## Audio2Image

Run this notebook to train the model

In [19]:
!pip install PySoundFile



In [70]:
import torch
import torchaudio
import torchvision
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

### Step 1: Create Dataset and DataLoader

In [71]:
IMG_TRANSFORM = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.ToTensor()
])

In [72]:
NEW_COLUMN_NAMES = {
    '---g-f_I2yQ': 'youtube_video_id',
    '1': 'start_seconds',
    'people marching': 'label',
    'test': 'split',
}

In [119]:
class AudioDataset(Dataset):
    def __init__(self, csv_file, audio_dir, img_dir, img_transform=None, embeddings=None):
        self.audio_dir = audio_dir
        self.img_dir = img_dir
        self.img_transform = img_transform
        self.embeddings = embeddings
        self.df = pd.read_csv(csv_file)
        self.rename_columns()
        self.add_columns()
        self.remove_invalid_rows()

    @staticmethod
    def check_validity(image_path):
        try:
            Image.open(image_path)
            return True
        except:
            return False

    def remove_invalid_rows(self):
        self.df['is_valid'] = self.df['img_path'].apply(AudioDataset.check_validity)
        self.df = self.df[self.df['is_valid'] == True]
        self.df = self.df.drop(columns=['is_valid'])

    def rename_columns(self):
        self.df.rename(columns=NEW_COLUMN_NAMES, inplace=True)

    def add_columns(self):
        self.df['audio_path'] = self.df['youtube_video_id'].apply(
            lambda x: self.audio_dir + '/' + 'audio_' + x + '.wav')
        self.df['img_path'] = self.df['youtube_video_id'].apply(lambda x: self.img_dir + '/' + 'image_' + x + '.jpg')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        print(f"Index is {idx}")
        if idx >= len(self.df):
            return None, None
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        else:
            idx = [idx]
        audio_path = self.df.loc[idx, 'audio_path'].values[0]
        waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
        transform = torchaudio.transforms.Resample(sample_rate, sample_rate / 10)
        waveform = transform(waveform)

        img_path = self.df.loc[idx, 'img_path'].values[0]
        img = Image.open(img_path)

        if self.img_transform is not None:
            img = self.img_transform(img)

        if self.embeddings is not None:
            waveform = self.embeddings(waveform)

        return waveform, img

In [120]:
CSV_FILE = './vggsound.csv'
AUDIO_DIR = './data/audio'
IMG_DIR = './data/image'

In [121]:
audio2image_dataset = AudioDataset(CSV_FILE, AUDIO_DIR, IMG_DIR, IMG_TRANSFORM)

In [122]:
BATCH_SIZE = 1

In [123]:
audio2image_dataloader = torch.utils.data.DataLoader(audio2image_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [129]:
for wave, image in audio2image_dataloader:
    print(wave)
    print(image)

Index is 0
tensor([[[-0.0183, -0.0491, -0.0259,  ...,  0.0666,  0.0817,  0.0243],
         [-0.0183, -0.0491, -0.0259,  ...,  0.0666,  0.0817,  0.0243]]])
tensor([[[[0.2353, 0.2314, 0.2353,  ..., 0.3020, 0.3176, 0.3176],
          [0.2510, 0.2471, 0.2431,  ..., 0.3059, 0.3216, 0.3216],
          [0.2627, 0.2588, 0.2588,  ..., 0.3176, 0.3294, 0.3294],
          ...,
          [0.1608, 0.1608, 0.1608,  ..., 0.0745, 0.0784, 0.0863],
          [0.1608, 0.1608, 0.1608,  ..., 0.0706, 0.0745, 0.0863],
          [0.1647, 0.1647, 0.1647,  ..., 0.0706, 0.0745, 0.0824]],

         [[0.2314, 0.2275, 0.2314,  ..., 0.2510, 0.2667, 0.2667],
          [0.2510, 0.2471, 0.2431,  ..., 0.2588, 0.2745, 0.2745],
          [0.2627, 0.2588, 0.2588,  ..., 0.2745, 0.2863, 0.2863],
          ...,
          [0.1804, 0.1804, 0.1804,  ..., 0.0706, 0.0745, 0.0824],
          [0.1804, 0.1804, 0.1804,  ..., 0.0667, 0.0706, 0.0824],
          [0.1843, 0.1843, 0.1843,  ..., 0.0667, 0.0706, 0.0784]],

         [[0.2118, 

KeyError: "None of [Index([2], dtype='int32')] are in the [index]"