In [1]:
import os
import torch
import random
import numpy as np
import pandas as pd
import nibabel as nib
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms.v2 as transforms

from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from monai.networks.nets import EfficientNetBN, Densenet121

In [2]:
SEED = 3

random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
excel_file = './TCIA_LGG_cases_159.xlsx'
df = pd.read_excel(excel_file)

file_label_map = {}
for index, row in df.iterrows():
    filename = row['Filename']
    label = row['1p/19q']
    if isinstance(filename, str) and isinstance(label, str):
      file_label_map[filename] = label

file_paths = []
base_names = set()
# labels = []

nifti_dir = './grouped_data'
for filename in os.listdir(nifti_dir):
    if filename.endswith(".nii.gz") and 'tumor' not in filename.lower():
        base_name = filename.split("_")[0]
        if base_name in file_label_map:
            # label = file_label_map[base_name]
            file_paths.append(os.path.join(nifti_dir, filename))
            base_names.add(base_name)
            # labels.append(label)
        else:
            print(f"File: {filename}, Label: Not found in Excel file - skipping")

In [4]:
data = []
labels = []
for base in base_names:
    sample = {'base_name': base,}
    # find the files in file_paths which have the base in their name and add them as a list to sample['files']
    sample['files'] = [file for file in file_paths if base in file]
    # find the label for the base, convert it to a 0 if it's "n/n" and a 1 if it's "d/d" and add it to sample['label']
    sample['label'] = 0 if file_label_map[base] == 'n/n' else 1
    labels.append(sample['label'])
    data.append(sample)

In [None]:
def show_sample(subject_name, nifti_dir='./grouped_data'):
    for filename in os.listdir(nifti_dir):
        if filename.endswith(".nii.gz") and subject_name in filename and 'tumor' not in filename.lower():
            nifti_file = os.path.join(nifti_dir, filename)
            break
    try:
        img = nib.load(nifti_file)
        img_data = img.get_fdata()
        print(img_data.shape)
        num_slices = img_data.shape[2]

        fig, axes = plt.subplots(int(num_slices**0.5), int(num_slices**0.5), figsize=(15, 15))
        fig.suptitle(f"Slices of Subject: {subject_name}")

        for i in range(num_slices):
            row = i // int(num_slices**0.5)
            col = i % int(num_slices**0.5)

            axes[row, col].imshow(img_data[:, :, i], cmap='gray')
            axes[row, col].set_title(f"Slice {i + 1}")
            axes[row, col].axis('off')

        for i in range(num_slices, int(num_slices**0.5) * int(num_slices**0.5)):
            row = i // int(num_slices**0.5)
            col = i % int(num_slices**0.5)
            axes[row, col].axis('off')

        plt.tight_layout()
        plt.show()

    except Exception as e:
        print(f"Error processing {nifti_file}: {e}")
        return None
    
show_sample('LGG-637')

In [None]:
class NiftiDataset(Dataset):
    """
    PyTorch Dataset for handling NIfTI images and their corresponding labels.
    """

    def __init__(self, data, transform=None):
        """
        Initializes the dataset.

        Parameters:
            - data (list): List of dictionaries containing the file paths and labels for each sample with the following structur:
            data = {
                'base_name': str,
                'files': list[str],
                'label': int        
            }.
            - transform (callable, optional): Optional transform to be applied on a sample.

        This function should initialize the NiftiDataset with the provided data and transform.
        First, it should load the two nifti files in data[files] and concatenate them into a single 4D torch tensor.
        Then, it should apply the transform if it is not None.
        Finally, it should create a list of tuples of the form (torch.Tensor, torch.int) where the first element is the 4D torch tensor and the second element is the corresponding label.
        """
        self.data = data
        self.transform = transform
        self.samples = []

        for sample in self.data:
            img_data = []
            for file in sample['files']:
                img = nib.load(file)
                img_data.append(torch.tensor(img.get_fdata(), dtype=torch.float))
            img_data = torch.cat(img_data, dim=2)
            img_data = img_data.permute(2, 0, 1)
            label = torch.tensor(sample['label'])
            self.samples.append((img_data, label))

    def __len__(self):
        """
        Returns the length of the dataset.
        """
        return len(self.samples)
    
    def __getitem__(self, idx):
        """
        Returns the item at the given index.

        Parameters:
            - idx (int): Index of the item.

        Returns:
            - tuple: Tuple containing the 4D torch tensor and the label at the given index.
        """
        sample = self.samples[idx]
        if self.transform:
            sample = (self.transform(sample[0]), sample[1])
        return sample

In [7]:
# Split data into training and validation sets (80% train, 20% val) such that the train_data is less imbalanced
train_data, val_data = train_test_split(data, test_size=0.2, stratify=labels)

# balance the train_data classes by duplicating the samples of the minority class
minority_samples = [sample for sample in train_data if sample['label'] == 0]
train_data.extend(minority_samples)

In [8]:
mean = [73.42668914794922]
std = [288.2677307128906]

In [9]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=.5),
    transforms.RandomVerticalFlip(p=.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(.1, .1)),
    transforms.Normalize(mean=mean, std=std),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
])

# valid_transform = None
valid_transform = transforms.Compose([
    transforms.Normalize(mean=mean, std=std),
])

In [None]:
# train_dataset = NiftiDataset(train_data, transform=train_transform)
test_dataset = NiftiDataset(val_data, transform=valid_transform)

# create an augmented 5-fold dataset from train_data
train_dataset = NiftiDataset(train_data * 3, transform=train_transform)

In [None]:
batch_size = 8

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Check data shapes (optional)
for images, labels in train_loader:
    print(f"Image batch shape: {images.shape}, Label batch shape: {labels.shape}")
    break  # Just checking one batch

In [12]:
# Implementation of Dice Loss for binary classification
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        intersection = torch.sum(y_true * y_pred)
        union = torch.sum(y_true) + torch.sum(y_pred)
        dice = 2 * intersection / (union + 1e-6)
        return 1 - dice
    
# Implementation of Focal Loss for binary classification
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        bce = nn.BCELoss(reduction='none')(y_pred, y_true)
        alpha = self.alpha * y_true + (1 - self.alpha) * (1 - y_true)
        loss = alpha * (1 - y_pred) ** self.gamma * bce
        return loss.mean()

In [13]:
# proper combination of Dice and BCEwithLogits loss for binary classification
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, y_pred, y_true):
        # 2x weight for Dice loss
        return self.dice_loss(y_pred, y_true) + self.bce_loss(y_pred, y_true)

In [14]:
from sklearn.metrics import confusion_matrix

def train_model(model, train_loader, test_loader, num_epochs=100, lr=1e-4, device=None):
    """
    Fine-tunes the EfficientNet-B1 model on the given dataset.

    Parameters:
        model (torch.nn.Module): The modified EfficientNet-B1 model.
        train_loader (DataLoader): DataLoader for training set.
        test_loader (DataLoader): DataLoader for validation/test set.
        num_epochs (int): Number of training epochs.
        lr (float): Learning rate.
        device (torch.device): Device (CPU/GPU) to train on.

    Returns:
        model: The trained model.
    """
    # Ensure device is set
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    
    # Define loss function and optimizer
    criterion = CombinedLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    # LR Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)

    # Track best model (based on validation accuracy)
    best_acc = 0.0
    best_loss = 10000000000.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        # # --- TRAINING PHASE ---
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc="Training", leave=False):
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)  # Ensure correct shape

            optimizer.zero_grad()
            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            preds = (outputs >= 0).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            all_preds = []
            all_labels = []
            for images, labels in tqdm(test_loader, desc="Validation", leave=False):
                images, labels = images.to(device), labels.to(device).float().unsqueeze(1)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                preds = (outputs >= 0).float()
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                all_preds.extend(preds.squeeze().tolist())
                all_labels.extend(labels.squeeze().tolist())

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Compute confusion matrix
        cm = confusion_matrix(all_labels, all_preds)
        print("Confusion Matrix:")
        print(cm)


        # Print log
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4%}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4%}")

        # Update learning rate
        scheduler.step()

        # Save best model
        if val_acc > best_acc:
                best_acc = val_acc
                best_loss = val_loss
                # Save the trained model
                output_dir = "./output_models/"
                model_path = os.path.join(output_dir, "efficientnet_b1_trained_2.pth")
                torch.save(model.state_dict(), model_path)
                print(f"Model saved to {model_path}")
        elif val_acc == best_acc and val_loss <= best_loss:
                best_loss = val_loss
                # Save the trained model
                output_dir = "./output_models/"
                model_path = os.path.join(output_dir, "efficientnet_b1_trained_2.pth")
                torch.save(model.state_dict(), model_path)
                print(f"Model saved to {model_path}")

    print(f"\nBest Validation Accuracy: {best_acc:.4%}")
    
    return model

In [None]:
# Initialize the model
model = EfficientNetBN(model_name="efficientnet-b0", spatial_dims=2, in_channels=112, num_classes=1, pretrained=True)
# model = Densenet121(spatial_dims=2, in_channels=112, out_channels=1, pretrained=True)
# model.load_state_dict(torch.load("./output_models/efficientnet_b0_trained_750.pth"))

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded to device")

In [None]:
model = train_model(model, train_loader, test_loader, num_epochs=30, lr=2e-4)