In [None]:
# import gdown

# file_id = "1MHw4Ufxck5NuHlSMi1PUfiXQhy0vcKOE"
# url = f"https://drive.google.com/uc?id={file_id}"
# gdown.download(url, quiet=False)

In [None]:
import re
import os
import math
import torch
import shutil
import random
import numpy as np
import pandas as pd
import nibabel as nib
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.models as models
import torchvision.transforms.v2 as transforms

from tqdm import tqdm
from scipy.ndimage import zoom
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

In [None]:
def remove_black_slices(nifti_file):
    """Removes black slices (slices with all pixel values equal to 0) from a NIfTI file.

    Args:
        nifti_file: Path to the NIfTI file.

    Returns:
        The path to the modified NIfTI file (or None if an error occurs).
    """
    try:
        img = nib.load(nifti_file)
        img_data = img.get_fdata()

        #Calculate sum of all pixels in each slice
        slice_sums = np.sum(img_data, axis=(0, 1))

        #Identify slices to keep (non-zero sum)
        slices_to_keep = np.where(slice_sums != 0)[0]

        #Extract those slices
        new_data = img_data[:, :, slices_to_keep]

        #Create a new NIfTI image with updated data
        new_affine = img.affine.copy()
        #Modify affine to reflect changes in z-axis dimension
        new_affine[2,3] = new_affine[2,3] + slices_to_keep[0] # Or appropriate adjustment


        new_img = nib.Nifti1Image(new_data, new_affine, header=img.header)
        #Update header's dimensions
        new_img.header.set_data_shape(new_data.shape)

        #Save new image with filename modified to show removal of slices
        new_nifti_file = nifti_file.replace('.nii.gz', '_no_black_slices.nii.gz').replace('skullstripped_nifti','final_data')
        nib.save(new_img, new_nifti_file)

        return new_nifti_file
    except Exception as e:
        print(f"Error processing {nifti_file}: {e}")
        return None

In [None]:
skullstripped_nifti_dir = './skullstripped_nifti'

for file_list, label_list in [ (train_files, train_labels), (test_files, test_labels)]:
    for i, nifti_file in enumerate(file_list):
      if os.path.exists(nifti_file):
          new_nifti = remove_black_slices(nifti_file)
          if new_nifti:

              file_list[i] = new_nifti
      else:
        print(f"File not found: {nifti_file}")

In [None]:
def decide_T1_T2(file_pair):
    """
    Given a list of 2 file names, returns a tuple (T1, T2, ambiguous)
    where T1 and T2 are the chosen file names (in that order) and
    ambiguous is a boolean that is True if the decision was ambiguous.
    """
    # Keywords that suggest T1 and T2 characteristics.
    t1_keywords = ['t1', 'spgr', 'flash', 'bravo', 'gre', 't1post', 't1flash3d']
    t2_keywords = ['t2', 'fse', 't2fse', 'ax_t2', 'tset2rst', 't2_rst']

    def is_t1(fname):
        lower = fname.lower()
        return any(kw in lower for kw in t1_keywords)
    
    def is_t2(fname):
        lower = fname.lower()
        return any(kw in lower for kw in t2_keywords)

    if len(file_pair) != 2:
        raise ValueError("The function expects a list of exactly two file names.")

    f1, f2 = file_pair[0], file_pair[1]
    f1_t1 = is_t1(f1)
    f1_t2 = is_t2(f1)
    f2_t1 = is_t1(f2)
    f2_t2 = is_t2(f2)

    # Case 1: One file clearly is T1 and the other clearly is T2.
    if f1_t1 and f2_t2:
        return (f1, f2, False)
    if f1_t2 and f2_t1:
        return (f2, f1, False)
    
    # Case 2: Only one file shows a clear T1 signature.
    if f1_t1 and not f2_t1:
        return (f1, f2, False)  # f1 is T1, assume f2 must be T2
    if f2_t1 and not f1_t1:
        return (f2, f1, False)  # f2 is T1, so f1 is T2
    
    # Case 3: Only one file shows a clear T2 signature.
    if f1_t2 and not f2_t2:
        return (f2, f1, False)  # f1 is T2, so f2 becomes T1
    if f2_t2 and not f1_t2:
        return (f1, f2, False)  # f2 is T2, so f1 becomes T1
    
    # Case 4: Ambiguous situation (both files either match both or none).
    # In that case, we default to assuming the first is T1 and the second is T2,
    # and set the ambiguous flag to True.
    return (f1, f2, True)

In [None]:
def group_by_subject(file_paths):
    """
    Groups file paths by subject using the LGG_id (e.g., 'LGG-104') found in the filename.
    Returns a list of tuples, where each tuple contains all files for one subject.
    """
    # Create a dictionary to group file paths by subject id.
    groups = defaultdict(list)
    # Regular expression pattern to capture the subject id (e.g., "LGG-104")
    subject_pattern = re.compile(r'LGG-\d+')
    
    for path in file_paths:
        match = subject_pattern.search(path)
        if match:
            subject_id = match.group()
            groups[subject_id].append(path)
        else:
            print(f"Warning: Subject ID not found in {path}")
    
    # Convert each group (list) into a tuple, and return a list of tuples.
    return [tuple(paths) for paths in groups.values()]

subject_files = group_by_subject(file_paths)

In [None]:
T1_images = []
T2_images = []
ambiguous = []

for subject in subject_files:
    T1, T2, amb = decide_T1_T2(subject)
    if not amb:
        T1_images.append(T1)
        T2_images.append(T2)
    else:
        ambiguous.append(subject)

In [None]:
len(T1_images), len(T2_images), len(ambiguous)

In [None]:
def rename_files(file_paths, tag):
    """
    Copy files in the given list to the grouped_data directory so that the names follow the following pattern: "LGG_id_tag.nii.gz"
    """
    for file in file_paths:
        # Extract the subject ID from the filename.
        subject_id = re.search(r'LGG-\d+', file).group()
        # Construct the new filename.
        new_name = f"./grouped_data/{subject_id}_{tag}.nii.gz"
        # Copy the files using shutil
        shutil.copy(file, new_name)


In [None]:
rename_files(T1_images, "T1")
rename_files(T2_images, "T2")

In [None]:
def show_all_slices(subject_name, nifti_dir='./final_data'):
    for filename in os.listdir(nifti_dir):
        if filename.endswith(".nii.gz") and subject_name in filename and 'tumor' not in filename.lower():
            nifti_file = os.path.join(nifti_dir, filename)
            break
    try:
        img = nib.load(nifti_file)
        img_data = img.get_fdata()
        print(img_data.shape)
        num_slices = img_data.shape[2]

        fig, axes = plt.subplots(int(num_slices**0.5), int(num_slices**0.5), figsize=(15, 15))
        fig.suptitle(f"Slices of Subject: {subject_name}")

        for i in range(num_slices):
            row = i // int(num_slices**0.5)
            col = i % int(num_slices**0.5)

            axes[row, col].imshow(img_data[:, :, i], cmap='gray')
            axes[row, col].set_title(f"Slice {i + 1}")
            axes[row, col].axis('off')

        for i in range(num_slices, int(num_slices**0.5) * int(num_slices**0.5)):
            row = i // int(num_slices**0.5)
            col = i % int(num_slices**0.5)
            axes[row, col].axis('off')

        plt.tight_layout()
        plt.show()

    except Exception as e:
        print(f"Error processing {nifti_file}: {e}")
        return None
    
show_all_slices('LGG-104_4.000000-Gad_Ax_T2_Straight-38151')

In [None]:
excel_file = './TCIA_LGG_cases_159.xlsx'
df = pd.read_excel(excel_file)

file_label_map = {}
for index, row in df.iterrows():
    filename = row['Filename']
    label = row['1p/19q']
    if isinstance(filename, str) and isinstance(label, str):
      file_label_map[filename] = label

file_paths = []
labels = []

nifti_dir = './final_data'
for filename in os.listdir(nifti_dir):
    if filename.endswith(".nii.gz") and 'tumor' not in filename.lower():
        base_name = filename.split("_")[2].split(".")[0]
        if base_name in file_label_map:
            label = file_label_map[base_name]
            file_paths.append(os.path.join(nifti_dir, filename))
            labels.append(label)
        else:
            print(f"File: {filename}, Label: Not found in Excel file - skipping")


train_files, test_files, train_labels, test_labels = train_test_split(file_paths, labels, test_size=0.2, random_state=42)

print("Train data examples:")
for i in range(min(5, len(train_files))):
    print(f"File: {train_files[i]}, Label: {train_labels[i]}")

print("\nTest data examples:")
for i in range(min(5, len(test_files))):
    print(f"File: {test_files[i]}, Label: {test_labels[i]}")

In [None]:
def convert_to_boolean(input_list):
    return [0 if element == 'n/n' else 1 for element in input_list]

train_labels = convert_to_boolean(train_labels)
test_labels = convert_to_boolean(test_labels)

In [None]:
import nibabel as nib
import numpy as np

def analyze_nifti_slices(file_paths, bin_size=10):
    """
    Analyzes the number of slices in a list of NIfTI files and bins the slice counts into partitions of size `bin_size`.

    Parameters:
        file_paths (list): List of file paths to NIfTI images.
        bin_size (int): The size of each bin (default: 10).

    Returns:
        dict: A dictionary with:
            - "lowest": The smallest number of slices.
            - "highest": The largest number of slices.
            - "median": The median number of slices.
            - "average": The average number of slices.
            - "variance": The variance of slice counts.
            - "binned_counts": A dictionary where keys are bin ranges and values are the count of images in that range.
    """
    slice_counts = []

    for file_path in file_paths:
        nii = nib.load(file_path)
        slices = nii.shape[2]  # Extract the number of slices (Z-dimension)
        slice_counts.append(slices)

    slice_counts = np.array(slice_counts)

    # Compute statistics
    stats = {
        "lowest": int(np.min(slice_counts)),
        "highest": int(np.max(slice_counts)),
        "median": int(np.median(slice_counts)),
        "average": float(np.mean(slice_counts)),
        "variance": float(np.var(slice_counts)),
    }
    
    bins = defaultdict(int)
    
    for count in slice_counts:
        bin_start = (count // bin_size) * bin_size
        bin_end = bin_start + bin_size - 1
        bin_range = f"{bin_start}-{bin_end}"
        bins[bin_range] += 1

    stats["binned_counts"] = dict(sorted(bins.items(), key=lambda x: int(x[0].split('-')[0])))

    return stats

analyze_nifti_slices(train_files)

In [None]:
def duplicate_slices(data, target_slices):
    """
    If the image has less than half of the target slices, duplicate slices to improve interpolation quality.
    
    Parameters:
        data (numpy array): The 3D MRI data.
        target_slices (int): The desired number of slices.
    
    Returns:
        numpy array: The expanded data before interpolation.
    """
    current_slices = data.shape[2]

    if current_slices < target_slices / 2:
        # Determine duplication factor
        factor = int(target_slices // current_slices)
        expanded_data = np.repeat(data, factor, axis=2)

        # If still below the target, add one more duplication pass
        while expanded_data.shape[2] < target_slices / 2:
            expanded_data = np.repeat(expanded_data, 2, axis=2)

        print(f"Duplicated slices from {current_slices} to {expanded_data.shape[2]} before interpolation.")
        return expanded_data
    return data

In [None]:
def interpolate_nifti_images(file_paths, output_dir, target_slices=64, order=3):
    """
    Interpolates NIfTI images to have a uniform number of slices, handling very thin images with duplication first.
    
    Parameters:
        file_paths (list): List of NIfTI file paths.
        output_dir (str): Directory to save processed images.
        target_slices (int): Desired number of slices.
        order (int): Interpolation order (default is 3 for cubic).
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for file_path in file_paths:
        nii = nib.load(file_path)
        data = nii.get_fdata()
        original_slices = data.shape[2]

        # Handle very thin images by duplicating slices before interpolation
        data = duplicate_slices(data, target_slices)

        # Compute scale factors for interpolation
        scale_factors = (1, 1, target_slices / data.shape[2])

        # Apply interpolation
        resampled_data = zoom(data, scale_factors, order=order)

        # Save the new NIfTI file
        new_nii = nib.Nifti1Image(resampled_data, affine=nii.affine, header=nii.header)
        output_path = os.path.join(output_dir, os.path.basename(file_path))
        nib.save(new_nii, output_path)

        print(f"Processed: {file_path} -> {output_path} (Initial slices: {original_slices}, Target slices: {target_slices})")


In [None]:
interpolate_nifti_images(test_files, './interpolated_data/', target_slices=56, order=3)

In [None]:
interpolate_nifti_images(train_files, './interpolated_data/', target_slices=56, order=3)

In [None]:
excel_file = './TCIA_LGG_cases_159.xlsx'
df = pd.read_excel(excel_file)

file_label_map = {}
for index, row in df.iterrows():
    filename = row['Filename']
    label = row['1p/19q']
    if isinstance(filename, str) and isinstance(label, str):
      file_label_map[filename] = label

file_paths = []
labels = []

nifti_dir = './interpolated_data'
for filename in os.listdir(nifti_dir):
    if filename.endswith(".nii.gz") and 'tumor' not in filename.lower():
        base_name = filename.split("_")[2].split(".")[0]
        if base_name in file_label_map:
            label = file_label_map[base_name]
            file_paths.append(os.path.join(nifti_dir, filename))
            labels.append(label)
        else:
            print(f"File: {filename}, Label: Not found in Excel file - skipping")


train_files, test_files, train_labels, test_labels = train_test_split(file_paths, labels, test_size=0.2, random_state=42)

print("Train data examples:")
for i in range(min(5, len(train_files))):
    print(f"File: {train_files[i]}, Label: {train_labels[i]}")

print("\nTest data examples:")
for i in range(min(5, len(test_files))):
    print(f"File: {test_files[i]}, Label: {test_labels[i]}")

In [None]:
def convert_to_boolean(input_list):
    return [0 if element == 'n/n' else 1 for element in input_list]

train_labels = convert_to_boolean(train_labels)
test_labels = convert_to_boolean(test_labels)

In [None]:
def show_all_slices(subject_name, nifti_dir='./interpolated_data'):
    for filename in os.listdir(nifti_dir):
        if filename.endswith(".nii.gz") and subject_name in filename and 'tumor' not in filename.lower():
            nifti_file = os.path.join(nifti_dir, filename)
            break
    try:
        img = nib.load(nifti_file)
        img_data = img.get_fdata()
        print(img_data.shape)
        num_slices = img_data.shape[2]

        fig, axes = plt.subplots(int(num_slices**0.5), int(num_slices**0.5), figsize=(15, 15))
        fig.suptitle(f"Slices of Subject: {subject_name}")

        for i in range(num_slices):
            row = i // int(num_slices**0.5)
            col = i % int(num_slices**0.5)

            axes[row, col].imshow(img_data[:, :, i], cmap='gray')
            axes[row, col].set_title(f"Slice {i + 1}")
            axes[row, col].axis('off')

        for i in range(num_slices, int(num_slices**0.5) * int(num_slices**0.5)):
            row = i // int(num_slices**0.5)
            col = i % int(num_slices**0.5)
            axes[row, col].axis('off')

        plt.tight_layout()
        plt.show()

    except Exception as e:
        print(f"Error processing {nifti_file}: {e}")
        return None
    
show_all_slices('LGG-558_3.000000-axial_FSE-03138_no_black')

In [None]:
class GaussianNoise(torch.nn.Module):
    """
    Custom transform to add Gaussian noise to an image.
    """
    def __init__(self, mean=0.0, std=0.05):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean

In [None]:
class NiftiDataset(Dataset):
    """
    PyTorch Dataset for handling NIfTI images and their corresponding labels.
    """

    def __init__(self, file_paths, labels, channel_mode="single", transform=None):
        """
        Initializes the dataset.

        Parameters:
            file_paths (list): List of file paths to NIfTI images.
            labels (list): List of labels corresponding to each file.
            channel_mode (str): "single" for (1, H, W, D) or "multi" for (D, H, W).
            transform (callable, optional): Transformations to apply to the data.
        """
        assert channel_mode in ["single", "multi"], "channel_mode must be 'single' or 'multi'"
        self.file_paths = file_paths
        self.labels = labels
        self.channel_mode = channel_mode
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def load_nifti(self, file_path):
        """
        Loads a NIfTI file and interpolates it to the target number of slices.
        """
        nii = nib.load(file_path)
        data = nii.get_fdata()

        # Normalize the data (min-max scaling)
        data = (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-8)

        # Convert to PyTorch tensor
        tensor_data = torch.tensor(data, dtype=torch.float32)

        if self.channel_mode == "single":
            # Option 1: Keep a single channel (for 3D CNNs)
            tensor_data = tensor_data.unsqueeze(0)  # Shape: (1, H, W, D)
        else:
            # Option 2: Treat slices as channels (for 2D CNNs)
            tensor_data = tensor_data.permute(2, 0, 1)  # Shape: (D, H, W)

        return tensor_data

    def __getitem__(self, idx):
        """
        Returns the image tensor and label for a given index.
        """
        file_path = self.file_paths[idx]
        label = self.labels[idx]

        # Load and preprocess the NIfTI image
        image = self.load_nifti(file_path)

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

In [None]:
def calculate_mean_std(dataset):
    """
    Calculate the mean and standard deviation of the images in the dataset.

    Parameters:
        dataset (Dataset): The dataset containing the images.

    Returns:
        mean (float): The mean of the pixel values.
        std (float): The standard deviation of the pixel values.
    """
    mean = 0.0
    std = 0.0
    num_samples = 0

    for image, _ in tqdm(dataset, desc="Calculating mean and std"):
        # Sum up the pixel values and the squared pixel values
        mean += image.mean().item()
        std += image.std().item()
        num_samples += image.size(0)

    mean /= num_samples
    std /= num_samples

    return mean, std


train_dataset = NiftiDataset(train_files, train_labels, channel_mode="single")
mean, std = calculate_mean_std(train_dataset)
print(f"Mean: {mean}, Std: {std}")

In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    # GaussianNoise(std=0.05),
    transforms.Normalize(mean=[mean,], std=[std]),
])

# valid_transform = None
valid_transform = transforms.Compose([
    # transforms.Normalize(mean=[mean], std=[std]),
])

In [None]:
train_dataset = NiftiDataset(train_files, train_labels, channel_mode="multi", transform=train_transform)
test_dataset = NiftiDataset(test_files, test_labels, channel_mode="multi", transform=valid_transform)

In [None]:
batch_size = 8

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Check data shapes (optional)
for images, labels in train_loader:
    print(f"Image batch shape: {images.shape}, Label batch shape: {labels.shape}")
    break  # Just checking one batch

In [None]:
class CustomEfficientNet(nn.Module):
    """
    EfficientNet-B1 with an additional 56 → 64 projection layer for MRI slices.
    """

    def __init__(self, num_classes=2, input_channels=56, binary=True):
        super(CustomEfficientNet, self).__init__()

        # Load pretrained EfficientNet-B1
        self.efficientnet = models.efficientnet_b1(pretrained=True)

        # Modify EfficientNet's first convolutional layer to match 64 → 32 transition
        self.efficientnet.features[0][0] = nn.Conv2d(
            in_channels=56,  # Match the new projection layer
            out_channels=32,  # EfficientNet's expected out_channels
            kernel_size=3,
            stride=2,
            padding=1,
            bias=False
        )

        # Modify the classifier for binary/multi-class classification
        num_ftrs = self.efficientnet.classifier[1].in_features

        if binary:
            self.efficientnet.classifier[1] = nn.Sequential(
                nn.Dropout(0.3),  # Regularization to prevent overfitting
                nn.Linear(num_ftrs, 1)  # Binary classification (BCEWithLogitsLoss)
            )
        else:
            self.efficientnet.classifier[1] = nn.Sequential(
                nn.Dropout(0.3),
                nn.Linear(num_ftrs, num_classes)
            )

    def forward(self, x):
        x = self.efficientnet(x)  # Pass through EfficientNet-B1
        return x

In [None]:
def train_model(model, train_loader, test_loader, num_epochs=80, device=None):
    """
    Fine-tunes the CustomEfficientNet model on the given dataset.

    For the optimizer:
      - The first convolution (model.efficientnet.features[0][0]) uses a base lr=5e-4 
        which is scheduled (cosine annealed) over 80 epochs from 5e-4 down to 1e-4.
      - The rest of the parameters have a base lr=1e-4 but are effectively set to:
            * 1e-5 for the first 20 epochs (i.e. multiplier=0.1)
            * At epoch 20, the lr “jumps” to 1e-4, and then a cosine annealing schedule
              (over the remaining 60 epochs) brings it down to 1e-5 by epoch 80.
    
    Parameters:
        model (torch.nn.Module): The modified EfficientNet-B3 model.
        train_loader (DataLoader): DataLoader for training set.
        test_loader (DataLoader): DataLoader for validation/test set.
        num_epochs (int): Number of training epochs (should be 80).
        device (torch.device): Device (CPU/GPU) to train on.

    Returns:
        model: The trained model.
    """
    # Set device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # -------------------------------
    # Setup parameter groups for the optimizer:
    #   Group 1: First conv layer (features[0][0])
    #   Group 2: All the rest of the parameters.
    # -------------------------------
    first_conv_params = list(model.efficientnet.features[0][0].parameters())
    # Compare parameters by their id to avoid tensor equality issues.
    first_conv_ids = set(map(id, first_conv_params))
    other_params = [p for p in model.parameters() if id(p) not in first_conv_ids]

    # For Group 1 we set base lr=5e-4.
    # For Group 2 we set base lr=1e-4 (we will apply a multiplier so that for the first 20 epochs
    # the effective lr is 1e-4 * 0.1 = 1e-5).
    optimizer = optim.AdamW([
        {"params": first_conv_params, "lr": 5e-4},
        {"params": other_params, "lr": 1e-4}
    ], weight_decay=1e-5)
    
    # -------------------------------
    # Define custom lambda functions for each parameter group.
    #
    # Group 1 (first conv layer):
    #   Cosine annealing schedule over all 80 epochs from 5e-4 to 1e-4.
    #   Effective lr = 1e-4 + 0.5*(5e-4 - 1e-4)*(1 + cos(pi * epoch / 80)).
    #   Multiplier (relative to base lr of 5e-4) is:
    #       multiplier = (1e-4 + 0.5*(4e-4)*(1+cos(pi*epoch/80)))/5e-4.
    #
    # Group 2 (all other layers):
    #   For epochs 0–19, effective lr = 1e-4 * 0.1 = 1e-5.
    #   For epochs 20–80, cosine annealing over 60 epochs from 1e-4 to 1e-5:
    #       effective_lr = 1e-5 + 0.5*(1e-4-1e-5)*(1 + cos(pi * (epoch-20) / 60)).
    #   Dividing by the base lr (1e-4) gives the multiplier.
    # -------------------------------
    def lambda1(epoch):
        # For group 1: from 5e-4 (epoch 0) to 1e-4 (epoch 80)
        return (1e-4 + 0.5*(5e-4 - 1e-4) * (1 + math.cos(math.pi * epoch / 80))) / 5e-4

    def lambda2(epoch):
        # For group 2:
        if epoch < 20:
            return 0.1  # Effective lr = 1e-4 * 0.1 = 1e-5.
        else:
            t = epoch - 20
            return (1e-5 + 0.5 * (1e-4 - 1e-5) * (1 + math.cos(math.pi * t / 60))) / 1e-4

    # Create the scheduler with a list of lambda functions (one per parameter group)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
    
    # Loss function (using BCEWithLogitsLoss for binary classification)
    criterion = nn.BCEWithLogitsLoss()

    # For tracking the best model (by validation accuracy)
    best_acc = 0.0
    best_model_wts = None

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Optionally, print the current effective learning rates:
        current_lrs = scheduler.get_last_lr()
        print(f"Current effective LRs: Group1 (first conv): {current_lrs[0]:.2e}, Group2 (others): {current_lrs[1]:.2e}")

        # --- TRAINING PHASE ---
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc="Training", leave=False):
            images = images.to(device)
            # For binary classification, ensure labels are floats and have shape [batch, 1]
            labels = labels.to(device).float().unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            preds = torch.sigmoid(outputs).round()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        
        train_loss = running_loss / total
        train_acc = correct / total

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc="Validation", leave=False):
                images = images.to(device)
                labels = labels.to(device).float().unsqueeze(1)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                preds = torch.sigmoid(outputs).round()
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss /= val_total
        val_acc = val_correct / val_total

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4%}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4%}")

        # Step the scheduler (updates both parameter groups)
        scheduler.step()

        # Save the best model based on validation accuracy.
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = model.state_dict()

    print(f"\nBest Validation Accuracy: {best_acc:.4%}")

    # Load best model weights (if available)
    if best_model_wts is not None:
        model.load_state_dict(best_model_wts)

    return model

In [None]:
# def train_model(model, train_loader, test_loader, num_epochs=20, lr=1e-4, device=None):
#     """
#     Fine-tunes the EfficientNet-B1 model on the given dataset.

#     Parameters:
#         model (torch.nn.Module): The modified EfficientNet-B1 model.
#         train_loader (DataLoader): DataLoader for training set.
#         test_loader (DataLoader): DataLoader for validation/test set.
#         num_epochs (int): Number of training epochs.
#         lr (float): Learning rate.
#         device (torch.device): Device (CPU/GPU) to train on.

#     Returns:
#         model: The trained model.
#     """
#     # Ensure device is set
#     if device is None:
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     model = model.to(device)
    
#     # Define loss function and optimizer
#     criterion = nn.BCEWithLogitsLoss()  # For binary classification
#     optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

#     # LR Scheduler
#     scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)

#     # Track best model (based on validation accuracy)
#     best_acc = 0.0
#     best_model_wts = None

#     for epoch in range(num_epochs):
#         print(f"\nEpoch {epoch+1}/{num_epochs}")

#         # --- TRAINING PHASE ---
#         model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         for images, labels in tqdm(train_loader, desc="Training", leave=False):
#             images, labels = images.to(device), labels.to(device).float().unsqueeze(1)  # Ensure correct shape

#             optimizer.zero_grad()
#             outputs = model(images)  # Forward pass
#             loss = criterion(outputs, labels)  # Compute loss

#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item() * images.size(0)
#             preds = torch.sigmoid(outputs).round()  # Convert logits to 0/1 predictions
#             correct += (preds == labels).sum().item()
#             total += labels.size(0)

#         train_loss = running_loss / total
#         train_acc = correct / total

#         # --- VALIDATION PHASE ---
#         model.eval()
#         val_loss = 0.0
#         val_correct = 0
#         val_total = 0

#         with torch.no_grad():
#             for images, labels in tqdm(test_loader, desc="Validation", leave=False):
#                 images, labels = images.to(device), labels.to(device).float().unsqueeze(1)

#                 outputs = model(images)
#                 loss = criterion(outputs, labels)

#                 val_loss += loss.item() * images.size(0)
#                 preds = torch.sigmoid(outputs).round()
#                 val_correct += (preds == labels).sum().item()
#                 val_total += labels.size(0)

#         val_loss /= val_total
#         val_acc = val_correct / val_total

#         # Print log
#         print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4%}")
#         print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4%}")

#         # Update learning rate
#         scheduler.step()

#         # Save best model
#         if val_acc > best_acc:
#             best_acc = val_acc
#             best_model_wts = model.state_dict()

#     print(f"\nBest Validation Accuracy: {best_acc:.4%}")
    
#     # Load best model weights
#     if best_model_wts is not None:
#         model.load_state_dict(best_model_wts)
    
#     return model

In [None]:
num_classes = 2  # Tumor vs. no tumor
input_channels = 56  # Number of MRI slices (D)
binary_classification = True  # Set to True for binary classification

model = CustomEfficientNet(num_classes, input_channels, binary_classification)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Print model summary
print(model)

In [None]:
trained_model = train_model(model, train_loader, test_loader, num_epochs=80)
# Create the output directory if it doesn't exist
output_dir = "./output_models/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save the trained model
model_path = os.path.join(output_dir, "efficientnet_b1_trained.pth")
torch.save(trained_model.state_dict(), model_path)
print(f"Model saved to {model_path}")