In [15]:
import torch
from torch.utils.data import DataLoader, Dataset
import torchaudio
from transformers import AutoFeatureExtractor, ASTForAudioClassification
from glob import glob
from collections import defaultdict
from tqdm import tqdm

In [16]:
# Define dataset class
class AudioDataset(Dataset):
    def __init__(self, file_paths, feature_extractor, target_sampling_rate=16000):
        self.file_paths = file_paths
        self.feature_extractor = feature_extractor
        self.target_sampling_rate = target_sampling_rate

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        data, sr = torchaudio.load(file_path)
        data = torchaudio.functional.resample(data, orig_freq=sr, new_freq=self.target_sampling_rate)
        data = data.squeeze()
        inputs = self.feature_extractor(data, sampling_rate=self.target_sampling_rate, return_tensors="pt")
        inputs['input_values'] = inputs['input_values'].squeeze(0)  # Remove batch dimension
        return inputs['input_values'], file_path


In [44]:
# Inference function
def inference_batch(model, dataloader, k=5, with_logit=False):
    d = defaultdict(int)
    counts = 0
    tqdm_bar = tqdm(dataloader)

    with open("/root/asset/test_only_speech_list_k5.txt", "w") as tf:
        for batch, paths in tqdm_bar:
            batch = batch.to('cuda:0')
            with torch.no_grad():
                outputs = model(input_values=batch).logits

            for i, logits in enumerate(outputs):
                logits = logits.squeeze()
                predicted_class_ids = torch.argsort(logits)[-k:]
                predicted_labels = [model.config.id2label[_id.item()] for _id in predicted_class_ids]

                for label in predicted_labels:
                    d[label] += 1

                #list_of_lists = [str(tensor.tolist()) for tensor in sorted(logits)[-k:]]
                sorted_indices = torch.argsort(logits)

                # 0의 정렬된 인덱스에서의 위치를 찾음
                sorted_position = (sorted_indices == 0).nonzero(as_tuple=True)[0].item()

                # 뒤에서 몇 번째인지 계산
                reverse_position = len(logits) - sorted_position - 1
                zero_logit_value = logits[0].item()

                tf.write(paths[i] + " --> " + str(reverse_position) + " --> " + str(zero_logit_value) + "\n")
                counts += 1
                    

            tqdm_bar.set_postfix(only_speech=d)

    return d, counts

In [45]:
# Load model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
model = model.to('cuda:0')
model.eval()

# Parameters
batch_size = 32  # Adjust batch size according to your GPU memory
file_paths = glob("/root/data/test/*.ogg")
dataset = AudioDataset(file_paths, feature_extractor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


# Run inference
d, counts = inference_batch(model, dataloader, k=5, with_logit=True)
print("Finished processing. Total non-speech files:", counts)

  5%|▌         | 82/1563 [00:58<17:44,  1.39it/s, only_speech=defaultdict(<class 'int'>, {'Breaking': 60, 'Bang': 5, 'Burst, pop': 31, 'Explosion': 48, 'Speech': 2348, 'Narration, monologue': 493, 'Female speech, woman speaking': 411, 'Inside, small room': 374, 'Animal': 419, 'Sliding door': 18, 'Door': 63, 'Stomach rumble': 12, 'Water': 110, 'Knock': 80, 'Slam': 16, 'Coin (dropping)': 24, 'Typing': 65, 'Computer keyboard': 76, 'Chop': 61, 'Bouncing': 18, 'Typewriter': 28, 'Scissors': 21, 'Pig': 36, 'Oink': 55, 'Grunt': 21, 'Music': 550, 'Helicopter': 31, 'Vehicle': 481, 'Speech synthesizer': 173, 'Conversation': 329, 'Tap': 44, 'Male speech, man speaking': 150, 'Applause': 54, 'Clapping': 61, 'Tick': 76, 'Tick-tock': 71, 'Fill (with liquid)': 41, 'Liquid': 90, 'Toilet flush': 43, 'Tools': 54, 'Power tool': 36, 'Wood': 77, 'Chainsaw': 38, 'Sigh': 29, 'Breathing': 27, 'Gasp': 74, 'Snort': 63, 'Chink, clink': 96, 'Crack': 43, 'Cap gun': 16, 'Sound effect': 126, 'Clock': 34, 'Cattle, bovi

KeyboardInterrupt: 