In [None]:
# Mount to Google Drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
FOLDERNAME = "poster"
%cd drive/MyDrive/$FOLDERNAME

Mounted at /content/drive
/content/drive/MyDrive/poster


In [None]:
# Load and split CSV
import pandas as pd
def load_and_split_dataset(csv_file):
    df = pd.read_csv(csv_file)
    scan_labels = df.set_index('id')['scan-level label']
    scan_train_data = df[df['group'] == 'Train'].reset_index(drop=True)
    scan_val_data = df[df['group'] == 'Valid'].reset_index(drop=True)
    return scan_labels, scan_train_data, scan_val_data


In [None]:
# Transforms
import torchvision.transforms as T
SIZE = 224

transform = T.Compose([
    T.Resize((SIZE, SIZE)),  # Resize to 224x224 for ResNet
    T.RandomRotation(degrees=(-10, 10)),  # Small rotation
    T.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    T.RandomVerticalFlip(p=0.5),  # Random vertical flip
    # T.Normalize(mean=[0.5], std=[0.5])  # Normalize for better training stability
])


In [None]:
from typing_extensions import final
import torchvision.transforms as T
import pandas as pd
import numpy as np
import nibabel as nib
import os
import cv2
from torch.utils.data import Dataset, DataLoader

class ScanDataset(Dataset):
    def __init__(self, folder_path, data_df, scan_labels, num_slices=50, transform=None):
        """
        Args:
            folder_path (str): Path to the folder containing NIfTI scans.
            data_df (DataFrame): DataFrame with scan IDs.
            scan_labels (dict): Mapping from scan IDs to scan-level labels.
            target_size (tuple): Target size for resizing slices (H, W, D).
            transform (callable, optional): Optional TorchIO transforms.
        """
        self.folder_path = folder_path
        self.scan_labels = scan_labels
        self.num_slices = num_slices
        self.transform = transform
        self.scan_ids = data_df['id'].tolist()

    def __len__(self):
        return len(self.scan_ids)

    def __getitem__(self, idx):
        scan_id = self.scan_ids[idx]
        scan_file = os.path.join(self.folder_path, f"{scan_id}.nii.gz")

        # Load scan and mask
        img = nib.load(scan_file)

        data_obj = img.dataobj  # Lazy loading (Do NOT use get_fdata())

        # 獲取影像尺寸
        H, W, D = img.shape

        # 設定適當的 HU window（例如腹部 CT 常見 -100 ~ 400）
        window_min, window_max = -100, 400

        # 計算切片起始索引
        if D < 60:
          start_index = round(D * 0.05)
        elif D > 130:
          start_index = round(D * 0.2)
        else:
          start_index = round(D * 0.1)

        start_index = max(0, start_index)
        end_index = min(start_index + self.num_slices, D)

        if (end_index - start_index) != self.num_slices:
          start_index = max(0, end_index - self.num_slices)
          end_index = min(start_index + self.num_slices, D)

        # 只讀取需要的切片
        data_slice = np.array(data_obj[:, :, start_index:end_index], dtype=np.float32)

        # 套用 HU window
        data_slice = np.clip(data_slice, window_min, window_max)
        data_slice = (data_slice - window_min) / (window_max - window_min)
        # data_slice.shape (512, 512, 50)

        # Reshape each slice into (512, 512, 1)
        slices = np.expand_dims(data_slice, axis=-1)  # New shape: (512, 512, 50, 1)

        # Move slices into batch dimension: (50, 512, 512, 1) → (50, 1, 512, 512)
        slices = np.moveaxis(slices, -2, 0)  # Now: (50, 512, 512, 1)
        slices = np.transpose(slices, (0, 3, 1, 2))  # Final shape: (50, 1, 512, 512)
        # slices.shape (50, 1, 512, 512)

        # Convert to PyTorch tensor
        tensor = torch.tensor(slices, dtype=torch.float32)
        # tensor.shape torch.Size([50, 1, 512, 512])

        # transform
        if self.transform:
          transformed_slices = torch.stack([transform(tensor[i]) for i in range(tensor.shape[0])])
          # transformed_slices.shape torch.Size([50, 1, 224, 224])

        # Get scan-level label
        scan_label = torch.tensor(self.scan_labels[scan_id])

        return transformed_slices, scan_label

In [None]:
# Define device
import torch
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
print('Device:', device)

Device: cuda


In [None]:
csv_path = "dataset/TrainValid_split.csv"
scan_labels, scan_train_data, scan_val_data = load_and_split_dataset(csv_path)

In [None]:
# Create datasets
folder_path = "dataset/1_Train,Valid_Image"
scan_train_data = ScanDataset(folder_path=folder_path, data_df=scan_train_data, scan_labels=scan_labels, transform=transform)
scan_val_data = ScanDataset(folder_path=folder_path, data_df=scan_val_data, scan_labels=scan_labels, transform=transform)

In [None]:
scan_train_data[1][0].shape

torch.Size([50, 1, 224, 224])

In [None]:
scan_val_data[1][0].shape

torch.Size([50, 1, 224, 224])

In [None]:
NUM_SCAN_TRAIN = len(scan_train_data)
NUM_SCAN_VAL = len(scan_val_data)
print(f"Train dataset size: {NUM_SCAN_TRAIN}")
print(f"Val dataset size: {NUM_SCAN_VAL}")

Train dataset size: 800
Val dataset size: 200


In [None]:
# Create DataLoaders
BATCH = 8
mini_scan_trains = DataLoader(scan_train_data, batch_size=BATCH, shuffle=True)
mini_scan_vals = DataLoader(scan_val_data, batch_size=BATCH, shuffle=False)
print(mini_scan_trains)
print(mini_scan_vals)

<torch.utils.data.dataloader.DataLoader object at 0x7df8b21aca10>
<torch.utils.data.dataloader.DataLoader object at 0x7df8bd0674d0>


In [None]:
x, y = next(iter(mini_scan_trains))
print(x.shape, y.shape)

torch.Size([8, 50, 1, 224, 224]) torch.Size([8])


In [None]:
y

tensor([1, 0, 0, 1, 1, 0, 0, 1])

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class CNN_LSTM(nn.Module):
    def __init__(self):
        super().__init__()

        # Feature extractor: ResNet-18 (Modified for grayscale input)
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 1-channel input
        num_flatten = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_flatten, 64)  # Feature output: 64D

        # LSTM for temporal modeling
        self.lstm = nn.LSTM(input_size=64, hidden_size=256, batch_first=True)

        # Final classification layer
        self.fc = nn.Linear(256, 2)  # 2 classes

    def forward(self, x):
        """
        x: (Batch, Slices, Channels, Height, Width) -> (N, F, C, H, W)
        """
        N, F, C, H, W = x.shape

        features = []  # Store ResNet output for each frame

        for i in range(F):
            frame = x[:, i, :, :, :]  # (N, C, H, W) for frame i
            out = self.resnet(frame)  # (N, 64)
            features.append(out.unsqueeze(1))  # (N, 1, 64)

        # Concatenate all frame features → (N, F, 64)
        out = torch.cat(features, dim=1)

        # Pass through LSTM
        output, (h_n, c_n) = self.lstm(out)

        # Use the last LSTM output (N, 256)
        out = output[:, -1, :]

        # Final classification layer
        out = self.fc(out)  # (N, 2)

        return out


In [None]:
model = CNN_LSTM()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 226MB/s]


In [None]:
model = model.to(device)

In [None]:
import torch.optim as optim
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
NUM_EPOCHS = 10
PRINT_EVERY = 10

In [None]:
def val(mini_scan_vals, model, device):
  model.eval()
  with torch.no_grad():
    total = 0
    for x, y in mini_scan_vals:
      x = x.to(device)
      y = y.to(device)
      scores = model(x)
      predictions = scores.argmax(axis=1)
      acc = predictions.eq(y).sum().item()
      total += acc
    val_acc = total / NUM_SCAN_VAL
    print('Val Acc:', val_acc)
    return val_acc

In [None]:
TRAIN_LOSS = []
VAL_ACC_LIST = []
BATCHES = []

def train(mini_scan_trains, model, loss_function, optimizer, device, mini_vals):
  global TRAIN_LOSS, VAL_ACC_LIST, BATCHES
  batch_count = 0
  for epoch in range(NUM_EPOCHS):
    for count, (x, y) in enumerate(mini_scan_trains):
      model.train()
      x = x.to(device)
      y = y.to(device)
      scores = model(x)
      loss = loss_function(scores, y)
      TRAIN_LOSS.append(loss.item())
      BATCHES.append(batch_count)
      if count % PRINT_EVERY == 0:
        print('Training loss:', loss.item(), end = ' / ')
        val_acc = val(mini_scan_vals, model, device)
        VAL_ACC_LIST.append(val_acc)
      batch_count += 1
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

In [None]:
train(mini_scan_trains, model, loss_function, optimizer, device, mini_scan_vals)

Training loss: 0.6910021305084229 / Val Acc: 0.51
Training loss: 0.7572619318962097 / Val Acc: 0.58
Training loss: 0.6188276410102844 / Val Acc: 0.565
Training loss: 0.6811545491218567 / Val Acc: 0.62
Training loss: 0.5175712704658508 / Val Acc: 0.67
Training loss: 0.552336573600769 / Val Acc: 0.71
Training loss: 0.36247357726097107 / Val Acc: 0.695
Training loss: 0.528309166431427 / Val Acc: 0.635
Training loss: 0.7364122867584229 / Val Acc: 0.675
Training loss: 0.8677643537521362 / Val Acc: 0.7
Training loss: 0.48966965079307556 / Val Acc: 0.65
Training loss: 0.5960067510604858 / Val Acc: 0.665
Training loss: 0.6547960042953491 / Val Acc: 0.695
Training loss: 0.5239405035972595 / Val Acc: 0.67
Training loss: 0.3068825304508209 / Val Acc: 0.665
Training loss: 0.7026999592781067 / Val Acc: 0.695
Training loss: 0.5422991514205933 / Val Acc: 0.63
Training loss: 0.6478659510612488 / Val Acc: 0.615
Training loss: 0.4264119267463684 / Val Acc: 0.67
Training loss: 0.49599581956863403 / Val A