In [23]:
import os
import glob
import numpy as np
import torch 
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T


In [32]:
dataset_path = "../Data/SubURMP64/images/clean"
train_folder = "train"

image_path = os.path.join(dataset_path, train_folder)
embedding_path = "../Embeddings/audio/subURMPClean/train/all_embeddings_pca.pt"
names_path = "../Embeddings/audio/subURMPClean/train/all_file_names.csv"
img_size = 64
batch_size = 2

In [27]:
class CustomDataset(Dataset):
    def __init__(self, image_path, name_path, embedding_path, image_size, transform=None):
        self.image_path = image_path
        self.transfrom = transform
        self.name_path = name_path
        self.image_dim = (image_size, image_size)

        self.embedding_array = torch.load(embedding_path)
        self.embedding_names = np.loadtxt(name_path, delimiter=',', dtype=str)

        self.file_list = glob.glob(self.image_path + "*")
        self.data = []
        for class_path in self.file_list:
            for img_pth in glob.glob(class_path + "/*.jpg"):
                self.data.append(img_pth)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]
        with Image.open(img_path) as image:
            image.Load()
        image = image.convert("RGB")
        img_path = img_path.split("/")[-1]
        embedding_index = np.where(self.embedding_names == img_path)
        embedding = self.embedding_array[embedding_index]
        embedding = embedding.squeeze()

        if self.transform:
            image = self.transform(image)

        return image, embedding

In [38]:
os.listdir(os.path.join(dataset_path, train_folder, "cello"))

['cello08_53800.jpg',
 'cello05_13200.jpg',
 'cello09_196900.jpg',
 'cello03_9300.jpg',
 'cello05_17700.jpg',
 'cello00_24500.jpg',
 'cello02_86900.jpg',
 'cello09_221900.jpg',
 'cello03_125600.jpg',
 'cello08_8800.jpg',
 'cello00_20000.jpg',
 'cello02_9300.jpg',
 'cello03_121300.jpg',
 'cello02_39300.jpg',
 'cello07_157200.jpg',
 'cello03_25400.jpg',
 'cello07_153700.jpg',
 'cello00_56800.jpg',
 'cello03_21100.jpg',
 'cello09_74600.jpg',
 'cello08_25000.jpg',
 'cello07_43500.jpg',
 'cello06_12300.jpg',
 'cello09_39700.jpg',
 'cello08_68100.jpg',
 'cello07_47000.jpg',
 'cello06_16600.jpg',
 'cello09_70300.jpg',
 'cello08_21500.jpg',
 'cello09_129300.jpg',
 'cello05_23300.jpg',
 'cello08_63900.jpg',
 'cello02_106300.jpg',
 'cello05_27600.jpg',
 'cello02_102600.jpg',
 'cello00_59500.jpg',
 'cello03_115700.jpg',
 'cello09_211800.jpg',
 'cello00_14400.jpg',
 'cello03_111200.jpg',
 'cello03_67800.jpg',
 'cello00_10100.jpg',
 'cello02_44300.jpg',
 'cello03_15500.jpg',
 'cello03_91600.jpg',
 

In [30]:
train_dataset = CustomDataset(image_path=image_path, name_path=names_path, embedding_path=embedding_path, image_size=img_size)


In [33]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                  # num_workers=args.num_workers
                                  )

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
def get_data(args):
    train_transforms = torchvision.transforms.Compose([
        T.Resize(args.img_size + int(.25*args.img_size)),  # args.img_size + 1/4 *args.img_size
        T.RandomResizedCrop(args.img_size, scale=(0.8, 1.0)),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    val_transforms = torchvision.transforms.Compose([
        T.Resize(args.img_size),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_dataset = CustomDataset(args, os.path.join(args.dataset_path, args.train_folder), transform=train_transforms)
    val_dataset = CustomDataset(args, os.path.join(args.dataset_path, args.val_folder), transform=val_transforms)
    
    # if args.slice_size>1:
    #     train_dataset = torch.utils.data.Subset(train_dataset, indices=range(0, len(train_dataset), args.slice_size))
    #     val_dataset = torch.utils.data.Subset(val_dataset, indices=range(0, len(val_dataset), args.slice_size))

    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                  # num_workers=args.num_workers
                                  )
    val_dataset = DataLoader(val_dataset, batch_size=2*args.batch_size, shuffle=False, 
                             #num_workers=args.num_workers
                             )
    return train_dataloader, val_dataset