# Dependencies

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt

# Custom Dataset

In [None]:
class SimpleVeinDataset(Dataset):
    def __init__(self, csv_file, image_dir, mask_dir):
        """
        Args:
            csv_file (str): Path to CSV with BatId and ImageID columns
            image_dir (str): Path to image folder
            mask_dir (str): Path to mask folder
        """
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.mask_dir = mask_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image ID
        image_id = self.data.iloc[idx]['ImageID']
        
        # Load image and mask
        img_path = os.path.join(self.image_dir, f"{image_id}.png")
        mask_path = os.path.join(self.mask_dir, f"{image_id}.jpg")
        
        # Open and convert to grayscale
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')
        
        # Convert to tensor and normalize
        image = torch.from_numpy(np.array(image)).float() / 255.0
        mask = torch.from_numpy(np.array(mask)).float() / 255.0
        
        # Add channel dimension
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        
        return image, mask

# Data Loading

In [None]:
# Create dataset
dataset = SimpleVeinDataset(
    csv_file='../Dataset/dataset.csv',
    image_dir='../Dataset/Images',
    mask_dir='../Dataset/Masks'
)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=1, #change to 2 when i get to buildling training 
    shuffle=False  #also change this to true
)

In [None]:
# Check if there are any image ids given in the csv that don't have a corresponding image or mask
for i, (image, mask) in enumerate(dataloader):
    if image.size(2) != 1024 or image.size(3) != 1024:
        print(f"Image {i} has shape {image.shape}")
    if mask.size(2) != 1024 or mask.size(3) != 1024:
        print(f"Mask {i} has shape {mask.shape}")

# Verifying Dataset

In [None]:
# Create dataloader with batch size 1 for easy visualization
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Create figure
fig, axes = plt.subplots(5, 2, figsize=(10, 20))
plt.tight_layout(pad=3.0)

# Go through first 5 images
for i, (image, mask) in enumerate(dataloader):
    if i >= 5:  # Stop after 5 images
        break
        
    # Remove batch dimension and convert to numpy
    image = image.squeeze().numpy()
    mask = mask.squeeze().numpy()
    
    # Plot image
    axes[i, 0].imshow(image, cmap='gray')
    axes[i, 0].set_title(f'Image {i+1}')
    axes[i, 0].axis('off')
    
    # Plot mask
    axes[i, 1].imshow(mask, cmap='gray')
    axes[i, 1].set_title(f'Mask {i+1}')
    axes[i, 1].axis('off')

plt.show()