## Audio2Image

Run this notebook to train the model

In [2]:
import os
!pip install PySoundFile



In [3]:
import torch
import torchaudio
import torchvision
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

### Step 1: Create Dataset and DataLoader

In [4]:
IMG_TRANSFORM = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.ToTensor()
])

In [5]:
NEW_COLUMN_NAMES = {
    '---g-f_I2yQ': 'youtube_video_id',
    '1': 'start_seconds',
    'people marching': 'label',
    'test': 'split',
}

In [39]:
class AudioDataset(Dataset):
    def __init__(self, csv_file, audio_dir, img_dir, img_transform=None, embeddings=None):
        self.audio_dir = audio_dir
        self.img_dir = img_dir
        self.img_transform = img_transform
        self.embeddings = embeddings
        self.df = pd.read_csv(csv_file)
        self.rename_columns()
        self.add_columns()
        self.remove_invalid_rows()

    @staticmethod
    def check_validity(image_path):
        try:
            Image.open(image_path)
            return True
        except:
            return False

    def remove_invalid_rows(self):
        self.df['is_valid'] = self.df['img_path'].apply(AudioDataset.check_validity)
        self.df = self.df[self.df['is_valid'] == True]
        self.df = self.df.drop(columns=['is_valid'])

    def rename_columns(self):
        self.df.rename(columns=NEW_COLUMN_NAMES, inplace=True)

    def add_columns(self):
        self.df['audio_path'] = self.df['youtube_video_id'].apply(
            lambda x: self.audio_dir + '/' + 'audio_' + x + '.wav')
        self.df['img_path'] = self.df['youtube_video_id'].apply(lambda x: self.img_dir + '/' + 'image_' + x + '.jpg')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        print(f"Index = {idx}")
        try:
            if torch.is_tensor(idx):
                idx = idx.tolist()
            else:
                idx = [idx]
            audio_path = self.df.loc[idx, 'audio_path'].values[0]
            waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
            transform = torchaudio.transforms.Resample(sample_rate, sample_rate / 10)
            waveform = transform(waveform)

            label = self.df.loc[idx, 'label'].values[0]

            img_path = self.df.loc[idx, 'img_path'].values[0]
            img = Image.open(img_path)

            if self.img_transform is not None:
                img = self.img_transform(img)

            if self.embeddings is not None:
                waveform = self.embeddings(waveform)

            return waveform.mean(dim=0), img, label
        except:
            return None, None, None

In [40]:
CSV_FILE = './vggsound.csv'
AUDIO_DIR = './data/audio'
IMG_DIR = './data/image'

In [41]:
audio2image_dataset = AudioDataset(CSV_FILE, AUDIO_DIR, IMG_DIR, IMG_TRANSFORM)

In [42]:
def custom_collate(batch):
    """Collate function for the dataloader."""
    audios, images, labels = zip(*batch)
    return (
        torch.concat(audios, dim=0),
        torch.concat(images, dim=0),
        list(labels)
    )

In [43]:
BATCH_SIZE = 1

In [44]:
audio2image_dataloader = torch.utils.data.DataLoader(audio2image_dataset, batch_size=BATCH_SIZE, shuffle=True,collate_fn=custom_collate)

In [45]:
# num_workers fails in Windows
if os.name != 'nt':
    audio2image_dataloader.num_workers = os.cpu_count()

In [46]:
dataiter = iter(audio2image_dataloader)
for i in range(2):
    waves, pics, texts = next(dataiter)

Index = 0
Index = 3
