In [58]:
import glob
import os
import numpy as np

import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset


from PIL import Image

In [148]:
img_size = 64
num_workers = 0
batch_size = 2


embeddingPath = "Embeddings/audio/subURMPClean/train/all_embeddings.pt"
embeddingNamePath = "Embeddings/audio/subURMPClean/train/all_file_names.csv"
dataset_path = "Data/SubURMP64/images/clean"
train_folder = "trial"
val_folder = "trial"



In [164]:
class CustomDataset(Dataset):
    def __init__(self, image_path, embeddingPath, namePath, transform=None):
        

        self.image_path = image_path # Set image path e.g. trial for demonstration, train for application
        self.transform = transform
        file_list = glob.glob(self.image_path + "*")

        # Getting image names
        self.data = []
        for class_path in file_list:
            for img_path in glob.glob(class_path + "/*.jpg"):
                self.data.append(img_path)

        # Reading in embedding file and associated names
        self.embeddingArr = torch.load(embeddingPath)
        self.fileNames = np.loadtxt(namePath, delimiter=',', dtype=str)



        self.img_dim = (64, 64)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]  # Gets path to image

        with Image.open(img_path) as img:  # Loads image
            img.load()
        img = img.convert("RGB")  # Converts image to rgb
        img_path = img_path.split("/")[-1]
        #TODO use img_path to lookup embedding from embedding file created in __init__

        embeddingIdx = np.where(self.fileNames == img_path)  # Index for embeddings where it corresponds to the desired file name
        embedding = self.embeddingArr[embeddingIdx]  # Embeddings for associated index


        if self.transform:
            img = self.transform(img)

        return img, embedding  

In [165]:

# Any paramater in get_data() should be found in args in actual implementation
def get_data(img_size, dataset_path, embeddingPath, namePath, train_folder, val_folder, batch_size, num_workers):  # Defines dataloaders and transformations for data
    train_transforms = torchvision.transforms.Compose([
        T.Resize(img_size + int(.25*img_size)),  # args.img_size + 1/4 *args.img_size
        T.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    val_transforms = torchvision.transforms.Compose([
        T.Resize(img_size),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_dataset = CustomDataset(image_path=f"{dataset_path}/{train_folder}/", embeddingPath=embeddingPath, namePath=namePath,transform=train_transforms)
    val_dataset = CustomDataset(image_path=f"{dataset_path}/{val_folder}/", embeddingPath=embeddingPath, namePath=namePath,transform=val_transforms)
    

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers) #Defines the train dataloader
    val_dataset = DataLoader(val_dataset, batch_size=2*batch_size, shuffle=False, num_workers=num_workers)
    
    return train_dataloader, val_dataset

In [166]:
train_dataloader, val_dataloader = get_data(img_size, dataset_path, embeddingPath, embeddingNamePath, 
                                            train_folder, val_folder, batch_size, num_workers)

In [167]:
imgsList = []
labsList = []
embList = []
for imgs, labs, emb in train_dataloader:
    imgsList.append(imgs)
    labsList.append(labs)
    embList.append(emb)

labsList


[('bassoon00_22700.jpg', 'cello00_5800.jpg')]

### Next piece of the puzzle 

The next part of the source code that is relevant is the one_epoch method in ddpm_conditional. Specifically line 118. 