In [1]:
# import os
# import shutil

# # link to data 
# # https://www.kaggle.com/datasets/abdullahalmunem/brats17/data

# # Define the base path and destination path
# base_path = "/Users/lindatang/Desktop/tumor_dl/BRATS2017/Brats17TrainingData/HGG"
# destination_path = "/Users/lindatang/Desktop/tumor_dl/BRATS2017/TrainingDataset/images"

# # Ensure the destination directory exists
# os.makedirs(destination_path, exist_ok=True)

# # Loop through the subfolders in the base path
# for patient_folder in os.listdir(base_path):
#     patient_path = os.path.join(base_path, patient_folder)
    
#     # Check if it's a directory
#     if os.path.isdir(patient_path):
#         # Loop through the files in the patient folder
#         for file_name in os.listdir(patient_path):
#             if file_name.endswith("flair.nii"):
#                 # Construct full file path
#                 file_path = os.path.join(patient_path, file_name)
                
#                 # Copy the file to the destination folder
#                 shutil.copy(file_path, destination_path)
#                 print(f"Copied {file_name} to {destination_path}")

# print("Finished copying t1ce.nii files.")

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset
import nibabel as nib
import numpy as np
import pandas as pd
import torchio as tio
from img_branch_utils import GBMdataset
import pandas as pd

In [3]:
from torch.utils.data import DataLoader

# Define the directory with images and segmentation
image_dir = "/Users/lindatang/Desktop/tumor_dl/BRATS2017/TrainingDataset/images"

# Define the path to the CSV file containing the patient survival data
csv_path = "/Users/lindatang/Desktop/tumor_dl/BRATS2017/survival_data_small.csv" # TODO-

# Create the dataset
dataset = GBMdataset(image_dir=image_dir, csv_path=csv_path)

# Create the DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)

In [4]:
# # Example loop over the data
# for batch_idx, (images, survival_times) in enumerate(dataloader):
#     if batch_idx < 2:  # Only print for the first 2 batches
#         print(f"Batch {batch_idx + 1}")
#         print(f"Images shape: {images.shape}")
#         print(f"Survival times: {survival_times}")
#     else:
#         break  # Exit the loop after the first 2 batches

In [5]:
class GaussianNoise3D(nn.Module):
    def __init__(self, mean=0.0, std=0.1):
        super(GaussianNoise3D, self).__init__()
        self.mean = mean
        self.std = std
        
    def forward(self, x):
        if self.training:  # Apply noise only during training
            noise = torch.randn_like(x) * self.std + self.mean
            return x + noise
        return x
    
class Encoder(nn.Module):
    def __init__(self, input_shape, network_depth, no_convolutions, conv_filter_no_init, 
                 conv_kernel_size, latent_representation_dim, l1, l2, dropout_value, 
                 use_batch_normalization, activation, gaussian_noise_std=None):
        super(Encoder, self).__init__()
        self.input_shape = input_shape
        self.network_depth = network_depth
        self.no_convolutions = no_convolutions
        self.conv_filter_no_init = conv_filter_no_init
        self.conv_kernel_size = conv_kernel_size
        self.latent_representation_dim = latent_representation_dim
        self.l1 = l1
        self.l2 = l2
        self.dropout_value = dropout_value
        self.use_batch_normalization = use_batch_normalization
        self.activation = activation
        self.gaussian_noise_std = gaussian_noise_std   
        self.encoder_layers = nn.ModuleList()

        # Gaussian noise layer
        if gaussian_noise_std:
            self.noise_layer = GaussianNoise3D(gaussian_noise_std)
        else:
            self.noise_layer = None

        # Convolutional layers
        in_channels = input_shape[0]
        for i in range(network_depth):
            for j in range(no_convolutions):
                out_channels = self.conv_filter_no_init * (2 ** i)
                conv_layer = nn.Conv3d(in_channels, out_channels, conv_kernel_size, padding=conv_kernel_size // 2)
                self.encoder_layers.append(conv_layer)
                if self.use_batch_normalization:
                    self.encoder_layers.append(nn.BatchNorm3d(out_channels))
                if self.activation == 'leakyrelu':
                    self.encoder_layers.append(nn.LeakyReLU(inplace=True))
                else:
                    self.encoder_layers.append(nn.ReLU(inplace=True))
                in_channels = out_channels
            self.encoder_layers.append(nn.MaxPool3d(kernel_size=2, stride=2))
            if dropout_value:
                self.encoder_layers.append(nn.Dropout3d(p=dropout_value))

        self.flatten = nn.Flatten()

        # Calculate feature map size after convolution
        with torch.no_grad():
            dummy_input = torch.zeros(1, *input_shape)
            conv_output = self._forward_conv_layers(dummy_input)
            self.feature_map_size = conv_output.size()
            flattened_dim = conv_output.view(1, -1).size(1)

        # Fully connected layer
        self.fc = nn.Linear(flattened_dim, latent_representation_dim)
        if activation == 'leakyrelu':
            self.activation_fn = nn.LeakyReLU(inplace=True)
        else:
            self.activation_fn = nn.ReLU(inplace=True)
            
    def _forward_conv_layers(self, x):
        if self.noise_layer:
            x = self.noise_layer(x)
        for layer in self.encoder_layers:
            x = layer(x)
        return x
    
    def forward(self, x):
        x = self._forward_conv_layers(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = self.activation_fn(x)
        return x

In [6]:
## hyperparameter tuning

from ray import tune
from ray import train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle
from ray.tune import CLIReporter

In [7]:
def downsample_fn(depth):
    if depth == 4:
        return [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)]  # Increase depth upsample factor to 2
    elif depth == 2:
        return [(4, 4, 4), (4, 4, 4)]  # Increase depth upsample factor to 4 if needed
    else:
        raise ValueError(f'Unsupported depth: {depth}')
    
class Decoder3D(nn.Module):
    def __init__(self, conv_shape, network_depth, no_convolutions, conv_filter_no_init,
                 conv_kernel_size, latent_representation_dim, output_channels=5, l1=0.0, l2=0.0,
                 dropout_value=0.0, use_batch_normalization=False, activation='relu'):
        super(Decoder3D, self).__init__()
        self.conv_shape = conv_shape  # Shape of the feature map at the start of the decoder
        self.network_depth = network_depth
        self.no_convolutions = no_convolutions
        self.conv_filter_no_init = conv_filter_no_init
        self.conv_kernel_size = conv_kernel_size
        self.latent_representation_dim = latent_representation_dim
        self.l1 = l1
        self.l2 = l2
        self.dropout_value = dropout_value
        self.use_batch_normalization = use_batch_normalization
        self.activation = activation
        self.output_channels = output_channels  # Final output channels (e.g., 5 channels for MRI modalities)
        
        # Activation function
        if activation == 'leakyrelu':
            self.activation_fn = nn.LeakyReLU(inplace=True)
        else:
            self.activation_fn = nn.ReLU(inplace=True)

        # Fully connected layer to reshape the latent vector into a 3D shape
        self.fc = nn.Linear(latent_representation_dim, np.prod(self.conv_shape))

        # Reshape layer to convert the flat output of the FC layer into a 3D volume
        self.reshape = lambda x: x.view(-1, *self.conv_shape)
        
        # Decoder layers (upsample and conv layers)
        self.decoder_layers = nn.ModuleList()
        
        # Reverse the depth, so we progressively upsample back to the original image size
        in_channels = conv_shape[0]  # Start with the number of channels from the conv_shape
        for i in reversed(range(network_depth)):
            # Upsampling layer
            upsample_factors = downsample_fn(network_depth)[i]
            self.decoder_layers.append(nn.Upsample(scale_factor=upsample_factors, mode='trilinear', align_corners=False))
            
            # Convolution layers
            out_channels = self.conv_filter_no_init * (2 ** i)  # Reduce the number of channels as we move up the network
            for j in range(no_convolutions):
                print(f"Layer {i}-{j}: in_channels = {in_channels}, out_channels = {out_channels}")
                self.decoder_layers.append(nn.Conv3d(in_channels, out_channels, kernel_size=conv_kernel_size, padding=1))
                if use_batch_normalization:
                    self.decoder_layers.append(nn.BatchNorm3d(out_channels))
                if activation == 'leakyrelu':
                    self.decoder_layers.append(nn.LeakyReLU(inplace=True))
                else:
                    self.decoder_layers.append(nn.ReLU(inplace=True))
                in_channels = out_channels  # Update in_channels for the next layer
            if dropout_value > 0.0:
                self.decoder_layers.append(nn.Dropout3d(p=dropout_value))
        
        # Final convolution to produce the reconstructed image with the correct number of output channels
        print(f"Final Layer: in_channels = {in_channels}, out_channels = {self.output_channels}")
        self.final_conv = nn.Conv3d(in_channels, self.output_channels, conv_kernel_size, padding=1)
        self.final_activation = nn.ReLU()  # You could change this to `nn.Sigmoid()` or `nn.Tanh()` depending on the data range

    def forward(self, x):
        # Expand the latent vector via the fully connected layer
        x = self.fc(x)
        x = self.activation_fn(x)
        
        # Reshape to the shape required for the convolutional layers
        x = self.reshape(x)  # Reshape to 3D tensor (batch_size, channels, depth, height, width)
        
        # Apply the decoder layers
        for layer in self.decoder_layers:
            x = layer(x)
        
        # Final convolution to produce the output volume
        x = self.final_conv(x)
        x = self.final_activation(x)
        return x

In [8]:
class LatentParametersModel(nn.Module):
    def __init__(self, latent_representation_dim, l1=0.0, l2=0.0):
        super(LatentParametersModel, self).__init__()
        self.mu_sigma_layer = nn.Linear(
            in_features=latent_representation_dim, 
            out_features=2
        )
        nn.init.xavier_uniform_(self.mu_sigma_layer.weight)
        nn.init.zeros_(self.mu_sigma_layer.bias)
        self.l1 = l1
        self.l2 = l2
    
    def forward(self, x):
        mu_sigma = self.mu_sigma_layer(x)
        return mu_sigma

In [9]:
def reconstruction_loss(y_true, y_pred):
    mse_loss = F.mse_loss(y_pred, y_true, reduction='none')
    reduced_loss = mse_loss.mean(dim=[1, 2, 3])
    return reduced_loss

def survival_loss(mu, sigma, x, delta):
    """
    Custom loss function based on the negative log-likelihood.

    :param mu: Predicted mean (log of hazard ratio), tensor of shape (batch_size,)
    :param sigma: Predicted standard deviation (scale parameter), tensor of shape (batch_size,)
    :param x: Observed time (log-transformed), tensor of shape (batch_size,)
    :param delta: Event indicator (1 if event occurred, 0 if censored), tensor of shape (batch_size,)
    :return: Computed loss, scalar value
    """
    # Negative log-likelihood term
    total_loss = -(torch.log(x)-mu)/sigma.sum()+(delta * torch.log(sigma) + (1 + delta) * torch.log(1 + torch.exp((torch.log(x)-mu)/sigma)))
    
    # Return the mean loss across the batch
    return total_loss / x.size(0)

# # # Example usage
# mu = torch.tensor([0.5, 0.8, 0.3])  # Predicted means (log hazard ratios)
# sigma = torch.tensor([1.1, 1.2, 1.1])  # Predicted standard deviations (not log-transformed)
# x = torch.tensor([1.0, 0.8, 0.9])  # Log-transformed observed times
# delta = torch.tensor([1.0, 0.0, 1.0])  # Event indicators

# loss = survival_loss(mu, sigma, x, delta)
# print(loss)

In [10]:
# Set up hyperparameter grid for Encoder 

# input_shape = (5, 128, 128, 128)  # 5 channels, 240x240x155 spatial dimensions
# conv_shape = (256, 8, 8, 8)

#network_depth = 1 #4
#no_convolutions = 1 #2
#conv_filter_no_init = 2 #32
#conv_kernel_size = 5 #3
#latent_representation_dim = 8 #128
#l1 = 0.01
#l2 = 0.01
#dropout_value = 0.5
#use_batch_normalization = True
#activation = 'leakyrelu'
#gaussian_noise_std = 0.1

config = { #lr? batchsize? 
    "network_depth": tune.choice([1, 2]),
    "no_convolutions": tune.choice([1]),

    "conv_filter_no_init": tune.choice([2]),
    "conv_kernel_size": tune.choice([5]),
    "latent_representation_dim": tune.choice([8]),
    
    "l1": tune.choice([0.001]),
    "l2": tune.choice([0.001]),

    "dropout_value": tune.choice([0.5]),
    "use_batch_normalization": tune.choice([True]), 
    "activation": tune.choice(['leaky_relu']), #'relu', 'tanh', 'sigmoid', 
    #"output_channels": tune.choice([2]), #? 
    "gaussian_noise_std": tune.choice([0.1]),
}

### example
# encoder = encoder(
#     input_shape=input_shape,
#     network_depth=network_depth,
#     no_convolutions=no_convolutions,
#     conv_filter_no_init=conv_filter_no_init,
#     conv_kernel_size=conv_kernel_size,
#     latent_representation_dim=latent_representation_dim,
#     l1=l1,
#     l2=l2,
#     dropout_value=dropout_value,
#     use_batch_normalization=use_batch_normalization,
#     activation=activation,
#     gaussian_noise_std=gaussian_noise_std
# )

# print(encoder.feature_map_size)


In [11]:
def train(config, dataloader=None): # remove decoder later

    ## define encoder
    encoder = Encoder(
        input_shape=(5, 128, 128, 128),
        network_depth=config["network_depth"],
        no_convolutions=config["no_convolutions"],
        conv_filter_no_init=config["conv_filter_no_init"],
        conv_kernel_size=config["conv_kernel_size"],
        latent_representation_dim=config["latent_representation_dim"],
        l1=config["l1"],
        l2=config["l2"],
        dropout_value=config["dropout_value"],
        use_batch_normalization=config["use_batch_normalization"],
        activation=config["activation"],
        gaussian_noise_std=config["gaussian_noise_std"]
    )
    print(encoder.feature_map_size)

    ## define decoder 
    latent_vector = torch.randn(1, 128)  # Latent vector of size [batch_size, latent_representation_dim]
    conv_shape = (256, 8, 8, 8)  # Example shape of the feature map before the decoder
    decoder_model = Decoder3D(conv_shape=conv_shape, network_depth=4, no_convolutions=2, 
                            conv_filter_no_init=64, conv_kernel_size=3, 
                            latent_representation_dim=128)

    # output_image = decoder_model(latent_vector)
    # print("Decoder output shape:", output_image.shape)  # Should be [batch_size, 5, 240, 240, 150] or similar

    optimizer = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9)
    encoder.train()
    decoder_model.train()
    device = "cpu" # TODO - 
    encoder.to(device)
    decoder_model.to(device)

    for epoch in range(2):

        total_loss = 0
        print("*********")
        
        for batch in dataloader:
            inputs, targets = batch

            inputs = inputs.squeeze(2)  # Removes the dimension at index 2

            # Forward pass through encoder
            encoded = encoder(inputs)

            # Generate prediction from latent dimension
            model = LatentParametersModel(latent_representation_dim=latent_representation_dim)
            mu_sigma = model(encoded)
            mu = mu_sigma[0,0]
            sigma = mu_sigma[0,1]

            # Calculate loss
            delta = torch.tensor([1.0])  #TODO - not defined?

            loss = survival_loss(mu, sigma, targets, delta)

            print("~~~~~")
            print("the current loss is: ")
            print(loss.item())

            if torch.isnan(loss).any():
                # TODO - 
                continue

            total_loss += loss.item()

            # Decoder

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        average_loss = total_loss / len(dataloader)
        print("epoch", epoch)
        print("average loss is:", average_loss)

    print("Finished Training")

In [12]:
# Define a scheduler
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=10,
    grace_period=1,
    reduction_factor=2
)

# Define a progress reporter
reporter = CLIReporter(
    metric_columns=["loss", "training_iteration"]
)

In [13]:
from functools import partial

# set up encoder like this 
scheduler = ASHAScheduler(
      metric="loss",
      mode="min",
      max_t=10,
      grace_period=1,
      reduction_factor=2
  )

analysis = tune.run(
      #tune.with_parameters(train, dataloader=dataloader),  # Pass your dataloader
      partial(train, dataloader=dataloader),  # Pass your dataloader
      resources_per_trial={"cpu": 1, "gpu": 0},
      config=config,
      num_samples=2,  # Number of samples from the search space
      scheduler=scheduler, 
      resume=True
  )

2024-08-19 15:38:30,441	INFO worker.py:1781 -- Started a local Ray instance.
2024-08-19 15:38:30,980	INFO tune.py:253 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `tune.run(...)`.
2024-08-19 15:38:30,981	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


0,1
Current time:,2024-08-19 15:44:26
Running for:,00:05:55.35
Memory:,6.7/8.0 GiB

Trial name,status,loc,activation,conv_filter_no_init,conv_kernel_size,dropout_value,gaussian_noise_std,l1,l2,latent_representatio n_dim,network_depth,no_convolutions,use_batch_normalizat ion
train_9d429_00000,RUNNING,127.0.0.1:1142,leaky_relu,2,5,0.5,0.1,0.001,0.001,8,1,1,True
train_9d429_00001,RUNNING,127.0.0.1:1143,leaky_relu,2,5,0.5,0.1,0.001,0.001,8,1,1,True
train_9d429_00002,RUNNING,127.0.0.1:1144,leaky_relu,2,5,0.5,0.1,0.001,0.001,8,2,1,True
train_9d429_00003,RUNNING,127.0.0.1:1145,leaky_relu,2,5,0.5,0.1,0.001,0.001,8,2,1,True


[36m(func pid=1142)[0m torch.Size([1, 2, 64, 64, 64])
[36m(func pid=1142)[0m Layer 3-0: in_channels = 256, out_channels = 512
[36m(func pid=1142)[0m Layer 3-1: in_channels = 512, out_channels = 512
[36m(func pid=1142)[0m Layer 2-0: in_channels = 512, out_channels = 256
[36m(func pid=1142)[0m Layer 2-1: in_channels = 256, out_channels = 256
[36m(func pid=1144)[0m torch.Size([1, 4, 32, 32, 32])
[36m(func pid=1142)[0m Final Layer: in_channels = 64, out_channels = 5
[36m(func pid=1145)[0m torch.Size([1, 4, 32, 32, 32])
[36m(func pid=1143)[0m torch.Size([1, 2, 64, 64, 64])
[36m(func pid=1142)[0m *********


2024-08-19 15:44:26,356	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/lindatang/ray_results/train_2024-08-19_15-38-30' in 0.3969s.
2024-08-19 15:44:36,781	INFO tune.py:1041 -- Total run time: 365.80 seconds (354.95 seconds for the tuning loop).
Resume experiment with: tune.run(..., resume=True)


