In [None]:
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import random
from specaugment import spec_augment
from torchvision import transforms as T

In [None]:
#resizing the images and turning them to tensors
transform1 = T.Compose([
    T.Resize(224),
    T.ToTensor(),
])
#turning the 1 channel image to a 3 channel image and normalizing the image     
transform2 = T.Compose([
    T.Lambda(lambda x: x.repeat(3, 1, 1)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
class get_dataset(Dataset):
    def __init__(self, file_path, labels, for_training,transform1=None,transform2=None):
        #self.df = pd.read_csv(labels)
        self.df = labels
        self.file_path = file_path
        self.for_training = for_training
        self.transform1 = transform1
        self.transform2 = transform2

    
    def __len__(self):
        return len(self.df)


    def __getitem__(self, index):
        label = (self.df.iloc[index, 1])
        file_data = Image.open(self.file_path + str(self.df["fname"][index]))
        file_data = self.transform1(file_data)
        file_data = self.SpecAugment(file_data)
        file_data = self.transform2(file_data)
        
        return file_data, label

    #SpecAugment: apply image warping, frequency mask and time mask to images
    #Is applied 20% of the time 
    def SpecAugment(self, data):
        if self.for_training:
            if random.random() < 0.2:
                data = spec_augment(mel_spectrogram=data, time_warping_para=3,time_masking_para=5, frequency_masking_para=8)
        return data
    #Mix up: take a random image from the dataset, choose a random interpolation value and merge the images
    #Is applied 20% of the time 
    def mixup(self, img, label):
    if self.for_training:
        if random.random() < 0.2:
            random_index = random.randint(0,(self.__len__())-1)
            label2 = (self.df.iloc[random_index, 1])
            img2 = Image.open(self.file_path + str(self.df["fname"][random_index]))
            img2 = self.transform1(img2)

            target1 = torch.zeros(7)
            target2 = torch.zeros(7)

            target1[label] = 1
            target2[label2] = 1

            #Mixup the images accordingly
            alpha = 1.0
            beta = 1.0

            lam = np.random.beta(alpha, beta)
            img = lam * img + (1 - lam) * img2
            label = lam * target1 + (1 - lam) * target2
            label = (label.argmax().item())
    return img, label

In [None]:
def train_dataloader(file_path, label_path, for_training, batch_size):

    train_dataset = get_dataset(file_path,label_path, for_training,transform1,transform2)

    return DataLoader(train_dataset, batch_size=batch_size,shuffle=True)

In [None]:
def val_dataloader(file_path, label_path, for_training, batch_size):

    valid_dataset = get_dataset(file_path,label_path, for_training,transform1,transform2)

    return DataLoader(valid_dataset, batch_size=batch_size,shuffle=False)