In [4]:
# Install the required packages for the notebook

%pip install boto3 nibabel numpy matplotlib torch torchvision torchaudio pillow tensorflow


Note: you may need to restart the kernel to use updated packages.


In [5]:
# Sprint 1: Reading NII files from S3 and saving PNGs

import boto3
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import io
import tempfile
import os

# Initialize S3 resource and specify bucket and folder details
s3 = boto3.resource('s3')
bucket_name = 'chemocraft-data'
folder_path = 'MICCAI_BraTS2020_TrainingData/'
# folder_path = 'Data/BraTS20_Training_369 copy/'
bucket = s3.Bucket(bucket_name)

def plot_slice(data, crop, slice_idx, filename):
    # Crop the specified slice
    slice_2d = data[:, :, slice_idx]
    cropped_slice = slice_2d[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
    
    # Display the cropped slice with matplotlib
    plt.figure(figsize=(6, 6))
    plt.imshow(cropped_slice, cmap='gray')
    plt.title(f'Slice {slice_idx} of {filename}')
    plt.axis('off')  # Hide axes for cleaner display
    plt.show()

def savePNG(data, crop, filename):    
    # Prepare directory structure
    fileWOext = filename.split(".")[0]
    TrainingCount = fileWOext.split("_")[-2]
    ScanType = fileWOext.split("_")[-1]
    slice_path = f"brain_slices/{TrainingCount}/{ScanType}/"
    print(f"Saving in directory: {slice_path}")

    # Iterate through each slice in the Z-Dimiension data and save as PNG
    for slice_idx in range(data.shape[2]):
        # Crop each slice
        slice_2d = data[:, :, slice_idx]
        cropped_slice = slice_2d[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]]
        png_filename = f"{slice_path}{slice_idx}.png"
        
        # Local Saving
        # try:
        #     # Create directories as needed and save each slice
        #     os.makedirs(slice_path, exist_ok=True)
        #     mpimg.imsave(png_filename, cropped_slice, cmap='gray')
        # except Exception as e:
        #     print(f"ERROR: directory could not be made due to {e}")
        
        # Upload each PNG to S3
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_png:
            mpimg.imsave(temp_png.name, cropped_slice, cmap='gray')
            temp_png.flush()
            temp_png.seek(0)
            temp_png_name = temp_png.name  # Store the name to use it after the file is closed

        try:
            s3.Bucket(bucket_name).upload_file(temp_png_name, f"Akshay/{png_filename}")
            os.remove(temp_png_name)
        except Exception as e:
            print(f"ERROR: Could not upload or delete temporary PNG file due to {e}")

def render_nii_from_s3(filename, path):
    print(f"Fetching file: {filename}")

    try:
        obj = bucket.Object(path + filename)
        file_stream = io.BytesIO(obj.get()['Body'].read())
    except s3.meta.client.exceptions.NoSuchKey as e:
        print(f"ERROR: The specified key does not exist: {path + filename}")
        return
    except Exception as e:
        print(f"ERROR: An unexpected error occurred: {e}")
        return

    with tempfile.NamedTemporaryFile(suffix='.nii', delete=False) as temp_file:  # Prevent auto-deletion
        temp_file.write(file_stream.getvalue())
        temp_file.flush()

        temp_file_path = temp_file.name
        print(f"Temporary file created: {temp_file_path}")

    try:
        img = nib.load(temp_file_path)
        data = img.get_fdata()

        print(f"Data shape for {filename}: {data.shape}")
        
        if data.size == 0:
            print(f"No data found in {filename}")
            return

        # Define crop dimensions
        cropleft = 25
        cropright = data.shape[0] - 15
        cropbottom = data.shape[1] - 40
        croptop = 40
        
        crop = np.array([[croptop, cropbottom], [cropleft, cropright]])
        
        # Save the PNGs and plot a sample slice
        savePNG(data, crop, filename)
        
        # slice_idx = 88  # Choose a slice index for sample display
        # plot_slice(data, crop, slice_idx, filename)

    except Exception as e:
        print(f"Error loading file {filename}: {e}")
        
    finally:
        try:
            os.remove(temp_file_path)
            print(f"Deleted temporary file: {temp_file_path}")
        except OSError as cleanup_error:
            print(f"Error deleting temp file: {cleanup_error}")

def find_and_render_nii_files():
    found_files = False

    subfolders = set()  # use a set to ensure unique subfolder names
    for obj in bucket.objects.filter(Prefix=folder_path):
        # Get the path after the 'Data/' prefix and split it by '/'
        path_parts = obj.key[len(folder_path):].split('/')
        
        # Check if there's at least one part (indicating a subfolder)
        if len(path_parts) > 1:
            subfolders.add(f'{path_parts[0]}/')  # Add the subfolder name
            
    subfolders = sorted(subfolders)

    print(f"Root Directory: {folder_path.split('/')[0]}")
    # print(subfolders)

    for subfolder in subfolders:
        path = folder_path + subfolder
        print(f"Reading S3 in {path}")
        for obj in bucket.objects.filter(Prefix=path):
            if obj.key.endswith('.nii'):
                print(f"path: {path}")
                found_files = True
                filename = obj.key.split('/')[-1]  # Extract filename from path
                print(f"Found .nii file: {filename}")
                render_nii_from_s3(filename, path)

    if not found_files:
        print(f"No .nii files found in the folder {folder_path}")

# Main function
# find_and_render_nii_files()

In [6]:
# Sprint 2: GAN for Brain MRI Generation

import io
from io import BytesIO
import keras
import numpy as np
from PIL import Image
import boto3
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models # type: ignore
from tensorflow.keras.optimizers import Adam # type: ignore
from tensorflow.keras.utils import Sequence # type: ignore # type: ignore
from keras.preprocessing.image import img_to_array, load_img # type: ignore

# Initialize S3 resource and specify bucket and folder details
s3 = boto3.resource('s3')
bucket_name = 'chemocraft-data'
bucket = s3.Bucket(bucket_name)
folder_prefix = "Akshay/brain_slices/"

def load_images_from_s3(bucket, folder_prefix):
    print(f"Loading images from S3 bucket: {bucket_name}/{folder_prefix}")
    images = []
    for obj in bucket.objects.filter(Prefix=folder_prefix):
        if obj.key.endswith('.png'):
            file_stream = io.BytesIO(obj.get()['Body'].read())
            image = load_img(file_stream, target_size=(200, 160), color_mode='grayscale')
            image = img_to_array(image) / 255.0  # Normalize to [0, 1]
            images.append(image)
    return np.array(images)


# Generator Model
def build_generator(latent_dim=100, output_shape=(200, 160, 1)):
    """
    Builds the generator model for a GAN.
    
    Parameters:
    latent_dim (int): The size of the input latent vector.
    output_shape (tuple): The desired shape of the generated images (height, width, channels).
    
    Returns:
    A compiled Keras Sequential model.
    """
    model = models.Sequential(name="Generator")
    print("Building Generator Model")
    
    # Input layer
    print(f"Output shape: {output_shape}")
    model.add(layers.Input(shape=(latent_dim,)))
    model.add(layers.Dense(128 * 25 * 20, activation="relu"))
    model.add(layers.Reshape((25, 20, 128)))
    
    # Transposed convolutional layers to upsample
    model.add(layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding="same", activation="relu"))
    model.add(layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding="same", activation="relu"))
    # model.add(layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding="same", activation="relu"))  
    # model.add(layers.Conv2DTranspose(8, kernel_size=4, strides=4, padding="same", activation="relu"))  
    model.add(layers.Conv2D(8, kernel_size=3, padding="same", activation="tanh"))
    model.add(layers.Conv2DTranspose(output_shape[2], kernel_size=4, strides=2, padding="same", activation="relu"))
    model.add(layers.Activation("tanh"))
    # model.summary()

    # model.compile(optimizer="adam", loss="binary_crossentropy")
    # model.compile(Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')
    return model

# Discriminator Model
def build_discriminator(input_shape):
    """
    Builds the discriminator model for a GAN.
    
    Parameters:
    input_shape (tuple): The shape of the input images (height, width, channels).
    
    Returns:
    A compiled Keras Sequential model.
    """
    model = models.Sequential(name="Discriminator")
    print("Building Discriminator Model")
    
    # Input layer
    print(f"Input shape: {input_shape}")
    
    model.add(layers.Input(shape=input_shape))
    model.add(layers.Conv2D(32, kernel_size=4, strides=2, padding="same"))
    model.add(layers.Conv2D(32, kernel_size=4, strides=2, padding="same", input_shape=input_shape))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2D(64, kernel_size=4, strides=2, padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    
    # Flatten and output layer
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation="sigmoid"))
    # model.summary()
    
    # model.compile(Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')
    # model.compile(optimizer="adam", loss="binary_crossentropy")
    return model

# GAN Model
def compile_gan(generator, discriminator):
    """
    Compiles the GAN model by connecting the generator and discriminator.
    
    Parameters:
    generator (keras.Model): The generator model.
    discriminator (keras.Model): The discriminator model.
    
    Returns:
    A compiled GAN model.
    """
    # Ensure the discriminator weights are not updated during generator training
    discriminator.trainable = False
    gan_input = layers.Input(shape=(generator.input_shape[1],))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)
    gan = models.Model(gan_input, gan_output, name="GAN")
    gan.compile(Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')
    
    return gan

# Instantiate models
latent_dim = 100
img_shape = (200, 160, 1)
generator = build_generator(latent_dim=latent_dim, output_shape=img_shape)
discriminator = build_discriminator(input_shape=img_shape)

generator.compile(Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')
discriminator.compile(Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')

gan = compile_gan(generator, discriminator)

# Display model summaries
generator.summary()
discriminator.summary()
gan.summary()

Building Generator Model
Output shape: (200, 160, 1)
Building Discriminator Model
Input shape: (200, 160, 1)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [7]:
# keras implmentation
# def train_gan(generator, discriminator, gan, latent_dim, epochs, batch_size):
#     """
#     Train the GAN model using images loaded from S3.
    
#     Args:
#         generator: The generator model
#         discriminator: The discriminator model
#         gan: The combined GAN model
#         latent_dim: Dimension of the latent space
#         epochs: Number of training epochs
#         batch_size: Size of training batches
#     """
#     # Load images from S3
#     s3 = boto3.resource('s3')   
#     bucket = s3.Bucket('chemocraft-data')
#     images = load_images_from_s3(bucket, 'Akshay/brain_slices/320/')
    
#     half_batch = batch_size // 2
    
#     # Lists to store loss values for plotting
#     d_losses = []
#     g_losses = []
    
#     # Create a directory for saving generated images
#     os.makedirs('generated_images', exist_ok=True)
    
#     print("Starting GAN training...")
#     for epoch in range(epochs):
#         # ---------------------
#         #  Train Discriminator
#         # ---------------------
        
#         # Select a random half batch of real images
#         idx = np.random.randint(0, images.shape[0], half_batch)
#         real_imgs = images[idx]
        
#         # Generate a half batch of fake images
#         noise = np.random.normal(0, 1, (half_batch, latent_dim))
#         gen_imgs = generator.predict(noise)
        
#         # Train the discriminator
#         d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((half_batch, 1)))
#         d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
#         d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
#         # ---------------------
#         #  Train Generator
#         # ---------------------
        
#         # Generate a batch of noise samples
#         noise = np.random.normal(0, 1, (batch_size, latent_dim))
        
#         # Train the generator
#         # The generator wants the discriminator to label the generated samples as valid (ones)
#         g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
        
#         # Store losses
#         d_losses.append(d_loss[0])
#         g_losses.append(g_loss)
        
#         # Print progress
#         if epoch % 100 == 0:
#             print(f"Epoch {epoch}/{epochs}")
#             print(f"D loss: {d_loss[0]:.4f}, acc.: {d_loss[1]:.2%}")
#             print(f"G loss: {g_loss:.4f}")
            
#             # Save generated images
#             save_generated_images(epoch, generator, latent_dim)
            
#             # Save models periodically
#             if epoch % 1000 == 0:
#                 save_models(epoch, generator, discriminator)
                
#     return d_losses, g_losses

# def save_generated_images(epoch, generator, latent_dim, examples=16, dim=(4, 4)):
#     """
#     Generate and save sample images during training.
#     """
#     noise = np.random.normal(0, 1, (examples, latent_dim))
#     gen_imgs = generator.predict(noise)
    
#     # Rescale images from [-1, 1] to [0, 1]
#     gen_imgs = 0.5 * gen_imgs + 0.5
    
#     # Set up the plot
#     fig = plt.figure(figsize=(8, 8))
#     for i in range(examples):
#         plt.subplot(dim[0], dim[1], i+1)
#         plt.imshow(gen_imgs[i, :, :, 0], cmap='gray')
#         plt.axis('off')
    
#     # Save the plot
#     plt.savefig(f'generated_images/brain_mri_epoch_{epoch}.png')
#     plt.close()

# def save_models(epoch, generator, discriminator):
#     """
#     Save the generator and discriminator models.
#     """
#     generator.save(f'models/generator_epoch_{epoch}.h5')
#     discriminator.save(f'models/discriminator_epoch_{epoch}.h5')

# def plot_training_history(d_losses, g_losses):
#     """
#     Plot the training history of the GAN.
#     """
#     plt.figure(figsize=(10, 5))
#     plt.plot(d_losses, label='Discriminator Loss')
#     plt.plot(g_losses, label='Generator Loss')
#     plt.title('GAN Training History')
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.legend()
#     plt.grid(True)
#     plt.savefig('training_history.png')
#     plt.close()

# # Training
# epochs = 50
# batch_size = 155*5
# half_batch = batch_size // 2


# # Train the GAN
# d_losses, g_losses = train_gan(
#     generator=generator,
#     discriminator=discriminator,
#     gan=gan,
#     latent_dim=latent_dim,
#     epochs=epochs,
#     batch_size=half_batch
# )
    
#     # Plot training history
# plot_training_history(d_losses, g_losses)

In [8]:
# def train_gan(generator, discriminator, gan, latent_dim, epochs, batch_size):
#     # Load images from S3
#     s3 = boto3.resource('s3')   
#     bucket = s3.Bucket('chemocraft-data')
#     images = load_images_from_s3(bucket, 'Akshay/brain_slices/320/')

#     half_batch = 64//2
    
#     print("Training GAN...")
#     for epoch in range(epochs):
#         # Train discriminator with real images
#         idx = np.random.randint(0, images.shape[0], half_batch)
#         real_images = images[idx]
#         real_labels = np.ones((half_batch, 1))
        
#         # Generate fake images
#         noise = np.random.normal(0, 1, (half_batch, latent_dim))
#         fake_images = generator.predict(noise)
#         fake_labels = np.zeros((half_batch, 1))
        
#         # Train discriminator
#         d_loss_real = discriminator.train_on_batch(real_images, real_labels)
#         d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
#         d_loss = 0.5 * (d_loss_real + d_loss_fake)
        
#         # Update discriminator loss tracker
#         if hasattr(discriminator, 'metrics') and discriminator.metrics:
#             for metric in discriminator.metrics:
#                 if hasattr(metric, 'update_state'):
#                     metric.update_state(d_loss)
        
#         # Train generator
#         noise = np.random.normal(0, 1, (batch_size, latent_dim))
#         misleading_labels = np.ones((batch_size, 1))
#         g_loss = gan.train_on_batch(noise, misleading_labels)
        
#         # Update generator loss tracker
#         if hasattr(gan, 'metrics') and gan.metrics:
#             for metric in gan.metrics:
#                 if hasattr(metric, 'update_state'):
#                     metric.update_state(g_loss)
        
#         # Print loss values
#         print(f"{epoch+1}/{epochs}, D Loss: {d_loss}, G Loss: {g_loss}")
        
# train_gan(generator, discriminator, gan, latent_dim, epochs, batch_size)

In [9]:
# def generate_3d_image(generator, latent_dim, num_slices=150):
#     noise = np.random.normal(0, 1, (num_slices, latent_dim))
#     generated_slices = generator.predict(noise)
#     generated_3d_image = np.stack(generated_slices, axis=0)  # Shape: (num_slices, 200, 160)
#     return generated_3d_image

# # generated_3d_image = generate_3d_image(generator, latent_dim=100, num_slices=150)

# # Display the generated 3D image
# # fig, axs = plt.subplots(10, 15, figsize=(15, 10))
# # cnt = 0
# # for i in range(10):
# #     for j in range(15):
# #         axs[i, j].imshow(generated_3d_image[cnt, :, :, 0], cmap='gray')
# #         axs[i, j].axis('off')
# #         cnt += 1
# # plt.show()

# # Save the models
# generator.save('generator_model.h5')
# discriminator.save('discriminator_model.h5')
# gan.save('gan_model.h5')


In [10]:
# pytorch implementation
import io
import numpy as np
import boto3
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

# Initialize S3 resource and specify bucket and folder details
s3 = boto3.resource('s3')
bucket_name = 'chemocraft-data'
bucket = s3.Bucket(bucket_name)
folder_prefix = "Akshay/brain_slices/"

# Set device for CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Function to load images from S3
def load_images_from_s3(bucket, folder_prefix):
    print(f"Loading images from S3 bucket: {bucket_name}/{folder_prefix}")
    images = []
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((200, 160)),
        transforms.ToTensor(),
    ])
    for obj in bucket.objects.filter(Prefix=folder_prefix):
        if obj.key.endswith('.png'):
            file_stream = io.BytesIO(obj.get()['Body'].read())
            image = Image.open(file_stream)
            image = transform(image)  # Normalize to [0, 1]
            images.append(image)
    return torch.stack(images).to(device)

# Generator Model
class Generator(nn.Module):
    def __init__(self, latent_dim=100, output_shape=(200, 160, 1)):
        super(Generator, self).__init__()
        self.init_dim = (128, 25, 20)
        self.fc = nn.Linear(latent_dim, 128 * 25 * 20)
        
        self.deconv_blocks = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 8, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
            nn.ConvTranspose2d(8, output_shape[2], kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )

    def forward(self, z):
        x = self.fc(z).view(-1, *self.init_dim)
        x = self.deconv_blocks(x)
        return x

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_shape=(200, 160, 1)):
        super(Discriminator, self).__init__()
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(input_shape[2], 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(128 * 25 * 20, 1)

    def forward(self, img):
        x = self.conv_blocks(img)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return torch.sigmoid(x)

# Instantiate models and move them to CUDA
latent_dim = 100
img_shape = (200, 160, 1)
generator = Generator(latent_dim=latent_dim, output_shape=img_shape).to(device)
discriminator = Discriminator(input_shape=img_shape).to(device)

# Optimizers
lr = 0.0002
b1 = 0.5
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, 0.999))

# Loss function
criterion = nn.BCELoss()

# GAN training function
def train_gan(data_loader, num_epochs=1000):
    for epoch in range(num_epochs):
        for imgs in data_loader:
            # Configure input
            real_imgs = imgs.to(device)
            batch_size = real_imgs.size(0)

            # Train Discriminator
            optimizer_D.zero_grad()

            # Real images
            valid = torch.ones((batch_size, 1), requires_grad=False).to(device)
            real_loss = criterion(discriminator(real_imgs), valid)

            # Fake images
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_imgs = generator(z)
            fake = torch.zeros((batch_size, 1), requires_grad=False).to(device)
            fake_loss = criterion(discriminator(fake_imgs.detach()), fake)

            # Total Discriminator loss
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(fake_imgs), valid)
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch}/{num_epochs}] | D loss: {d_loss.item()} | G loss: {g_loss.item()}")

# Load data from S3 and start training
# Assuming data loader is created with loaded images and batch size
# data_loader = DataLoader(load_images_from_s3(bucket, folder_prefix), batch_size=32, shuffle=True)
# train_gan(data_loader)


In [12]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, TensorDataset

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.init_size = 256 // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = 256 // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

def train_gan(generator, discriminator, latent_dim, epochs, batch_size, images):
    """
    Train the GAN model using images loaded from S3.
    
    Args:
        generator: The generator model
        discriminator: The discriminator model
        latent_dim: Dimension of the latent space
        epochs: Number of training epochs
        batch_size: Size of training batches
        images: Tensor of real images loaded from S3
    """
    # Optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    criterion = nn.BCELoss()

    # Lists to store loss values for plotting
    d_losses = []
    g_losses = []

    # Create a directory for saving generated images
    os.makedirs('generated_images', exist_ok=True)

    print("Starting GAN training...")
    for epoch in range(epochs):
        for i in range(0, len(images), batch_size):
            # ---------------------
            #  Train Discriminator
            # ---------------------
            real_imgs = images[i:i+batch_size].to(device)
            batch_size = real_imgs.size(0)
            
            # Labels for real and fake images
            valid = torch.ones((batch_size, 1), device=device)
            fake = torch.zeros((batch_size, 1), device=device)

            # Train with real images
            optimizer_D.zero_grad()
            real_loss = criterion(discriminator(real_imgs), valid)
            real_loss.backward()

            # Train with fake images
            z = torch.randn(batch_size, latent_dim, device=device)
            gen_imgs = generator(z)
            fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
            fake_loss.backward()
            
            d_loss = real_loss + fake_loss
            optimizer_D.step()

            # ---------------------
            #  Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(gen_imgs), valid)  # We want discriminator to believe the fake images are real
            g_loss.backward()
            optimizer_G.step()

            # Track losses
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

        # Print progress
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{epochs} | D loss: {d_loss.item():.4f} | G loss: {g_loss.item():.4f}")
            save_generated_images(epoch, generator, latent_dim)

            # Save models periodically
            if epoch % 100 == 0:
                save_models(epoch, generator, discriminator)

    return d_losses, g_losses

def save_generated_images(epoch, generator, latent_dim, examples=16, dim=(4, 4)):
    """
    Generate and save sample images during training.
    """
    z = torch.randn(examples, latent_dim, device=device)
    gen_imgs = generator(z).detach().cpu()
    gen_imgs = 0.5 * (gen_imgs + 1)  # Rescale images from [-1, 1] to [0, 1]
    
    save_image(gen_imgs, f'generated_images/brain_mri_epoch_{epoch}.png', nrow=dim[0], normalize=True)

def save_models(epoch, generator, discriminator):
    """
    Save the generator and discriminator models.
    """
    os.makedirs('models', exist_ok=True)
    torch.save(generator.state_dict(), f'models/generator_epoch_{epoch}.pth')
    torch.save(discriminator.state_dict(), f'models/discriminator_epoch_{epoch}.pth')

def plot_training_history(d_losses, g_losses):
    """
    Plot the training history of the GAN.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.title('GAN Training History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_history.png')
    plt.close()

# Training parameters
epochs = 50
batch_size = 32

# Load images from S3 bucket and create a DataLoader
images = load_images_from_s3(bucket, 'Akshay/brain_slices/320/')
train_loader = DataLoader(TensorDataset(images), batch_size=batch_size, shuffle=True)

# Train the GAN
d_losses, g_losses = train_gan(generator, discriminator, latent_dim, epochs, batch_size, images)

# Plot training history
plot_training_history(d_losses, g_losses)


Loading images from S3 bucket: chemocraft-data/Akshay/brain_slices/320/
Starting GAN training...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x16640 and 32768x1)