In [15]:
import torch
from torch.utils.data import DataLoader, Dataset
import torchaudio
from transformers import AutoFeatureExtractor, ASTForAudioClassification
from glob import glob
from collections import defaultdict
from tqdm import tqdm

In [16]:
# Define dataset class
class AudioDataset(Dataset):
    def __init__(self, file_paths, feature_extractor, target_sampling_rate=16000):
        self.file_paths = file_paths
        self.feature_extractor = feature_extractor
        self.target_sampling_rate = target_sampling_rate

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data, sr = torchaudio.load(file_path)
        data = torchaudio.functional.resample(data, orig_freq=sr, new_freq=self.target_sampling_rate)
        data = data.squeeze()
        inputs = self.feature_extractor(data, sampling_rate=self.target_sampling_rate, return_tensors="pt")
        inputs['input_values'] = inputs['input_values'].squeeze(0)  # Remove batch dimension
        return inputs['input_values'], file_path


In [73]:
# Inference function
def inference_batch(model, dataloader, k=5, with_logit=False):
    d = defaultdict(int)
    counts = 0
    tqdm_bar = tqdm(dataloader)

    with open("/root/asset/test_only_speech_list_k5.txt", "w") as tf:
        for batch, paths in tqdm_bar:
            batch = batch.to('cuda:0')
            with torch.no_grad():
                outputs = model(input_values=batch).logits

            for i, logits in enumerate(outputs):
                logits = logits.squeeze()
                predicted_class_ids = torch.argsort(logits)[-k:]
                predicted_labels = [model.config.id2label[_id.item()] for _id in predicted_class_ids]

                for label in predicted_labels:
                    d[label] += 1

                #list_of_lists = [str(tensor.tolist()) for tensor in sorted(logits)[-k:]]
                sorted_indices = torch.argsort(logits)

                # 0의 정렬된 인덱스에서의 위치를 찾음
                sorted_position = (sorted_indices == 0).nonzero(as_tuple=True)[0].item()

                # 뒤에서 몇 번째인지 계산
                reverse_position = len(logits) - sorted_position - 1
                zero_logit_value = logits[0].item()

                if reverse_position > 2 and zero_logit_value <= -2.5:
                    tf.write(paths[i] + " --> " + str(reverse_position) + " --> " + str(zero_logit_value) + "\n")
                    counts += 1
                    

            tqdm_bar.set_postfix(only_speech=d)

    return d, counts

In [74]:
# Load model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = model.to('cuda:0')
model.eval()

# Parameters
batch_size = 32  # Adjust batch size according to your GPU memory
file_paths = glob("/root/data/test/*.ogg")
dataset = AudioDataset(file_paths, feature_extractor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


# Run inference
d, counts = inference_batch(model, dataloader, k=5, with_logit=True)
print("Finished processing. Total non-speech files:", counts)

100%|██████████| 1563/1563 [18:44<00:00,  1.39it/s, only_speech=defaultdict(<class 'int'>, {'Breaking': 903, 'Bang': 79, 'Burst, pop': 546, 'Explosion': 930, 'Speech': 44871, 'Narration, monologue': 8689, 'Female speech, woman speaking': 7520, 'Inside, small room': 7502, 'Animal': 8360, 'Sliding door': 366, 'Door': 1144, 'Stomach rumble': 237, 'Water': 2345, 'Knock': 1585, 'Slam': 235, 'Coin (dropping)': 416, 'Typing': 1187, 'Computer keyboard': 1384, 'Chop': 1063, 'Bouncing': 181, 'Typewriter': 543, 'Scissors': 262, 'Pig': 624, 'Oink': 1019, 'Grunt': 365, 'Music': 10563, 'Helicopter': 670, 'Vehicle': 8696, 'Speech synthesizer': 3380, 'Conversation': 6158, 'Tap': 930, 'Male speech, man speaking': 2607, 'Applause': 720, 'Clapping': 821, 'Tick': 1268, 'Tick-tock': 1097, 'Fill (with liquid)': 765, 'Liquid': 1901, 'Toilet flush': 943, 'Tools': 1023, 'Power tool': 588, 'Wood': 1363, 'Chainsaw': 703, 'Sigh': 554, 'Breathing': 469, 'Gasp': 1434, 'Snort': 1087, 'Chink, clink': 1430, 'Crack': 6

Finished processing. Total non-speech files: 6305





In [75]:
import pandas as pd

file_ids = []

with open("/root/asset/test_only_speech_list_k5.txt", "r") as file:
    for line in file:
        file_path = line.split(' ')[0]
        file_id = file_path.split('/')[-1].split('.')[0]
        file_ids.append(file_id.strip("\n"))

# Create a DataFrame
df = pd.DataFrame(file_ids, columns=['id'])

# Save to CSV
csv_path = "./infer_masking.csv"
df.to_csv(csv_path, index=False)

print(f"CSV file saved to {csv_path}")


CSV file saved to ./infer_masking.csv


In [76]:
import pandas as pd

# CSV 파일 읽기
df1 = pd.read_csv('./infer_masking.csv')
df2 = pd.read_csv('/root/asset/ex7_19/new_masking.csv')

# 'id' 열에 strip() 적용
df1['id'] = df1['id'].str.strip()
df2['id'] = df2['id'].str.strip()

# 'id' 열의 교집합 추출
common_ids = pd.merge(df1, df2, on='id')

# 교집합에 속하지 않는 'id' 찾기
df1_ids = set(df1['id'])
common_ids_set = set(common_ids['id'])
non_common_ids = df1_ids - common_ids_set

# 교집합에 속하지 않는 행 출력
for idx, row in df1.iterrows():
    if row['id'] in non_common_ids:
        print(row)

print(f"df1 : {len(df1)}, df2 : {len(df2)}, merge : {len(common_ids)}")


id    TEST_00191
Name: 21, dtype: object
id    TEST_00281
Name: 33, dtype: object
id    TEST_00351
Name: 40, dtype: object
id    TEST_00356
Name: 42, dtype: object
id    TEST_00575
Name: 83, dtype: object
id    TEST_00777
Name: 109, dtype: object
id    TEST_01004
Name: 143, dtype: object
id    TEST_01117
Name: 151, dtype: object
id    TEST_01145
Name: 159, dtype: object
id    TEST_01301
Name: 177, dtype: object
id    TEST_01815
Name: 227, dtype: object
id    TEST_02359
Name: 307, dtype: object
id    TEST_02549
Name: 331, dtype: object
id    TEST_02637
Name: 341, dtype: object
id    TEST_02740
Name: 357, dtype: object
id    TEST_02950
Name: 379, dtype: object
id    TEST_03045
Name: 388, dtype: object
id    TEST_03717
Name: 468, dtype: object
id    TEST_03781
Name: 475, dtype: object
id    TEST_04340
Name: 539, dtype: object
id    TEST_04425
Name: 552, dtype: object
id    TEST_04969
Name: 638, dtype: object
id    TEST_05001
Name: 640, dtype: object
id    TEST_05190
Name: 669, dtype: obje