<a href="https://colab.research.google.com/github/spate472/RecreatingRetinaUNet/blob/main/Attempt2_ToyDataPyHealth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyhealth pydicom scikit-image matplotlib
!pip install pandas==2.2.2
!pip install torch torchvision

In [None]:
import pydicom
import pandas
import numpy
import pyhealth
import skimage

print(f"pydicom: {pydicom.__version__}")
print(f"pandas: {pandas.__version__}")
print(f"numpy: {numpy.__version__}")
print(f"pyhealth: {pyhealth.__version__}")
print(f"skimage: {skimage.__version__}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from skimage.draw import disk
from pyhealth.datasets.base_dataset_v2 import BaseDataset
from pyhealth.data import Patient, Visit, Event
from torch.utils.data import DataLoader, random_split
import os
import json
import torchvision.transforms as transforms

# 1. Define the Custom Dataset (Extending BaseDataset)
class ToyImageDataset(BaseDataset):
    def __init__(self, root, task_type='shapes', dev='cpu', verbose=True, dataset_name="toy_images"): #Added Dataset Name
        self.task_type = task_type
        self.verbose = verbose
        self.dev = dev
        self.image_size = 320
        self.data_dir = root
        if not os.path.exists(self.data_dir): #create as needed
             os.makedirs(self.data_dir) #Create directory if needed
        patient_list, sample_list = self.process() #Returns
        super().__init__(dataset_name=dataset_name, sample_list=sample_list, root=root) #Run Base init
        self.patient_list = patient_list
        self.sample_list = sample_list
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # Convert to tensor (C, H, W)
            transforms.Grayscale(num_output_channels=1),  # Convert to grayscale (1 channel)
        ])

    def process(self):
        """Generates and stores the toy images and labels.

        This function creates the synthetic data and converts it into a list
        of PyHealth Patient objects.
        """
        if self.verbose:
            print("Generating Toy Dataset...")

        patient_list = []
        samples = []

        for patient_id in range(10):  # Small number of patients
            patient_str = str(patient_id)
            patient = Patient(patient_id=patient_str)
            visit = Visit(visit_id="visit_0", patient_id=patient_str)

            img, mask = self._generate_image_and_mask(random.randint(0, 1)) #Binary label
            image_event = Event(visit_id="visit_0", patient_id=patient_str, code="image", value=torch.tensor(img, dtype=torch.float32).unsqueeze(0), vocabulary="imaging")
            mask_event =  Event(visit_id="visit_0", patient_id=patient_str, code="mask", value=torch.tensor(mask, dtype=torch.float32).unsqueeze(0), vocabulary="imaging")
            visit.add_event(image_event)
            visit.add_event(mask_event)
            patient.add_visit(visit)
            patient_list.append(patient)
            samples.append((patient_str, visit.visit_id)) # add tuple for dataloader.

        #Create JSON dict:
        def local_asdict(obj):
            if hasattr(obj, '__dict__'):
                return obj.__dict__
            return str(obj)

        dataset_dict = {
            "patient_list" : patient_list,
        }

        #with open(os.path.join(self.data_dir, "dataset.json"), 'w') as f:
        #   json.dump([local_asdict(p) for p in patient_list], f, indent=4) # save as dict, with indent

        return patient_list, samples

    def _generate_image_and_mask(self, label):
         img = np.zeros((self.image_size, self.image_size), dtype=np.float32)
         r = 20
         x, y = random.randint(40, 280), random.randint(40, 280) #Image Limits
         rr, cc = disk((x, y), r, shape=(self.image_size, self.image_size)) #Shape
         img[rr, cc] = 0.2  #Intensity

         if self.task_type == 'shapes' and label == 1:
             rr_inner, cc_inner = disk((x, y), r // 2, shape=(self.image_size, self.image_size))
             img[rr_inner, cc_inner] -= 0.2

         img += np.random.uniform(0, 0.05, img.shape) # Noise
         return img, img  # image and mask

    def __len__(self): #BaseDataset now has to know how many samples you have
        return len(self.sample_list)

    def __getitem__(self, idx):
        """
        Args:
            idx: The index of the sample to fetch.
        Returns:
            A dict of features and labels.
        """
        patient_id, visit_id = self.sample_list[idx]

        #Manually Iterating and using get_event_list
        patient = next(p for p in self.patient_list if p.patient_id == patient_id)
        image = None
        mask = None
        for visit in patient.visits:
             if visit == visit_id:
                print(patient.visits[visit])
                print(patient.visits[visit].get_event_list(None))

                image_events = visit.get_event_list(code="image")
                mask_events = visit.get_event_list(code="mask")

                if image_events:
                    image = image_events[0].value
                if mask_events:
                    mask = mask_events[0].value

        if image is None or mask is None:
            print(f"Warning: Missing data for patient {patient_id}, visit code{visit_id}")
            image = torch.zeros((1, self.image_size, self.image_size)) #Make sure it is not error
            mask = torch.zeros((1, self.image_size, self.image_size))
        image = self.transform(image)
        mask = self.transform(mask)

        mask = mask / mask.max()  # Normalize mask to 0-1 range

        return {
            "image": image,
            "mask": mask,
        }

    def stat(self):
         print("Stat method was called")
         return None

# 2. Model and Training Loop (Mostly Unchanged) - Same

# 2. Define a simplified RetinaUNet-like model (adapted for toy data)
class RetinaUNet(nn.Module):
    def __init__(self):
        super(RetinaUNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.deconv1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.deconv2 = nn.ConvTranspose2d(16, 1, kernel_size=(2, 2))
        self.fc1 = nn.Linear(32 * 80 * 80, 64)  # Adjusted to match feature map size
        self.fc2 = nn.Linear(64, 32 * 80 * 80)  # Changed to output the correct number of units

    def forward(self, image):
        # Encoding path
        x = self.pool(F.relu(self.conv1(image)))
        x = self.pool(F.relu(self.conv2(x)))

        # Flatten for the fully connected layers
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)

        # Reshape output to match original image size
        logits = logits.view(logits.size(0), 32, 80, 80)  # Reshape it to a 4D tensor

        # Add deconvolution layers for segmentation output
        x = self.deconv1(logits)
        x = self.deconv2(x)

        x = torch.sigmoid(x)

        return x

# 3. Main execution block
DATA_DIR = "/content/toy_dataset"
dataset = ToyImageDataset(root=DATA_DIR, task_type='shapes')

# Splitting data into train, validation, and test sets
train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

#Load data into a dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

device = 'cpu'

# 4. Model
model = RetinaUNet().to(device) #Create Model

# Define optimizer and criterion
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss() # more stable

#---------------------------------------------------------------------------------------
#Training Function
def train_epoch(model, dataloader, optimizer, criterion, epoch, device):
    model.train() #Set to train
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        images = data["image"].to(device) #load to mem
        masks = data["mask"].to(device) #load to mem
        optimizer.zero_grad() #Zero optimizer
        outputs = model(images) #Set data to training function
        loss = criterion(outputs, masks) #Set training loss
        loss.backward() #Update with backpropigation
        optimizer.step() #Advance the data
        running_loss += loss.item() #Add to total

    print(f'Epoch {epoch} - Training Loss: {running_loss/len(dataloader)}')#Show in process
#---------------------------------------------------------------------------------------
#Evaluation Function
def evaluate_model(model, dataloader, device):
    model.eval()
    dice_score = 0.0
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            images = data["image"].to(device)
            masks = data["mask"].to(device)
            outputs = model(images)
            dice_score += dice_coefficient(outputs, masks).item()
    avg_dice = dice_score/len(dataloader)
    print(f'Test Dice Score: {avg_dice}')
    return avg_dice

#---------------------------------------------------------------------------------------
#Helper Function
def dice_coefficient(pred, target, threshold=0.5):
    pred = (pred > threshold).float()
    target = (target > threshold).float()
    intersection = torch.sum(pred * target)
    union = torch.sum(pred) + torch.sum(target)
    return 2.0 * intersection / (union + 1e-6)

# 5. Training and Evaluation Loop
num_epochs = 10 #Training Steps
for epoch in range(num_epochs):
    train_epoch(model, train_loader, optimizer, criterion, epoch, device) #Run train, data, other params

evaluate_model(model, test_loader, device) #Run Evaluate Loop