In [7]:
# Install the required packages for the notebook
%pip install boto3 nibabel numpy matplotlib torch torchvision torchaudio pillow tensorflow torchsummary

Note: you may need to restart the kernel to use updated packages.


In [8]:
# Sprint 1: Reading NII files from S3 and saving PNGs
import io
import os
import boto3
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tempfile

# Initialize S3 resource and specify bucket and folder details
s3 = boto3.resource('s3')
bucket_name = 'chemocraft-data'
folder_path = 'MICCAI_BraTS2020_TrainingData/'
# folder_path = 'Data/BraTS20_Training_369 copy/'
bucket = s3.Bucket(bucket_name)

def plot_slice(data, crop, slice_idx, filename):
    # Crop the specified slice
    slice_2d = data[:, :, slice_idx]
    cropped_slice = slice_2d[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
    
    # Display the cropped slice with matplotlib
    plt.figure(figsize=(6, 6))
    plt.imshow(cropped_slice, cmap='gray')
    plt.title(f'Slice {slice_idx} of {filename}')
    plt.axis('off')  # Hide axes for cleaner display
    plt.show()

def savePNG(data, crop, filename):    
    # Prepare directory structure
    fileWOext = filename.split(".")[0]
    TrainingCount = fileWOext.split("_")[-2]
    ScanType = fileWOext.split("_")[-1]
    slice_path = f"brain_slices/{TrainingCount}/{ScanType}/"
    print(f"Saving in directory: {slice_path}")

    # Iterate through each slice in the Z-Dimiension data and save as PNG
    for slice_idx in range(data.shape[2]):
        # Crop each slice
        slice_2d = data[:, :, slice_idx]
        cropped_slice = slice_2d[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
        png_filename = f"{slice_path}{slice_idx}.png"
        
        # Local Saving
        # try:
        #     # Create directories as needed and save each slice
        #     os.makedirs(slice_path, exist_ok=True)
        #     mpimg.imsave(png_filename, cropped_slice, cmap='gray')
        # except Exception as e:
        #     print(f"ERROR: directory could not be made due to {e}")
        
        # Upload each PNG to S3
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_png:
            mpimg.imsave(temp_png.name, cropped_slice, cmap='gray')
            temp_png.flush()
            temp_png.seek(0)
            temp_png_name = temp_png.name  # Store the name to use it after the file is closed

        try:
            s3.Bucket(bucket_name).upload_file(temp_png_name, f"Akshay/{png_filename}")
            os.remove(temp_png_name)
        except Exception as e:
            print(f"ERROR: Could not upload or delete temporary PNG file due to {e}")

def render_nii_from_s3(filename, path):
    print(f"Fetching file: {filename}")

    try:
        obj = bucket.Object(path + filename)
        file_stream = io.BytesIO(obj.get()['Body'].read())
    except s3.meta.client.exceptions.NoSuchKey as e:
        print(f"ERROR: The specified key does not exist: {path + filename}")
        return
    except Exception as e:
        print(f"ERROR: An unexpected error occurred: {e}")
        return

    with tempfile.NamedTemporaryFile(suffix='.nii', delete=False) as temp_file:  # Prevent auto-deletion
        temp_file.write(file_stream.getvalue())
        temp_file.flush()

        temp_file_path = temp_file.name
        print(f"Temporary file created: {temp_file_path}")

    try:
        img = nib.load(temp_file_path)
        data = img.get_fdata()

        print(f"Data shape for {filename}: {data.shape}")
        
        if data.size == 0:
            print(f"No data found in {filename}")
            return

        # Define crop dimensions
        cropleft = 25
        cropright = data.shape[0] - 15
        cropbottom = data.shape[1] - 40
        croptop = 40
        
        crop = np.array([[croptop, cropbottom], [cropleft, cropright]])
        
        # Save the PNGs and plot a sample slice
        savePNG(data, crop, filename)
        
        # slice_idx = 88  # Choose a slice index for sample display
        # plot_slice(data, crop, slice_idx, filename)

    except Exception as e:
        print(f"Error loading file {filename}: {e}")
        
    finally:
        try:
            os.remove(temp_file_path)
            print(f"Deleted temporary file: {temp_file_path}")
        except OSError as cleanup_error:
            print(f"Error deleting temp file: {cleanup_error}")

def find_and_render_nii_files():
    found_files = False

    subfolders = set()  # use a set to ensure unique subfolder names
    for obj in bucket.objects.filter(Prefix=folder_path):
        # Get the path after the 'Data/' prefix and split it by '/'
        path_parts = obj.key[len(folder_path):].split('/')
        
        # Check if there's at least one part (indicating a subfolder)
        if len(path_parts) > 1:
            subfolders.add(f'{path_parts[0]}/')  # Add the subfolder name
            
    subfolders = sorted(subfolders)

    print(f"Root Directory: {folder_path.split('/')[0]}")
    # print(subfolders)

    for subfolder in subfolders:
        path = folder_path + subfolder
        print(f"Reading S3 in {path}")
        for obj in bucket.objects.filter(Prefix=path):
            if obj.key.endswith('.nii'):
                print(f"path: {path}")
                found_files = True
                filename = obj.key.split('/')[-1]  # Extract filename from path
                print(f"Found .nii file: {filename}")
                render_nii_from_s3(filename, path)

    if not found_files:
        print(f"No .nii files found in the folder {folder_path}")

# Main function
# find_and_render_nii_files()

In [9]:
# Sprint 2: GAN for Brain MRI Generation
# pytorch implementation
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler
from torchsummary import summary
from torchvision import transforms
from PIL import Image

# Initialize S3 resource and specify bucket and folder details
s3 = boto3.resource('s3')
bucket_name = 'chemocraft-data'
bucket = s3.Bucket(bucket_name)
folder_prefix = "Akshay/brain_slices/"


# Set device for CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Print CUDA device information
print(f"CUDA Device Available: {torch.cuda.is_available()}")
print(f"CUDA Device Count: {torch.cuda.device_count()}")
print(f"Current CUDA Device: {torch.cuda.current_device()}")
print(f"CUDA Device Name: {torch.cuda.get_device_name(device)}")

# Function to load images from S3
# Load data from S3 and start training
def load_images_from_s3(bucket, folder_prefix):
    print(f"Loading images from S3 bucket: {bucket_name}/{folder_prefix}")
    images = []
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((200, 160)),
        transforms.ToTensor(),
    ])
    for obj in bucket.objects.filter(Prefix=folder_prefix):
        if obj.key.endswith('.png'):
            file_stream = io.BytesIO(obj.get()['Body'].read())
            image = Image.open(file_stream)
            image = transform(image)  # Normalize to [0, 1]
            images.append(image)
    return torch.stack(images).to(device)

def save_generated_images(epoch, generator, latent_dim, examples=16, dim=(4, 4)):
    """
    Generate and save sample images during training.
    """
    z = torch.randn(examples, latent_dim, device=device)
    gen_imgs = generator(z).detach().cpu()
    gen_imgs = 0.5 * (gen_imgs + 1)  # Rescale images from [-1, 1] to [0, 1]
    
    save_image(gen_imgs, f'generated_images/brain_mri_epoch_{epoch}.png', nrow=dim[0], normalize=True)

def save_models(epoch, generator, discriminator):
    """
    Save the generator and discriminator models.
    """
    os.makedirs('models', exist_ok=True)
    torch.save(generator.state_dict(), f'models/generator_epoch_{epoch}.pth')
    torch.save(discriminator.state_dict(), f'models/discriminator_epoch_{epoch}.pth')

def plot_training_history(d_losses, g_losses):
    """
    Plot the training history of the GAN.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.title('GAN Training History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.savefig('training_history.png')
    plt.close()

CUDA Device Available: True
CUDA Device Count: 1
Current CUDA Device: 0
CUDA Device Name: NVIDIA GeForce GTX 1650


In [10]:
# def generate_3d_image(generator, latent_dim, num_slices=150):
#     noise = np.random.normal(0, 1, (num_slices, latent_dim))
#     generated_slices = generator.predict(noise)
#     generated_3d_image = np.stack(generated_slices, axis=0)  # Shape: (num_slices, 200, 160)
#     return generated_3d_image

# # generated_3d_image = generate_3d_image(generator, latent_dim=100, num_slices=150)

# # Display the generated 3D image
# # fig, axs = plt.subplots(10, 15, figsize=(15, 10))
# # cnt = 0
# # for i in range(10):
# #     for j in range(15):
# #         axs[i, j].imshow(generated_3d_image[cnt, :, :, 0], cmap='gray')
# #         axs[i, j].axis('off')
# #         cnt += 1
# # plt.show()

# # Save the models
# generator.save('generator_model.h5')
# discriminator.save('discriminator_model.h5')
# gan.save('gan_model.h5')


In [17]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        # We'll start with 25x20 (200/8 x 160/8) and upsample 3 times to get 200x160
        self.init_height = 25  # 200 // 8
        self.init_width = 20   # 160 // 8
        
        # Initial layer to convert latent vector to features
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.init_height * self.init_width),
            nn.ReLU()
        )

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),  # 50x40
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),  # 100x80
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),  # 200x160
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(-1, 128, self.init_height, self.init_width)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),    # Output: 100x80
            *discriminator_block(16, 32),             # Output: 50x40
            *discriminator_block(32, 64),             # Output: 25x20
            *discriminator_block(64, 128),            # Output: 13x10
        )

        # Calculate the correct input size for the final linear layer
        # After 4 stride-2 convolutions: 200x160 -> 100x80 -> 50x40 -> 25x20 -> 13x10
        ds_height = int(np.ceil(200 / 2**4))  # = 13
        ds_width = int(np.ceil(160 / 2**4))   # = 10
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_height * ds_width, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity
    
def debug_dimensions():

    # Test forward pass with dimension printing
    z = torch.randn(batch_size, latent_dim, device=device)
    print(f"1. Input noise shape: {z.shape}")

    # Generator forward pass
    fake_imgs = generator(z)
    print(f"2. Generated images shape: {fake_imgs.shape}")

    # Discriminator forward pass with detailed shape printing
    d_out = discriminator.model(fake_imgs)
    print(f"3. After discriminator conv layers: {d_out.shape}")
    
    d_out_flat = d_out.view(d_out.shape[0], -1)
    print(f"4. After flattening: {d_out_flat.shape}")
    
    validity = discriminator.adv_layer(d_out_flat)
    print(f"5. Final output shape: {validity.shape}")

    torch.cuda.empty_cache()  # Clear the CUDA cache

    # Print model summaries
    # summary(generator, input_size=(batch_size, latent_dim))
    
    # summary(discriminator, input_size=(batch_size, 1, 200, 160))

latent_dim = 100
batch_size = 32  # Reduce the batch size to fit in GPU memory
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
torch.cuda.empty_cache()  # Clear the CUDA cache

# Print model summaries
print("\nGenerator Summary:")
summary(generator, input_size=(latent_dim,))
print("\nDiscriminator Summary:")
summary(discriminator, input_size=(1, 200, 160))

debug_dimensions()


Generator Summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 64000]       6,464,000
              ReLU-2                [-1, 64000]               0
       BatchNorm2d-3          [-1, 128, 25, 20]             256
          Upsample-4          [-1, 128, 50, 40]               0
            Conv2d-5          [-1, 128, 50, 40]         147,584
       BatchNorm2d-6          [-1, 128, 50, 40]             256
         LeakyReLU-7          [-1, 128, 50, 40]               0
          Upsample-8         [-1, 128, 100, 80]               0
            Conv2d-9          [-1, 64, 100, 80]          73,792
      BatchNorm2d-10          [-1, 64, 100, 80]             128
        LeakyReLU-11          [-1, 64, 100, 80]               0
         Upsample-12         [-1, 64, 200, 160]               0
           Conv2d-13          [-1, 1, 200, 160]             577
             Tanh-1

In [16]:
def train_gan(generator, discriminator, latent_dim, epochs, batch_size, images, accumulation_steps=8):
    """
    Train the GAN model using images loaded from S3.
    
    Args:
        generator: The generator model
        discriminator: The discriminator model
        latent_dim: Dimension of the latent space
        epochs: Number of training epochs
        batch_size: Size of training batches
        images: Tensor of real images loaded from S3
    """
    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    criterion = nn.BCELoss()

    # Lists to store loss values for plotting
    d_losses = []
    g_losses = []

    # Create a directory for saving generated images
    os.makedirs('generated_images', exist_ok=True)

    print("Starting GAN training...")
    for epoch in range(epochs):
        for i in range(0, len(images), batch_size):
            # ---------------------
            #  Train Discriminator
            # ---------------------
            real_imgs = images[i:i+batch_size].to(device)
            batch_size = real_imgs.size(0)
            
            # Labels for real and fake images
            valid = torch.ones((batch_size, 1), device=device)
            fake = torch.zeros((batch_size, 1), device=device)

            # Train with real images
            optimizer_D.zero_grad()
            real_loss = criterion(discriminator(real_imgs), valid)
            real_loss.backward()

            # Train with fake images
            z = torch.randn(batch_size, latent_dim, device=device)
            gen_imgs = generator(z)
            fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
            fake_loss.backward()
            
            d_loss = real_loss + fake_loss
            optimizer_D.step()

            # ---------------------
            #  Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(gen_imgs), valid)  # We want discriminator to believe the fake images are real
            g_loss.backward()
            optimizer_G.step()
            
            # Accumulate gradients
            if (i + 1) % accumulation_steps == 0:
                optimizer_G.step()
                optimizer_G.zero_grad()
                optimizer_D.step()
                optimizer_D.zero_grad()

            # Track losses
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

        # Print progress
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{epochs} | D loss: {d_loss.item():.4f} | G loss: {g_loss.item():.4f}")
            save_generated_images(epoch, generator, latent_dim)

            # Save models periodically
            if epoch % 100 == 0:
                save_models(epoch, generator, discriminator)

    return d_losses, g_losses

# Training parameters
epochs = 500
batch_size = 16  # Reduce the batch size to fit in GPU memory

# Load images from S3 bucket and create a DataLoader
images = load_images_from_s3(bucket, 'Akshay/brain_slices/320')
train_loader = DataLoader(TensorDataset(images), batch_size=batch_size, shuffle=True)

# Train the GAN
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'  # Set the environment variable
d_losses, g_losses = train_gan(generator, discriminator, latent_dim, epochs, batch_size, images)

# Plot training history
plot_training_history(d_losses, g_losses)


Loading images from S3 bucket: chemocraft-data/Akshay/brain_slices/320

Generator Summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 64000]       6,464,000
              ReLU-2                [-1, 64000]               0
       BatchNorm2d-3          [-1, 128, 25, 20]             256
          Upsample-4          [-1, 128, 50, 40]               0
            Conv2d-5          [-1, 128, 50, 40]         147,584
       BatchNorm2d-6          [-1, 128, 50, 40]             256
         LeakyReLU-7          [-1, 128, 50, 40]               0
          Upsample-8         [-1, 128, 100, 80]               0
            Conv2d-9          [-1, 64, 100, 80]          73,792
      BatchNorm2d-10          [-1, 64, 100, 80]             128
        LeakyReLU-11          [-1, 64, 100, 80]               0
         Upsample-12         [-1, 64, 200, 160]               0
           C