In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torch torchvision diffusers transformers accelerate matplotlib scikit-image tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"abhaypherali","key":"859f3b85bfb14b823365766df25ebae5"}'}

In [4]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [5]:
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
!unzip -q lgg-mri-segmentation.zip

Dataset URL: https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
License(s): CC-BY-NC-SA-4.0
Downloading lgg-mri-segmentation.zip to /content
 97% 690M/714M [00:02<00:00, 254MB/s]
100% 714M/714M [00:02<00:00, 340MB/s]


In [11]:
# Block 1: Imports & Configuration
# First, we pull in all the tools we need for this project.
# This includes PyTorch for building the models, tools for handling images,
# and the `diffusers` library from Hugging Face which gives us the U-Net and scheduler.

import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import UNet2DModel, DDPMScheduler
from accelerate import Accelerator
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np

# This is our main settings area. We can easily change things like image size,
# how long to train, and where to save our results, all in one place.
config = {
    "dataset_path": "./lgg-mri-segmentation/kaggle_3m", # Where our MRI images are stored.
    "image_size": 128,          # We'll resize all images to 128x128 pixels.
    "batch_size": 16,           # Process 16 images at a time during training.
    "latent_channels": 128,      # The number of channels in our compressed "summary" image.
    "latent_size": 16,          # The size of the compressed summary (16x16).
    "ae_epochs": 75,            # How many times to loop over the entire dataset for autoencoder training.
    "ae_lr": 1e-4,              # The learning rate for the autoencoder.
    "unet_epochs": 150,         # How many times to loop over the dataset for U-Net training.
    "unet_lr": 1e-4,            # The learning rate for the U-Net.
    "num_train_timesteps": 1000, # The number of steps in the diffusion (noising/denoising) process.
    "save_dir": "/content/drive/My Drive/MRI_Project_Outputs_More_Epochs", # Where to save our trained models.
}

# Make sure the folder to save our models exists.
os.makedirs(config["save_dir"], exist_ok=True)


# Block 2: Data Preparation
# This part of the code sets up a custom 'Dataset' class. Its job is to find all our MRI images,
# load them one by one, and prepare them for the model by resizing and normalizing them.

class MRIDataset(Dataset):
    # This is our custom dataset for loading the MRI scans.
    def __init__(self, root_dir, transform=None):
        # Find all the image files, but ignore the segmentation masks.
        self.image_paths = glob.glob(os.path.join(root_dir, '*/*[0-9].tif'))
        self.transform = transform

    def __len__(self):
        # A simple function to tell PyTorch how many images we have in total.
        return len(self.image_paths)

    def __getitem__(self, idx):
        # This function gets one specific image from our dataset.
        img_path = self.image_paths[idx]
        # Open the image file and convert it to grayscale.
        image = Image.open(img_path).convert('L')
        # Apply the transformations we defined (like resize and normalize).
        if self.transform:
            image = self.transform(image)
        return image


# Block 3: Autoencoder Model Architecture
# Here, we define the two main parts of our Autoencoder: the Encoder and the Decoder.
# The Encoder's job is to compress the big image into a small summary (the latent space).
# The Decoder's job is to reconstruct the image from that small summary.

class Encoder(nn.Module):
    # The Encoder squishes the image down into a small summary.
    def __init__(self, latent_channels):
        super().__init__()
        self.model = nn.Sequential(
            # Each Conv2d layer with stride=2 halves the image size.
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), # 128x128 -> 64x64
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 64x64 -> 32x32
            nn.ReLU(),
            nn.Conv2d(64, latent_channels, kernel_size=3, stride=2, padding=1), # 32x32 -> 16x16
        )
    def forward(self, x): return self.model(x)

class Decoder(nn.Module):
    # The Decoder rebuilds the image from the small summary.
    def __init__(self, latent_channels):
        super().__init__()
        self.model = nn.Sequential(
            # Each ConvTranspose2d layer doubles the image size, reversing the encoder's job.
            nn.ConvTranspose2d(latent_channels, 64, kernel_size=4, stride=2, padding=1), # 16x16 -> 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 32x32 -> 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1), # 64x64 -> 128x128
            nn.Tanh(), # Tanh ensures the output pixels are in the right range (-1 to 1).
        )
    def forward(self, x): return self.model(x)

class Autoencoder(nn.Module):
    # This class just wraps the Encoder and Decoder together into one neat package.
    def __init__(self, latent_channels):
        super().__init__()
        self.encoder = Encoder(latent_channels)
        self.decoder = Decoder(latent_channels)
    def forward(self, x):
        # When we give it an image, it first encodes it, then decodes it.
        latent = self.encoder(x)
        return self.decoder(latent)

# Block 4: Autoencoder Training
# This is where train the Autoencoder.
# We show it an image, ask it to reconstruct it, and penalise it if the reconstruction
# is different from the original. It learns by trying to minimise this penalty.

def train_autoencoder(dataloader):
    print("--- Stage 1: Training Autoencoder ---")
    accelerator = Accelerator()
    device = accelerator.device
    autoencoder = Autoencoder(config["latent_channels"]).to(device)
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=config["ae_lr"])
    criterion = nn.MSELoss() # Use Mean Squared Error to see how different the images are.
    dataloader_acc, autoencoder, optimizer = accelerator.prepare(dataloader, autoencoder, optimizer)

    # Loop over the dataset for the number of epochs we set.
    for epoch in range(config["ae_epochs"]):
        progress_bar = tqdm(total=len(dataloader_acc), desc=f"Epoch {epoch + 1}/{config['ae_epochs']}")
        # Process the data one batch at a time.
        for batch in dataloader_acc:
            images = batch.to(device)
            optimizer.zero_grad() # Reset gradients from the last step.
            reconstructed_images = autoencoder(images) # Get the model's reconstruction.
            loss = criterion(reconstructed_images, images) # Calculate how "wrong" it was.
            accelerator.backward(loss) # Figure out which weights contributed to the error.
            optimizer.step() # Update the weights to do better next time.
            progress_bar.update(1)
            progress_bar.set_postfix(loss=loss.item())
        progress_bar.close()

    # Save the trained model's weights to a file.
    ae_save_path = os.path.join(config["save_dir"], "autoencoder.pth")
    unwrapped_ae = accelerator.unwrap_model(autoencoder)
    torch.save(unwrapped_ae.state_dict(), ae_save_path)
    print(f"Autoencoder training finished. Model saved to {ae_save_path}")
    return unwrapped_ae

# Block 5: Diffusion Model (U-Net) Training
# We teach the U-Net how to denoise.
# We take the clean summaries from our autoencoder, add a bunch of noise,
# and then ask the U-Net to predict exactly what noise we added.

def train_diffusion_model(autoencoder, dataloader):
    print("\n--- Stage 2: Training Diffusion U-Net ---")
    # First, freeze the autoencoder. We don't want to change it anymore.
    autoencoder.eval()
    for param in autoencoder.parameters():
        param.requires_grad = False

    accelerator = Accelerator()
    device = accelerator.device
    # Set up our U-Net model with the right dimensions.
    unet = UNet2DModel(
        sample_size=config["latent_size"], in_channels=config["latent_channels"],
        out_channels=config["latent_channels"], layers_per_block=2,
        block_out_channels=(128, 128, 256),
        down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
        up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
    )
    # The scheduler manages the math for adding noise.
    noise_scheduler = DDPMScheduler(num_train_timesteps=config["num_train_timesteps"])
    unet.to(device)
    autoencoder.to(device)
    optimizer = torch.optim.Adam(unet.parameters(), lr=config["unet_lr"])
    unet, optimizer, dataloader_acc = accelerator.prepare(unet, optimizer, dataloader)

    # Loop over the dataset for the number of epochs we set.
    for epoch in range(config["unet_epochs"]):
        progress_bar = tqdm(total=len(dataloader_acc), desc=f"Epoch {epoch + 1}/{config['unet_epochs']}")
        for batch in dataloader_acc:
            images = batch.to(device)
            optimizer.zero_grad()
            # Get the clean summaries from our frozen encoder.
            with torch.no_grad():
                clean_latents = autoencoder.encoder(images)
            # Create some random noise.
            noise = torch.randn_like(clean_latents)
            # Pick a random timestep (which determines how much noise to add).
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (clean_latents.shape[0],), device=device).long()
            # Add the noise to our clean summaries.
            noisy_latents = noise_scheduler.add_noise(clean_latents, noise, timesteps)
            # Ask the U-Net to predict the noise we just added.
            noise_pred = unet(noisy_latents, timesteps).sample
            # Calculate how wrong the U-Net's prediction was.
            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss) # Figure out which weights to change.
            optimizer.step() # Update the U-Net to get better at predicting noise.
            progress_bar.update(1)
            progress_bar.set_postfix(loss=loss.item())
        progress_bar.close()

    # Save our newly trained U-Net model.
    unet_save_path = os.path.join(config["save_dir"], "unet.pth")
    accelerator.wait_for_everyone()
    unwrapped_unet = accelerator.unwrap_model(unet)
    torch.save(unwrapped_unet.state_dict(), unet_save_path)
    print(f"U-Net training finished. Model saved to {unet_save_path}")
    return unwrapped_unet

# Block 6: Image Generation (Inference)
# This is where we put it all together to create a new image.
# We start with pure random noise and use our expert U-Net to gradually
# denoise it until a clean image summary emerges.

def generate_image(unet, autoencoder):
    print("\n--- Stage 3: Generating New Image ---")

    # Make sure all our models are on the right device (like the GPU).
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    unet.to(device)
    autoencoder.to(device)
    # Set models to evaluation mode (important for inference).
    unet.eval()
    autoencoder.eval()

    scheduler = DDPMScheduler(num_train_timesteps=config["num_train_timesteps"])

    # Create a random starting point (a block of pure noise).
    # The seed makes sure we get the same "random" noise every time, for reproducible results.
    generator = torch.Generator(device=device)
    generator.manual_seed(19)

    latents = torch.randn(
        (1, config["latent_channels"], config["latent_size"], config["latent_size"]),
        generator=generator,
        device=device,
    )

    # Loop backwards from 999 down to 0, removing a bit of noise at each step.
    scheduler.set_timesteps(config["num_train_timesteps"])
    progress_bar = tqdm(scheduler.timesteps, desc="Generating Image")
    for t in progress_bar:
        with torch.no_grad():
            # Ask the U-Net to predict the noise in our current latents.
            noise_pred = unet(latents, t).sample
        # Use the scheduler to remove that predicted noise.
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # Now that we have a clean summary, use the decoder to rebuild the final image.
    with torch.no_grad():
        generated_image = autoencoder.decoder(latents)

    # Clean up the image for display (un-normalize it from [-1, 1] to [0, 1]).
    generated_image = (generated_image + 1) / 2
    generated_image = generated_image.clamp(0, 1)
    img_np = generated_image[0].cpu().permute(1, 2, 0).numpy()

    # Plot and save the final generated image.
    plt.figure(figsize=(6, 6))
    plt.imshow(img_np, cmap='gray')
    plt.title("Generated MRI Image")
    plt.axis('off')
    save_path = os.path.join(config["save_dir"], "generated_mri.png")
    plt.savefig(save_path)
    plt.show()
    print(f"Generated image saved to {save_path}")


# Block 7: Main Execution
# This is the main function that runs our script.
# First, it checks if we've already trained our models.
# If we have, it just loads them and generates an image. If not, it starts the whole training process.

def main():
    # Define the paths where our trained models should be.
    ae_path = os.path.join(config["save_dir"], "autoencoder.pth")
    unet_path = os.path.join(config["save_dir"], "unet.pth")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # If both trained models already exist...
    if os.path.exists(ae_path) and os.path.exists(unet_path):
        print("--- Loading pre-trained models for inference ---")

        # ...then just load them up.
        trained_autoencoder = Autoencoder(config["latent_channels"])
        trained_unet = UNet2DModel(
            sample_size=config["latent_size"], in_channels=config["latent_channels"],
            out_channels=config["latent_channels"], layers_per_block=2,
            block_out_channels=(128, 128, 256),
            down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
        )
        trained_autoencoder.load_state_dict(torch.load(ae_path, map_location=device))
        trained_unet.load_state_dict(torch.load(unet_path, map_location=device))

        # And jump straight to generating an image.
        print("Models loaded successfully. Generating a new image...")
        generate_image(trained_unet, trained_autoencoder)
        return # Exit the function here.

    # If the models don't exist, we have to train them from scratch.
    print("--- Pre-trained models not found. Starting full training process. ---")

    # Set up the image transformations.
    transform = transforms.Compose([
        transforms.Resize((config["image_size"], config["image_size"])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ])

    # Try to load our dataset.
    try:
        dataset = MRIDataset(root_dir=config["dataset_path"], transform=transform)
        if len(dataset) == 0:
            raise FileNotFoundError(f"No image files were found in the directory '{config['dataset_path']}'.")
        dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)
        print(f"Successfully loaded {len(dataset)} images.")
    except FileNotFoundError as e:
        print(f"\n[Dataset Error] {e}")
        return

    # Run the two training stages.
    trained_autoencoder = train_autoencoder(dataloader)
    trained_unet = train_diffusion_model(trained_autoencoder, dataloader)
    # Finally, generate an image with our newly trained models.
    generate_image(trained_unet, trained_autoencoder)

# Makes the script runnable.
if __name__ == "__main__":
    main()

--- Loading pre-trained models for inference ---
Models loaded successfully. Generating a new image...

--- Stage 3: Generating New Image ---


Generating Image:   0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 