In [1]:
!python -m venv venv && source venv/bin/activate
pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118  # cope cuda 
pip install monai[nib] nibabel numpy scikit-image tqdm matplotlib albumentations pytorch-lightning
pip install gradio  # present

SyntaxError: invalid syntax (3093708394.py, line 2)

In [None]:
!git clone https://github.com/USERNAME/REPO_NAME.git

# train.py

In [None]:
import os
from glob import glob
import numpy as np
import torch
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, ToTensord
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.apps import download_and_extract

# -------- CONFIG ----------

In [None]:
data_dir = "/kaggle/input/aortic-seg"  
# data structure:
# data/
#   images/
#   masks/
images = sorted(glob(os.path.join(data_dir, "images", "*.nii*")))
masks  = sorted(glob(os.path.join(data_dir, "masks",  "*.nii*")))

# organization as images
assert len(images) == len(masks)

# individed
n = len(images)
train_ids = list(range(0, int(0.7*n)))
val_ids   = list(range(int(0.7*n), int(0.85*n)))
test_ids  = list(range(int(0.85*n), n))

def make_pairs(ids):
    return [{"image": images[i], "label": masks[i]} for i in ids]

train_files = make_pairs(train_ids)
val_files   = make_pairs(val_ids)
test_files  = make_pairs(test_ids)



# -------- TRANSFORMS ----------

In [None]:
train_transforms = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.0,1.0,1.0), mode=("bilinear","nearest")),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True),
    RandCropByPosNegLabeld(
        keys=["image","label"], label_key="label",
        spatial_size=(128,128,64), pos=1, neg=1, num_samples=4
    ),
    RandFlipd(keys=["image","label"], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=["image","label"], prob=0.5, max_k=3),
    ToTensord(keys=["image","label"])
]

val_transforms = [
    LoadImaged(keys=["image","label"]),
    EnsureChannelFirstd(keys=["image","label"]),
    Spacingd(keys=["image", "label"], pixdim=(1.0,1.0,1.0), mode=("bilinear","nearest")),
    Orientationd(keys=["image","label"], axcodes="RAS"),
    ScaleIntensityRanged(keys=["image"], a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True),
    ToTensord(keys=["image","label"])
]



# -------- DATASETS ----------

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16,32,64,128),
    strides=(2,2,2),
    num_res_units=2,
).to(device)

loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

# -------- TRAIN LOOP ----------

In [None]:
max_epochs = 200
best_metric = -1
save_path = "./best_model.pth"

for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step

# VALIDATION (quick)

In [None]:
model.eval()
    with torch.no_grad():
        val_dice = []
        for val_data in val_loader:
            val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
            val_outputs = sliding_window_inference(val_inputs, (128,128,64), 4, model, overlap=0.5)
            val_outputs = torch.sigmoid(val_outputs)
            val_outputs = (val_outputs > 0.5).float()
            # compute dice
            dice = dice_metric(y_pred=val_outputs, y=val_labels)
            val_dice.append(dice.item())
        mean_dice = np.mean(val_dice)

    print(f"Epoch {epoch+1}/{max_epochs} loss: {epoch_loss:.4f} val_dice: {mean_dice:.4f}")

# save best

In [None]:
    if mean_dice > best_metric:
        best_metric = mean_dice
        torch.save(model.state_dict(), save_path)
        print(f"Saved new best: {best_metric:.4f}")

print("Training finished.")