In [15]:
import numpy as np
import pandas as pd
import torch
import os
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.models import swin_t, Swin_T_Weights
from datasets import load_dataset
from PIL import Image

In [35]:
train_dir = '/GILMLab/GILMLabProjects/DeepLearning/deepquantification/data/GDCAtlas-Data/patches/train' 
test_dir = '/GILMLab/GILMLabProjects/DeepLearning/deepquantification/data/GDCAtlas-Data/patches/test'

In [28]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

class GDC(Dataset):
    def __init__(self, src_dir, transforms = None):
        csv_dir = os.path.join(src_dir, 'metadata.csv')
        df = pd.read_csv(csv_dir)
        self.images = df['file_name'].values
        self.labels = df['label'].values
        self.transforms = transforms
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        image = Image.open(os.path.join(src_dir, img))
        label = self.labels[idx]
            
        if self.transforms:
            img_tensor = self.transforms(image)
        else:
            img_tensor = torch.from_numpy(np.float32(image)).unsqueeze(0)
        return img_tensor, label

In [36]:
train_dataset = GDC(train_dir, transform)
train_dataloader = DataLoader(train_dataset, batch_size = 100, shuffle = True)
test_dataset = GDC(test_dir, transform)
test_dataloader = DataLoader(test_dataset, batch_size = 100, shuffle = True)