In [None]:
from preprocessing import DicomMontageDataset, create_train_test_val

In [None]:
dataset_dataloader_dict = create_train_test_val("dataset", "dataset/labels_multi.csv", batch_size=8, num_workers=4, test_size=0.25, val_size=0.5, random_state=42)

train_dataset = dataset_dataloader_dict["train_dataset"]
val_dataset = dataset_dataloader_dict["val_dataset"]
test_dataset = dataset_dataloader_dict["test_dataset"]
train_loader = dataset_dataloader_dict["train_loader"]
val_loader = dataset_dataloader_dict["val_loader"]
test_loader = dataset_dataloader_dict["test_loader"]

In [None]:
from monai.networks.nets import SEResNet50
import torch.nn as nn
import torch

# Load pretrained SE-ResNet50 3D
# Create SE-ResNet50 3D model for binary classification
model = SEResNet50(
    spatial_dims=3,
    in_channels=1,      # Use this instead of n_input_channels
    num_classes=4       # For binary classification
)

# Optional: move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

## Optionally freeze early layers to prevent fine tuning the whole network
# for param in model.layer1.parameters():
#     param.requires_grad = False

num_epochs = 10

for epoch in range(num_epochs):
    # --- Training ---
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)  # (B, 1, 10, 224, 224)=
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.2f}%")

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for val_inputs, val_labels in val_loader:
            val_inputs = val_inputs.to(device)
            val_labels = val_labels.to(device)

            val_outputs = model(val_inputs)
            loss = criterion(val_outputs, val_labels)

            val_loss += loss.item()
            _, val_predicted = val_outputs.max(1)
            val_correct += val_predicted.eq(val_labels).sum().item()
            val_total += val_labels.size(0)

    val_acc = 100 * val_correct / val_total
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%\n")