In [None]:
# Cell 1: Download dataset from Figshare
import os
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm

DATA_ROOT = Path("/content/data/China_Fundus_CIMT")
DATA_ROOT.mkdir(parents=True, exist_ok=True)

def download_figshare_dataset(article_id=27907056, out_dir=DATA_ROOT):
    """Download China-Fundus-CIMT dataset from Figshare"""
    print(f"Downloading from Figshare article {article_id}...")
    api_url = f"https://api.figshare.com/v2/articles/{article_id}"

    r = requests.get(api_url)
    r.raise_for_status()
    meta = r.json()

    files = meta.get("files", [])
    if not files:
        raise ValueError("No files found in Figshare article")

    for file_info in files:
        name = file_info['name']
        url = file_info['download_url']
        dest = out_dir / name

        if dest.exists():
            print(f"‚úì {name} already exists, skipping")
            continue

        print(f"Downloading {name}...")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        with open(dest, 'wb') as f, tqdm(
            desc=name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as pbar:
            for chunk in response.iter_content(chunk_size=8192):
                size = f.write(chunk)
                pbar.update(size)

        if dest.suffix == '.zip':
            print(f"Extracting {name}...")
            with zipfile.ZipFile(dest, 'r') as zip_ref:
                zip_ref.extractall(out_dir)
            dest.unlink()

    print("\n‚úÖ Dataset download complete!")

if not (DATA_ROOT / "Fundus_CIMT_2903 Dataset").exists():
    download_figshare_dataset()
else:
    print("‚úÖ Dataset already exists")

Downloading from Figshare article 27907056...
Downloading data_info.json...


data_info.json: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 736k/736k [00:00<00:00, 1.06MiB/s]


Downloading Fundus_CIMT_2903.zip...


Fundus_CIMT_2903.zip: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1.50G/1.50G [01:21<00:00, 19.8MiB/s]


Extracting Fundus_CIMT_2903.zip...

‚úÖ Dataset download complete!


In [None]:
# Cell 2: Verify dataset
DATA_ROOT = Path("/content/data")
DATASETS = {
    "China_Fundus_CIMT": DATA_ROOT / "China_Fundus_CIMT",
}

img_exts = {".png",".jpg",".jpeg",".tif",".tiff",".bmp",".gif"}

def count_by_ext(d: Path, exts):
    return sum(1 for p in d.rglob("*") if p.is_file() and p.suffix.lower() in exts)

print("Summary:")
for name, root in DATASETS.items():
    if not root.exists():
        continue
    n_img = count_by_ext(root, img_exts)
    print(f"‚úÖ {name:<16} images‚âà{n_img:,}    path={root}")

Summary:
‚úÖ China_Fundus_CIMT images‚âà5,806    path=/content/data/China_Fundus_CIMT


In [None]:
# Cell 3: Preprocess the dataset
import json
import shutil
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path

# Configuration
SEED = 42
np.random.seed(SEED)

RAW_DATA_ROOT = Path("/content/data/China_Fundus_CIMT")
DATASET_FOLDER = RAW_DATA_ROOT / "Fundus_CIMT_2903 Dataset"
DATA_INFO_JSON = RAW_DATA_ROOT / "data_info.json"

PROCESSED_ROOT = Path("/content/processed_data/CIMT")
PROCESSED_ROOT.mkdir(parents=True, exist_ok=True)

TARGET_SIZE = (512, 512)

print("Loading metadata...")
with open(DATA_INFO_JSON, 'r') as f:
    metadata_dict = json.load(f)

# Create metadata CSV
metadata_list = []
for patient_id, info in metadata_dict.items():
    metadata_list.append({
        'patient_id': patient_id,
        'age': info['True_age'],
        'age_norm': info['age'],  # Normalized age
        'gender': info['gender'],  # 0=female, 1=male
        'thickness': info['thickness'],  # CIMT values (will parse this)
        'label': info['label'],  # 0=normal, 1=thickened (for reference)
        'group': info['group'],  # 1=train, 2=val, 3=test
        'left_image': info['left_eye'],
        'right_image': info['right_eye']
    })

metadata_df = pd.DataFrame(metadata_list)
metadata_df.to_csv(PROCESSED_ROOT / "metadata.csv", index=False)

print(f"‚úÖ Metadata saved: {len(metadata_df)} patients")
print(f"   Train: {(metadata_df['group']==1).sum()}")
print(f"   Val: {(metadata_df['group']==2).sum()}")
print(f"   Test: {(metadata_df['group']==3).sum()}")

# Copy images to processed folder
images_out = PROCESSED_ROOT / "images"
images_out.mkdir(exist_ok=True)

print("\nCopying images...")
for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc="Processing"):
    for eye in ['left_image', 'right_image']:
        src = DATASET_FOLDER / row[eye]
        dst = images_out / row[eye]
        if src.exists() and not dst.exists():
            shutil.copy(src, dst)

print("\n‚úÖ Preprocessing complete!")

Loading metadata...
‚úÖ Metadata saved: 2903 patients
   Train: 2603
   Val: 200
   Test: 100

Copying images...


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2903/2903 [00:05<00:00, 490.48it/s] 


‚úÖ Preprocessing complete!





In [None]:
# Cell 4: Configuration - OPTIMIZED FOR COLAB FREE
import torch

# ==================== PATHS ====================
PROCESSED_ROOT = Path("/content/processed_data/CIMT")
METADATA_CSV = PROCESSED_ROOT / "metadata.csv"
IMAGES_DIR = PROCESSED_ROOT / "images"

OUTPUT_DIR = Path("/content/outputs/cimt_regression")
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
LOGS_DIR = OUTPUT_DIR / "logs"
RESULTS_DIR = OUTPUT_DIR / "results"

for dir_path in [OUTPUT_DIR, CHECKPOINT_DIR, LOGS_DIR, RESULTS_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)

# ==================== MODEL ====================
MODEL_NAME = "seresnext50_32x4d"
USE_PRETRAINED = True
USE_MULTIMODAL = True

CLINICAL_INPUT_DIM = 3  # age + gender (2)
CLINICAL_HIDDEN_DIM = 128
BACKBONE_OUTPUT_DIM = 2048
FUSION_HIDDEN_DIMS = [512, 128]
DROPOUT_RATE = 0.5

# ==================== DATA - OPTIMIZED ====================
IMAGE_SIZE = 512
BATCH_SIZE = 24  # ‚ö†Ô∏è Small for Colab free (15GB RAM)
NUM_WORKERS = 2
PIN_MEMORY = True
GRADIENT_ACCUMULATION_STEPS = 8  # Effective batch = 2√ó8 = 16

# ==================== TRAINING - REDUCED ====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
USE_MIXED_PRECISION = True  # Essential for Colab free

# Reduced epochs for faster training
STAGE1_EPOCHS = 30  # Reduced from 100
STAGE1_LR = 0.001
STAGE1_LR_DECAY_FACTOR = 0.1
STAGE1_LR_DECAY_EVERY = 15

STAGE2_EPOCHS = 20  # Reduced from 100
STAGE2_LR = 0.00001
STAGE2_LR_DECAY_FACTOR = 0.1
STAGE2_LR_DECAY_EVERY = 10

OPTIMIZER = "adam"
WEIGHT_DECAY = 1e-4
BETAS = (0.9, 0.999)
LOSS_TYPE = "smooth_l1"  # Changed from 'weighted_bce'

# ==================== AUGMENTATION ====================
HORIZONTAL_FLIP_PROB = 0.5
VERTICAL_FLIP_PROB = 0.5
ROTATION_RANGE = 20
COLOR_JITTER = True
BRIGHTNESS = 0.2
CONTRAST = 0.2
SATURATION = 0.2
HUE = 0.1

NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]

# ==================== EVALUATION ====================
SAVE_BEST_MODEL = True
METRIC_FOR_BEST = "mae"  # Changed from 'auc' - lower is better
EARLY_STOPPING_PATIENCE = 15
LOG_INTERVAL = 10
SAVE_INTERVAL = 5

# ==================== REGRESSION SPECIFIC ====================
CIMT_THRESHOLD = 0.9  # mm - for optional binary classification metrics

# ==================== SUMMARY ====================
print("="*60)
print("CONFIGURATION LOADED - CIMT REGRESSION")
print("="*60)
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"\nüíæ Memory Optimizations:")
print(f"   ‚Ä¢ Batch size: {BATCH_SIZE}")
print(f"   ‚Ä¢ Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"   ‚Ä¢ Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"   ‚Ä¢ Mixed precision: {USE_MIXED_PRECISION}")
print(f"\nüéØ Training Configuration:")
print(f"   ‚Ä¢ Stage 1: {STAGE1_EPOCHS} epochs")
print(f"   ‚Ä¢ Stage 2: {STAGE2_EPOCHS} epochs")
print(f"   ‚Ä¢ Loss: {LOSS_TYPE}")
print(f"   ‚Ä¢ Metric: {METRIC_FOR_BEST} (lower is better)")
print("="*60)

CONFIGURATION LOADED - CIMT REGRESSION
Device: cuda
GPU: NVIDIA L4
GPU Memory: 23.80 GB

üíæ Memory Optimizations:
   ‚Ä¢ Batch size: 24
   ‚Ä¢ Gradient accumulation: 8
   ‚Ä¢ Effective batch: 192
   ‚Ä¢ Mixed precision: True

üéØ Training Configuration:
   ‚Ä¢ Stage 1: 30 epochs
   ‚Ä¢ Stage 2: 20 epochs
   ‚Ä¢ Loss: smooth_l1
   ‚Ä¢ Metric: mae (lower is better)


In [None]:
# Cell 5: Utility functions
def set_seed(seed):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

def parse_cimt_value(thickness_str):
    """
    Parse CIMT thickness string to extract continuous value in mm.
    Returns max(left, right) as the target CIMT value.

    Example formats:
    - "0.85, 0.90" -> extract left=0.85, right=0.90, return 0.90
    - "0.95" -> return 0.95
    """
    if pd.isna(thickness_str) or thickness_str == '':
        return None

    thickness_str = str(thickness_str).strip()

    # Try to parse comma-separated values
    if ',' in thickness_str:
        parts = [p.strip() for p in thickness_str.split(',')]
        values = []
        for p in parts:
            try:
                values.append(float(p))
            except ValueError:
                continue
        if values:
            return max(values)  # Return max of left and right

    # Try single value
    try:
        return float(thickness_str)
    except ValueError:
        return None

print("‚úÖ Utility functions defined")

‚úÖ Utility functions defined


In [None]:
# Cell 6: Data transforms
from torchvision import transforms

def get_transforms():
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(p=HORIZONTAL_FLIP_PROB),
        transforms.RandomVerticalFlip(p=VERTICAL_FLIP_PROB),
        transforms.RandomRotation(degrees=ROTATION_RANGE),
        transforms.ColorJitter(
            brightness=BRIGHTNESS,
            contrast=CONTRAST,
            saturation=SATURATION,
            hue=HUE
        ) if COLOR_JITTER else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
    ])

    val_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
    ])

    return train_transform, val_transform

print("‚úÖ Transforms defined")

‚úÖ Transforms defined


In [None]:
# Cell 7: Dataset class - MODIFIED FOR REGRESSION
from torch.utils.data import Dataset, DataLoader

class CIMTRegressionDataset(Dataset):
    """
    Dataset for CIMT regression task.
    Returns continuous CIMT value in mm (not binary label).
    """
    def __init__(self, metadata_csv, images_dir, split="train",
                 transform=None, use_multimodal=True):
        self.images_dir = Path(images_dir)
        self.transform = transform
        self.use_multimodal = use_multimodal
        self.split = split

        # Load metadata
        df = pd.read_csv(metadata_csv)

        # Filter by split (group: 1=train, 2=val, 3=test)
        split_map = {'train': 1, 'val': 2, 'test': 3}
        df = df[df['group'] == split_map[split]].copy()

        # Parse CIMT values from thickness column
        df['cimt_mm'] = df['thickness'].apply(parse_cimt_value)

        # Remove samples with missing CIMT values
        df = df.dropna(subset=['cimt_mm']).reset_index(drop=True)

        self.data = df

        # Statistics
        cimt_values = self.data['cimt_mm'].values
        print(f"{split.upper()}: {len(self.data)} patients")
        print(f"  CIMT range: [{cimt_values.min():.2f}, {cimt_values.max():.2f}] mm")
        print(f"  CIMT mean¬±std: {cimt_values.mean():.2f}¬±{cimt_values.std():.2f} mm")
        print(f"  Thickened (‚â•{CIMT_THRESHOLD}mm): {(cimt_values >= CIMT_THRESHOLD).sum()}")
        print(f"  Normal (<{CIMT_THRESHOLD}mm): {(cimt_values < CIMT_THRESHOLD).sum()}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Load images
        left_path = self.images_dir / row['left_image']
        right_path = self.images_dir / row['right_image']

        left_img = Image.open(left_path).convert('RGB')
        right_img = Image.open(right_path).convert('RGB')

        if self.transform:
            left_img = self.transform(left_img)
            right_img = self.transform(right_img)

        # Clinical features
        age = torch.tensor([row['age_norm']], dtype=torch.float32)
        gender = torch.tensor([1-row['gender'], row['gender']], dtype=torch.float32)
        clinical = torch.cat([age, gender])

        # CIMT value (continuous target) - shape [1]
        cimt_value = torch.tensor([row['cimt_mm']], dtype=torch.float32)

        return {
            'left_image': left_img,
            'right_image': right_img,
            'clinical': clinical,
            'cimt': cimt_value,  # Changed from 'label' to 'cimt'
            'patient_id': row['patient_id']
        }


def get_dataloaders():
    train_transform, val_transform = get_transforms()

    train_dataset = CIMTRegressionDataset(METADATA_CSV, IMAGES_DIR, 'train',
                                          train_transform, USE_MULTIMODAL)
    val_dataset = CIMTRegressionDataset(METADATA_CSV, IMAGES_DIR, 'val',
                                        val_transform, USE_MULTIMODAL)
    test_dataset = CIMTRegressionDataset(METADATA_CSV, IMAGES_DIR, 'test',
                                         val_transform, USE_MULTIMODAL)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=NUM_WORKERS,
                             pin_memory=PIN_MEMORY, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                           shuffle=False, num_workers=NUM_WORKERS,
                           pin_memory=PIN_MEMORY)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                            shuffle=False, num_workers=NUM_WORKERS,
                            pin_memory=PIN_MEMORY)

    return train_loader, val_loader, test_loader

print("‚úÖ Dataset class defined")

‚úÖ Dataset class defined


In [None]:
# Cell 8: Model architecture - MODIFIED FOR REGRESSION
import torch.nn as nn
import timm

class SiameseMultimodalCIMTRegression(nn.Module):
    """
    Siamese multimodal model for CIMT regression.
    Outputs a single scalar value (CIMT in mm) with no activation.
    """
    def __init__(self):
        super().__init__()

        # Shared backbone for both eyes (unchanged)
        self.backbone = timm.create_model(
            MODEL_NAME,
            pretrained=USE_PRETRAINED,
            num_classes=0,  # Remove classification head
            global_pool='avg'
        )

        # Clinical feature processor (unchanged)
        self.clinical_fc = nn.Sequential(
            nn.Linear(CLINICAL_INPUT_DIM, CLINICAL_HIDDEN_DIM),
            nn.ReLU(),
            nn.Dropout(DROPOUT_RATE)
        )

        # Fusion layers - modified for regression
        fusion_input_dim = BACKBONE_OUTPUT_DIM * 2 + CLINICAL_HIDDEN_DIM

        layers = []
        in_dim = fusion_input_dim
        for hidden_dim in FUSION_HIDDEN_DIMS:
            layers.extend([
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(DROPOUT_RATE)
            ])
            in_dim = hidden_dim

        # Final regression head: outputs 1 scalar, no activation
        layers.append(nn.Linear(in_dim, 1))  # Changed: no sigmoid
        self.fusion = nn.Sequential(*layers)

    def forward(self, left_img, right_img, clinical):
        # Extract features from both eyes (shared weights)
        left_features = self.backbone(left_img)
        right_features = self.backbone(right_img)

        # Concatenate bilateral features
        bilateral_features = torch.cat([left_features, right_features], dim=1)

        # Process clinical features
        clinical_features = self.clinical_fc(clinical)

        # Fuse all features
        fused = torch.cat([bilateral_features, clinical_features], dim=1)

        # Regression output (no sigmoid, pure linear output)
        output = self.fusion(fused)  # Shape: [batch_size, 1]

        return output

print("‚úÖ Model architecture defined")

‚úÖ Model architecture defined


In [None]:
# Cell 9: Loss and metrics - MODIFIED FOR REGRESSION
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

class RegressionMetricsCalculator:
    """
    Calculate regression metrics: MAE, RMSE, R¬≤
    Optionally calculate binary classification metrics at threshold.
    """
    def __init__(self, threshold=CIMT_THRESHOLD):
        self.threshold = threshold
        self.reset()

    def reset(self):
        self.predictions = []
        self.targets = []

    def update(self, predictions, targets):
        """
        predictions: tensor of shape [batch_size, 1]
        targets: tensor of shape [batch_size, 1]
        """
        self.predictions.extend(predictions.cpu().numpy().flatten())
        self.targets.extend(targets.cpu().numpy().flatten())

    def compute(self):
        y_true = np.array(self.targets)
        y_pred = np.array(self.predictions)

        # Regression metrics
        mae = mean_absolute_error(y_true, y_pred)
        mse = mean_squared_error(y_true, y_pred)
        rmse = np.sqrt(mse)

        # R¬≤ (coefficient of determination)
        r2 = r2_score(y_true, y_pred)

        metrics = {
            'mae': mae,
            'rmse': rmse,
            'mse': mse,
            'r2': r2
        }

        # Optional: Binary classification metrics at threshold
        y_true_binary = (y_true >= self.threshold).astype(int)
        y_pred_binary = (y_pred >= self.threshold).astype(int)

        accuracy = (y_true_binary == y_pred_binary).mean()

        metrics['threshold_accuracy'] = accuracy
        metrics['threshold'] = self.threshold

        return metrics

print("‚úÖ Metrics defined")

‚úÖ Metrics defined


In [None]:
import torch.optim as optim
from torch.amp import autocast, GradScaler # Changed import from torch.cuda.amp
import time

# Enable mixed precision
scaler = GradScaler() if USE_MIXED_PRECISION else None

def train_epoch(model, dataloader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    metrics_calc = RegressionMetricsCalculator()

    optimizer.zero_grad()

    pbar = tqdm(dataloader, desc=f'Epoch {epoch} [Train]')
    for batch_idx, batch in enumerate(pbar):
        left_img = batch['left_image'].to(DEVICE)
        right_img = batch['right_image'].to(DEVICE)
        clinical = batch['clinical'].to(DEVICE)
        targets = batch['cimt'].to(DEVICE)  # Changed from 'label' to 'cimt'

        # Forward pass with mixed precision
        if USE_MIXED_PRECISION:
            with autocast(device_type=DEVICE.type): # Changed from 'cuda' to DEVICE.type
                predictions = model(left_img, right_img, clinical)  # No sigmoid
                loss = criterion(predictions, targets)
                loss = loss / GRADIENT_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            # Update every N steps
            if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        else:
            predictions = model(left_img, right_img, clinical)
            loss = criterion(predictions, targets)
            loss = loss / GRADIENT_ACCUMULATION_STEPS
            loss.backward()

            if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()

        running_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS

        # Update metrics (no sigmoid needed for predictions)
        metrics_calc.update(predictions.detach(), targets)

        pbar.set_postfix({'loss': running_loss / (batch_idx + 1)})

    metrics = metrics_calc.compute()
    metrics['loss'] = running_loss / len(dataloader)
    return metrics


def validate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    metrics_calc = RegressionMetricsCalculator()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validating'):
            left_img = batch['left_image'].to(DEVICE)
            right_img = batch['right_image'].to(DEVICE)
            clinical = batch['clinical'].to(DEVICE)
            targets = batch['cimt'].to(DEVICE)  # Changed from 'label' to 'cimt'

            if USE_MIXED_PRECISION:
                with autocast(device_type=DEVICE.type): # Changed from 'cuda' to DEVICE.type
                    predictions = model(left_img, right_img, clinical)  # No sigmoid
                    loss = criterion(predictions, targets)
            else:
                predictions = model(left_img, right_img, clinical)
                loss = criterion(predictions, targets)

            running_loss += loss.item()

            # Update metrics (no sigmoid needed)
            metrics_calc.update(predictions, targets)

    metrics = metrics_calc.compute()
    metrics['loss'] = running_loss / len(dataloader)
    return metrics

print("‚úÖ Training functions defined")

‚úÖ Training functions defined


In [None]:
# Cell 11: Initialize everything
print("Initializing...")

# Create dataloaders
train_loader, val_loader, test_loader = get_dataloaders()

# Create model
model = SiameseMultimodalCIMTRegression().to(DEVICE)
print(f"\n‚úÖ Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Loss function - changed from BCE to MSE/SmoothL1
if LOSS_TYPE == "mse":
    criterion = nn.MSELoss()
elif LOSS_TYPE == "smooth_l1":
    criterion = nn.SmoothL1Loss()  # More robust to outliers
else:
    raise ValueError(f"Unknown loss type: {LOSS_TYPE}")

print(f"Loss function: {criterion.__class__.__name__}")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=STAGE1_LR,
                      weight_decay=WEIGHT_DECAY, betas=BETAS)

print("\n" + "="*60)
print("READY TO TRAIN!")
print("="*60)

Initializing...
TRAIN: 2603 patients
  CIMT range: [0.50, 1.70] mm
  CIMT mean¬±std: 1.00¬±0.18 mm
  Thickened (‚â•0.9mm): 1904
  Normal (<0.9mm): 699
VAL: 200 patients
  CIMT range: [0.50, 1.40] mm
  CIMT mean¬±std: 0.92¬±0.21 mm
  Thickened (‚â•0.9mm): 100
  Normal (<0.9mm): 100
TEST: 100 patients
  CIMT range: [0.50, 1.50] mm
  CIMT mean¬±std: 0.91¬±0.19 mm
  Thickened (‚â•0.9mm): 50
  Normal (<0.9mm): 50


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/111M [00:00<?, ?B/s]


‚úÖ Model created with 27,740,401 parameters
Loss function: SmoothL1Loss

READY TO TRAIN!


In [None]:
# Cell 12: Training Stage 1
print("\n" + "="*60)
print("STAGE 1: Training with higher LR")
print("="*60)

best_metric = float('inf')  # Changed: lower is better for MAE
patience_counter = 0
training_history = []

for epoch in range(1, STAGE1_EPOCHS + 1):
    # Learning rate decay
    if epoch > 1 and (epoch - 1) % STAGE1_LR_DECAY_EVERY == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= STAGE1_LR_DECAY_FACTOR

    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, epoch)

    # Validate
    val_metrics = validate(model, val_loader, criterion)

    # Log
    current_lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch {epoch}/{STAGE1_EPOCHS}:")
    print(f"  Train - Loss: {train_metrics['loss']:.4f}, "
          f"MAE: {train_metrics['mae']:.3f}mm, RMSE: {train_metrics['rmse']:.3f}mm, "
          f"R¬≤: {train_metrics['r2']:.3f}")
    print(f"  Val   - Loss: {val_metrics['loss']:.4f}, "
          f"MAE: {val_metrics['mae']:.3f}mm, RMSE: {val_metrics['rmse']:.3f}mm, "
          f"R¬≤: {val_metrics['r2']:.3f}")
    print(f"  Val Threshold Acc (@{CIMT_THRESHOLD}mm): {val_metrics['threshold_accuracy']:.3f}")
    print(f"  LR: {current_lr:.6f}")

    # Save best model (lower MAE is better)
    current_metric = val_metrics[METRIC_FOR_BEST]
    if current_metric < best_metric:
        best_metric = current_metric
        patience_counter = 0
        print(f"  ‚úÖ New best {METRIC_FOR_BEST.upper()}: {best_metric:.4f}")
        if SAVE_BEST_MODEL:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_metric': best_metric,
                'val_metrics': val_metrics
            }, CHECKPOINT_DIR / "best_model_stage1.pth")
    else:
        patience_counter += 1
        print(f"  Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")

    # Early stopping
    if patience_counter >= EARLY_STOPPING_PATIENCE:
        print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch}")
        break

    # Save checkpoint
    if epoch % SAVE_INTERVAL == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, CHECKPOINT_DIR / f"checkpoint_stage1_epoch{epoch}.pth")

    training_history.append({
        'epoch': epoch,
        'stage': 1,
        'train_loss': train_metrics['loss'],
        'train_mae': train_metrics['mae'],
        'train_rmse': train_metrics['rmse'],
        'train_r2': train_metrics['r2'],
        'val_loss': val_metrics['loss'],
        'val_mae': val_metrics['mae'],
        'val_rmse': val_metrics['rmse'],
        'val_r2': val_metrics['r2'],
        'val_threshold_acc': val_metrics['threshold_accuracy'],
        'lr': current_lr
    })

# Load best model from stage 1
if SAVE_BEST_MODEL and (CHECKPOINT_DIR / "best_model_stage1.pth").exists():
    print("\nLoading best model from Stage 1...")
    checkpoint = torch.load(CHECKPOINT_DIR / "best_model_stage1.pth", weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úÖ Loaded best model with {METRIC_FOR_BEST.upper()}: {checkpoint['best_metric']:.4f}")


STAGE 1: Training with higher LR


Epoch 1 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:49<00:00,  1.02s/it, loss=0.146]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 1/30:
  Train - Loss: 0.1464, MAE: 0.413mm, RMSE: 0.560mm, R¬≤: -8.335
  Val   - Loss: 0.0201, MAE: 0.170mm, RMSE: 0.201mm, R¬≤: 0.115
  Val Threshold Acc (@0.9mm): 0.630
  LR: 0.001000
  ‚úÖ New best MAE: 0.1703


Epoch 2 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:42<00:00,  1.05it/s, loss=0.0407]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.24it/s]



Epoch 2/30:
  Train - Loss: 0.0407, MAE: 0.228mm, RMSE: 0.285mm, R¬≤: -1.414
  Val   - Loss: 0.0220, MAE: 0.173mm, RMSE: 0.204mm, R¬≤: 0.084
  Val Threshold Acc (@0.9mm): 0.600
  LR: 0.001000
  Patience: 1/15


Epoch 3 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.03it/s, loss=0.0354]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.37it/s]



Epoch 3/30:
  Train - Loss: 0.0354, MAE: 0.210mm, RMSE: 0.266mm, R¬≤: -1.100
  Val   - Loss: 0.0203, MAE: 0.163mm, RMSE: 0.208mm, R¬≤: 0.050
  Val Threshold Acc (@0.9mm): 0.620
  LR: 0.001000
  ‚úÖ New best MAE: 0.1629


Epoch 4 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0327]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.36it/s]



Epoch 4/30:
  Train - Loss: 0.0327, MAE: 0.203mm, RMSE: 0.256mm, R¬≤: -0.945
  Val   - Loss: 0.0221, MAE: 0.169mm, RMSE: 0.218mm, R¬≤: -0.042
  Val Threshold Acc (@0.9mm): 0.565
  LR: 0.001000
  Patience: 1/15


Epoch 5 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:42<00:00,  1.05it/s, loss=0.0313]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.27it/s]



Epoch 5/30:
  Train - Loss: 0.0313, MAE: 0.199mm, RMSE: 0.250mm, R¬≤: -0.855
  Val   - Loss: 0.0152, MAE: 0.147mm, RMSE: 0.176mm, R¬≤: 0.320
  Val Threshold Acc (@0.9mm): 0.735
  LR: 0.001000
  ‚úÖ New best MAE: 0.1469


Epoch 6 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:42<00:00,  1.05it/s, loss=0.0294]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.26it/s]



Epoch 6/30:
  Train - Loss: 0.0294, MAE: 0.194mm, RMSE: 0.242mm, R¬≤: -0.746
  Val   - Loss: 0.0155, MAE: 0.148mm, RMSE: 0.178mm, R¬≤: 0.302
  Val Threshold Acc (@0.9mm): 0.780
  LR: 0.001000
  Patience: 1/15


Epoch 7 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.0258]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.37it/s]



Epoch 7/30:
  Train - Loss: 0.0258, MAE: 0.182mm, RMSE: 0.227mm, R¬≤: -0.527
  Val   - Loss: 0.0164, MAE: 0.147mm, RMSE: 0.180mm, R¬≤: 0.288
  Val Threshold Acc (@0.9mm): 0.690
  LR: 0.001000
  Patience: 2/15


Epoch 8 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.02it/s, loss=0.026]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.33it/s]



Epoch 8/30:
  Train - Loss: 0.0260, MAE: 0.182mm, RMSE: 0.228mm, R¬≤: -0.542
  Val   - Loss: 0.0144, MAE: 0.141mm, RMSE: 0.172mm, R¬≤: 0.351
  Val Threshold Acc (@0.9mm): 0.775
  LR: 0.001000
  ‚úÖ New best MAE: 0.1410


Epoch 9 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.03it/s, loss=0.0261]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.21it/s]



Epoch 9/30:
  Train - Loss: 0.0261, MAE: 0.181mm, RMSE: 0.229mm, R¬≤: -0.559
  Val   - Loss: 0.0145, MAE: 0.143mm, RMSE: 0.172mm, R¬≤: 0.352
  Val Threshold Acc (@0.9mm): 0.740
  LR: 0.001000
  Patience: 1/15


Epoch 10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.025]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.36it/s]



Epoch 10/30:
  Train - Loss: 0.0250, MAE: 0.177mm, RMSE: 0.224mm, R¬≤: -0.485
  Val   - Loss: 0.0147, MAE: 0.141mm, RMSE: 0.172mm, R¬≤: 0.351
  Val Threshold Acc (@0.9mm): 0.740
  LR: 0.001000
  ‚úÖ New best MAE: 0.1407


Epoch 11 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:42<00:00,  1.05it/s, loss=0.0249]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 11/30:
  Train - Loss: 0.0249, MAE: 0.179mm, RMSE: 0.223mm, R¬≤: -0.481
  Val   - Loss: 0.0159, MAE: 0.146mm, RMSE: 0.177mm, R¬≤: 0.309
  Val Threshold Acc (@0.9mm): 0.685
  LR: 0.001000
  Patience: 1/15


Epoch 12 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.03it/s, loss=0.0236]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.19it/s]



Epoch 12/30:
  Train - Loss: 0.0236, MAE: 0.173mm, RMSE: 0.217mm, R¬≤: -0.409
  Val   - Loss: 0.0143, MAE: 0.141mm, RMSE: 0.170mm, R¬≤: 0.366
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.001000
  Patience: 2/15


Epoch 13 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.01it/s, loss=0.0244]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.26it/s]



Epoch 13/30:
  Train - Loss: 0.0244, MAE: 0.176mm, RMSE: 0.221mm, R¬≤: -0.451
  Val   - Loss: 0.0137, MAE: 0.137mm, RMSE: 0.169mm, R¬≤: 0.376
  Val Threshold Acc (@0.9mm): 0.785
  LR: 0.001000
  ‚úÖ New best MAE: 0.1374


Epoch 14 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.0247]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.26it/s]



Epoch 14/30:
  Train - Loss: 0.0247, MAE: 0.175mm, RMSE: 0.222mm, R¬≤: -0.464
  Val   - Loss: 0.0252, MAE: 0.181mm, RMSE: 0.218mm, R¬≤: -0.045
  Val Threshold Acc (@0.9mm): 0.570
  LR: 0.001000
  Patience: 1/15


Epoch 15 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.0259]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 15/30:
  Train - Loss: 0.0259, MAE: 0.181mm, RMSE: 0.228mm, R¬≤: -0.538
  Val   - Loss: 0.0148, MAE: 0.143mm, RMSE: 0.174mm, R¬≤: 0.334
  Val Threshold Acc (@0.9mm): 0.800
  LR: 0.001000
  Patience: 2/15


Epoch 16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.0226]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.28it/s]



Epoch 16/30:
  Train - Loss: 0.0226, MAE: 0.168mm, RMSE: 0.213mm, R¬≤: -0.343
  Val   - Loss: 0.0156, MAE: 0.143mm, RMSE: 0.175mm, R¬≤: 0.324
  Val Threshold Acc (@0.9mm): 0.700
  LR: 0.000100
  Patience: 3/15


Epoch 17 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.04it/s, loss=0.0212]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.26it/s]



Epoch 17/30:
  Train - Loss: 0.0212, MAE: 0.162mm, RMSE: 0.206mm, R¬≤: -0.258
  Val   - Loss: 0.0138, MAE: 0.137mm, RMSE: 0.167mm, R¬≤: 0.389
  Val Threshold Acc (@0.9mm): 0.755
  LR: 0.000100
  ‚úÖ New best MAE: 0.1372


Epoch 18 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.03it/s, loss=0.0211]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.18it/s]



Epoch 18/30:
  Train - Loss: 0.0211, MAE: 0.163mm, RMSE: 0.206mm, R¬≤: -0.257
  Val   - Loss: 0.0146, MAE: 0.141mm, RMSE: 0.170mm, R¬≤: 0.365
  Val Threshold Acc (@0.9mm): 0.715
  LR: 0.000100
  Patience: 1/15


Epoch 19 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:47<00:00,  1.01it/s, loss=0.0228]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.19it/s]



Epoch 19/30:
  Train - Loss: 0.0228, MAE: 0.169mm, RMSE: 0.214mm, R¬≤: -0.356
  Val   - Loss: 0.0138, MAE: 0.136mm, RMSE: 0.167mm, R¬≤: 0.391
  Val Threshold Acc (@0.9mm): 0.745
  LR: 0.000100
  ‚úÖ New best MAE: 0.1365


Epoch 20 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.03it/s, loss=0.0214]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.36it/s]



Epoch 20/30:
  Train - Loss: 0.0214, MAE: 0.164mm, RMSE: 0.207mm, R¬≤: -0.271
  Val   - Loss: 0.0139, MAE: 0.137mm, RMSE: 0.167mm, R¬≤: 0.388
  Val Threshold Acc (@0.9mm): 0.735
  LR: 0.000100
  Patience: 1/15


Epoch 21 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:42<00:00,  1.05it/s, loss=0.0221]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.22it/s]



Epoch 21/30:
  Train - Loss: 0.0221, MAE: 0.166mm, RMSE: 0.210mm, R¬≤: -0.314
  Val   - Loss: 0.0139, MAE: 0.137mm, RMSE: 0.167mm, R¬≤: 0.389
  Val Threshold Acc (@0.9mm): 0.720
  LR: 0.000100
  Patience: 2/15


Epoch 22 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.03it/s, loss=0.0212]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.25it/s]



Epoch 22/30:
  Train - Loss: 0.0212, MAE: 0.164mm, RMSE: 0.206mm, R¬≤: -0.259
  Val   - Loss: 0.0132, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.412
  Val Threshold Acc (@0.9mm): 0.750
  LR: 0.000100
  ‚úÖ New best MAE: 0.1342


Epoch 23 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:47<00:00,  1.01it/s, loss=0.0217]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 23/30:
  Train - Loss: 0.0217, MAE: 0.166mm, RMSE: 0.208mm, R¬≤: -0.284
  Val   - Loss: 0.0130, MAE: 0.133mm, RMSE: 0.164mm, R¬≤: 0.412
  Val Threshold Acc (@0.9mm): 0.770
  LR: 0.000100
  ‚úÖ New best MAE: 0.1328


Epoch 24 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.0206]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.22it/s]



Epoch 24/30:
  Train - Loss: 0.0206, MAE: 0.161mm, RMSE: 0.203mm, R¬≤: -0.222
  Val   - Loss: 0.0133, MAE: 0.135mm, RMSE: 0.165mm, R¬≤: 0.403
  Val Threshold Acc (@0.9mm): 0.775
  LR: 0.000100
  Patience: 1/15


Epoch 25 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0212]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.36it/s]



Epoch 25/30:
  Train - Loss: 0.0212, MAE: 0.164mm, RMSE: 0.206mm, R¬≤: -0.253
  Val   - Loss: 0.0142, MAE: 0.139mm, RMSE: 0.168mm, R¬≤: 0.380
  Val Threshold Acc (@0.9mm): 0.705
  LR: 0.000100
  Patience: 2/15


Epoch 26 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.02]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.32it/s]



Epoch 26/30:
  Train - Loss: 0.0200, MAE: 0.159mm, RMSE: 0.200mm, R¬≤: -0.192
  Val   - Loss: 0.0141, MAE: 0.138mm, RMSE: 0.168mm, R¬≤: 0.382
  Val Threshold Acc (@0.9mm): 0.725
  LR: 0.000100
  Patience: 3/15


Epoch 27 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:42<00:00,  1.05it/s, loss=0.0218]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.30it/s]



Epoch 27/30:
  Train - Loss: 0.0218, MAE: 0.164mm, RMSE: 0.209mm, R¬≤: -0.290
  Val   - Loss: 0.0135, MAE: 0.135mm, RMSE: 0.165mm, R¬≤: 0.401
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000100
  Patience: 4/15


Epoch 28 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.03it/s, loss=0.02]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.24it/s]



Epoch 28/30:
  Train - Loss: 0.0200, MAE: 0.158mm, RMSE: 0.200mm, R¬≤: -0.188
  Val   - Loss: 0.0134, MAE: 0.133mm, RMSE: 0.166mm, R¬≤: 0.398
  Val Threshold Acc (@0.9mm): 0.770
  LR: 0.000100
  Patience: 5/15


Epoch 29 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:44<00:00,  1.04it/s, loss=0.0208]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.35it/s]



Epoch 29/30:
  Train - Loss: 0.0208, MAE: 0.162mm, RMSE: 0.204mm, R¬≤: -0.235
  Val   - Loss: 0.0139, MAE: 0.133mm, RMSE: 0.168mm, R¬≤: 0.379
  Val Threshold Acc (@0.9mm): 0.775
  LR: 0.000100
  Patience: 6/15


Epoch 30 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0208]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.25it/s]



Epoch 30/30:
  Train - Loss: 0.0208, MAE: 0.161mm, RMSE: 0.204mm, R¬≤: -0.232
  Val   - Loss: 0.0152, MAE: 0.141mm, RMSE: 0.173mm, R¬≤: 0.341
  Val Threshold Acc (@0.9mm): 0.715
  LR: 0.000100
  Patience: 7/15

Loading best model from Stage 1...
‚úÖ Loaded best model with MAE: 0.1328


In [None]:

# Cell 13: Training Stage 2 - Fine-tuning
print("\n" + "="*60)
print("STAGE 2: Fine-tuning with lower LR")
print("="*60)

optimizer = optim.Adam(model.parameters(), lr=STAGE2_LR,
                      weight_decay=WEIGHT_DECAY, betas=BETAS)

best_metric = float('inf')
patience_counter = 0

for epoch in range(1, STAGE2_EPOCHS + 1):
    # Learning rate decay
    if epoch > 1 and (epoch - 1) % STAGE2_LR_DECAY_EVERY == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= STAGE2_LR_DECAY_FACTOR

    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, epoch)

    # Validate
    val_metrics = validate(model, val_loader, criterion)

    # Log
    current_lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch {epoch}/{STAGE2_EPOCHS}:")
    print(f"  Train - Loss: {train_metrics['loss']:.4f}, "
          f"MAE: {train_metrics['mae']:.3f}mm, RMSE: {train_metrics['rmse']:.3f}mm, "
          f"R¬≤: {train_metrics['r2']:.3f}")
    print(f"  Val   - Loss: {val_metrics['loss']:.4f}, "
          f"MAE: {val_metrics['mae']:.3f}mm, RMSE: {val_metrics['rmse']:.3f}mm, "
          f"R¬≤: {val_metrics['r2']:.3f}")
    print(f"  Val Threshold Acc (@{CIMT_THRESHOLD}mm): {val_metrics['threshold_accuracy']:.3f}")
    print(f"  LR: {current_lr:.6f}")

    # Save best model
    current_metric = val_metrics[METRIC_FOR_BEST]
    if current_metric < best_metric:
        best_metric = current_metric
        patience_counter = 0
        print(f"  ‚úÖ New best {METRIC_FOR_BEST.upper()}: {best_metric:.4f}")
        if SAVE_BEST_MODEL:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_metric': best_metric,
                'val_metrics': val_metrics
            }, CHECKPOINT_DIR / "best_model_final.pth")
    else:
        patience_counter += 1
        print(f"  Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")

    if patience_counter >= EARLY_STOPPING_PATIENCE:
        print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch}")
        break

    training_history.append({
        'epoch': epoch,
        'stage': 2,
        'train_loss': train_metrics['loss'],
        'train_mae': train_metrics['mae'],
        'train_rmse': train_metrics['rmse'],
        'train_r2': train_metrics['r2'],
        'val_loss': val_metrics['loss'],
        'val_mae': val_metrics['mae'],
        'val_rmse': val_metrics['rmse'],
        'val_r2': val_metrics['r2'],
        'val_threshold_acc': val_metrics['threshold_accuracy'],
        'lr': current_lr
    })


STAGE 2: Fine-tuning with lower LR


Epoch 1 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.02it/s, loss=0.0215]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.20it/s]



Epoch 1/20:
  Train - Loss: 0.0215, MAE: 0.162mm, RMSE: 0.208mm, R¬≤: -0.280
  Val   - Loss: 0.0132, MAE: 0.135mm, RMSE: 0.164mm, R¬≤: 0.408
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  ‚úÖ New best MAE: 0.1345


Epoch 2 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.01it/s, loss=0.0208]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.30it/s]



Epoch 2/20:
  Train - Loss: 0.0208, MAE: 0.161mm, RMSE: 0.204mm, R¬≤: -0.237
  Val   - Loss: 0.0132, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.411
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  ‚úÖ New best MAE: 0.1338


Epoch 3 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0213]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.25it/s]



Epoch 3/20:
  Train - Loss: 0.0213, MAE: 0.163mm, RMSE: 0.207mm, R¬≤: -0.267
  Val   - Loss: 0.0131, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.412
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  ‚úÖ New best MAE: 0.1336


Epoch 4 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0208]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.18it/s]



Epoch 4/20:
  Train - Loss: 0.0208, MAE: 0.161mm, RMSE: 0.204mm, R¬≤: -0.237
  Val   - Loss: 0.0133, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.408
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  Patience: 1/15


Epoch 5 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.021]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.24it/s]



Epoch 5/20:
  Train - Loss: 0.0210, MAE: 0.163mm, RMSE: 0.205mm, R¬≤: -0.249
  Val   - Loss: 0.0133, MAE: 0.135mm, RMSE: 0.164mm, R¬≤: 0.408
  Val Threshold Acc (@0.9mm): 0.755
  LR: 0.000010
  Patience: 2/15


Epoch 6 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.03it/s, loss=0.0213]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.18it/s]



Epoch 6/20:
  Train - Loss: 0.0213, MAE: 0.164mm, RMSE: 0.206mm, R¬≤: -0.264
  Val   - Loss: 0.0131, MAE: 0.133mm, RMSE: 0.164mm, R¬≤: 0.412
  Val Threshold Acc (@0.9mm): 0.765
  LR: 0.000010
  ‚úÖ New best MAE: 0.1330


Epoch 7 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.01it/s, loss=0.0215]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.24it/s]



Epoch 7/20:
  Train - Loss: 0.0215, MAE: 0.164mm, RMSE: 0.207mm, R¬≤: -0.271
  Val   - Loss: 0.0132, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.411
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  Patience: 1/15


Epoch 8 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.021]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.22it/s]



Epoch 8/20:
  Train - Loss: 0.0210, MAE: 0.163mm, RMSE: 0.205mm, R¬≤: -0.246
  Val   - Loss: 0.0130, MAE: 0.133mm, RMSE: 0.163mm, R¬≤: 0.415
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  ‚úÖ New best MAE: 0.1328


Epoch 9 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0209]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 9/20:
  Train - Loss: 0.0209, MAE: 0.163mm, RMSE: 0.205mm, R¬≤: -0.240
  Val   - Loss: 0.0131, MAE: 0.133mm, RMSE: 0.163mm, R¬≤: 0.413
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000010
  Patience: 1/15


Epoch 10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:47<00:00,  1.01it/s, loss=0.0215]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 10/20:
  Train - Loss: 0.0215, MAE: 0.162mm, RMSE: 0.207mm, R¬≤: -0.275
  Val   - Loss: 0.0131, MAE: 0.134mm, RMSE: 0.163mm, R¬≤: 0.414
  Val Threshold Acc (@0.9mm): 0.755
  LR: 0.000010
  Patience: 2/15


Epoch 11 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0206]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.32it/s]



Epoch 11/20:
  Train - Loss: 0.0206, MAE: 0.160mm, RMSE: 0.203mm, R¬≤: -0.224
  Val   - Loss: 0.0132, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.413
  Val Threshold Acc (@0.9mm): 0.755
  LR: 0.000001
  Patience: 3/15


Epoch 12 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:47<00:00,  1.01it/s, loss=0.0193]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.20it/s]



Epoch 12/20:
  Train - Loss: 0.0193, MAE: 0.156mm, RMSE: 0.196mm, R¬≤: -0.146
  Val   - Loss: 0.0133, MAE: 0.135mm, RMSE: 0.164mm, R¬≤: 0.408
  Val Threshold Acc (@0.9mm): 0.750
  LR: 0.000001
  Patience: 4/15


Epoch 13 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0202]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.21it/s]



Epoch 13/20:
  Train - Loss: 0.0202, MAE: 0.161mm, RMSE: 0.201mm, R¬≤: -0.204
  Val   - Loss: 0.0130, MAE: 0.133mm, RMSE: 0.163mm, R¬≤: 0.415
  Val Threshold Acc (@0.9mm): 0.745
  LR: 0.000001
  Patience: 5/15


Epoch 14 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:47<00:00,  1.01it/s, loss=0.0213]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.19it/s]



Epoch 14/20:
  Train - Loss: 0.0213, MAE: 0.164mm, RMSE: 0.206mm, R¬≤: -0.265
  Val   - Loss: 0.0130, MAE: 0.133mm, RMSE: 0.163mm, R¬≤: 0.416
  Val Threshold Acc (@0.9mm): 0.760
  LR: 0.000001
  Patience: 6/15


Epoch 15 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:47<00:00,  1.01it/s, loss=0.021]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:03<00:00,  2.27it/s]



Epoch 15/20:
  Train - Loss: 0.0210, MAE: 0.163mm, RMSE: 0.205mm, R¬≤: -0.248
  Val   - Loss: 0.0131, MAE: 0.134mm, RMSE: 0.163mm, R¬≤: 0.413
  Val Threshold Acc (@0.9mm): 0.750
  LR: 0.000001
  Patience: 7/15


Epoch 16 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.02it/s, loss=0.0204]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.20it/s]



Epoch 16/20:
  Train - Loss: 0.0204, MAE: 0.159mm, RMSE: 0.202mm, R¬≤: -0.214
  Val   - Loss: 0.0130, MAE: 0.133mm, RMSE: 0.163mm, R¬≤: 0.416
  Val Threshold Acc (@0.9mm): 0.750
  LR: 0.000001
  Patience: 8/15


Epoch 17 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:43<00:00,  1.04it/s, loss=0.0202]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.21it/s]



Epoch 17/20:
  Train - Loss: 0.0202, MAE: 0.160mm, RMSE: 0.201mm, R¬≤: -0.199
  Val   - Loss: 0.0131, MAE: 0.134mm, RMSE: 0.163mm, R¬≤: 0.414
  Val Threshold Acc (@0.9mm): 0.755
  LR: 0.000001
  Patience: 9/15


Epoch 18 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:45<00:00,  1.02it/s, loss=0.0216]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.23it/s]



Epoch 18/20:
  Train - Loss: 0.0216, MAE: 0.164mm, RMSE: 0.208mm, R¬≤: -0.284
  Val   - Loss: 0.0132, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.412
  Val Threshold Acc (@0.9mm): 0.750
  LR: 0.000001
  Patience: 10/15


Epoch 19 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.01it/s, loss=0.0204]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.25it/s]



Epoch 19/20:
  Train - Loss: 0.0204, MAE: 0.159mm, RMSE: 0.202mm, R¬≤: -0.211
  Val   - Loss: 0.0131, MAE: 0.133mm, RMSE: 0.163mm, R¬≤: 0.416
  Val Threshold Acc (@0.9mm): 0.750
  LR: 0.000001
  Patience: 11/15


Epoch 20 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108/108 [01:46<00:00,  1.01it/s, loss=0.0197]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9/9 [00:04<00:00,  2.22it/s]


Epoch 20/20:
  Train - Loss: 0.0197, MAE: 0.157mm, RMSE: 0.199mm, R¬≤: -0.170
  Val   - Loss: 0.0133, MAE: 0.134mm, RMSE: 0.164mm, R¬≤: 0.409
  Val Threshold Acc (@0.9mm): 0.755
  LR: 0.000001
  Patience: 12/15





In [None]:
# Cell 14: Final evaluation on test set
print("\n" + "="*60)
print("FINAL EVALUATION ON TEST SET")
print("="*60)

if SAVE_BEST_MODEL and (CHECKPOINT_DIR / "best_model_final.pth").exists():
    print("Loading best model...")
    checkpoint = torch.load(CHECKPOINT_DIR / "best_model_final.pth", weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úÖ Loaded best model with {METRIC_FOR_BEST.upper()}: {checkpoint['best_metric']:.4f}")

test_metrics = validate(model, test_loader, criterion)

print("\nTest Set Results:")
print(f"  Loss: {test_metrics['loss']:.4f}")
print(f"  MAE: {test_metrics['mae']:.3f} mm")
print(f"  RMSE: {test_metrics['rmse']:.3f} mm")
print(f"  R¬≤: {test_metrics['r2']:.3f}")
print(f"  Threshold Accuracy (@{CIMT_THRESHOLD}mm): {test_metrics['threshold_accuracy']:.3f}")

# Save training history
history_df = pd.DataFrame(training_history)
history_df.to_csv(RESULTS_DIR / "training_history.csv", index=False)
print(f"\n‚úÖ Training history saved to {RESULTS_DIR / 'training_history.csv'}")

# Save test results
test_results = {
    'test_loss': test_metrics['loss'],
    'test_mae': test_metrics['mae'],
    'test_rmse': test_metrics['rmse'],
    'test_r2': test_metrics['r2'],
    'test_threshold_acc': test_metrics['threshold_accuracy']
}

import json
with open(RESULTS_DIR / "test_results.json", 'w') as f:
    json.dump(test_results, f, indent=2)
print(f"‚úÖ Test results saved to {RESULTS_DIR / 'test_results.json'}")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)


FINAL EVALUATION ON TEST SET
Loading best model...
‚úÖ Loaded best model with MAE: 0.1328


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:02<00:00,  2.01it/s]


Test Set Results:
  Loss: 0.0077
  MAE: 0.105 mm
  RMSE: 0.133 mm
  R¬≤: 0.519
  Threshold Accuracy (@0.9mm): 0.860

‚úÖ Training history saved to /content/outputs/cimt_regression/results/training_history.csv
‚úÖ Test results saved to /content/outputs/cimt_regression/results/test_results.json

TRAINING COMPLETE!





In [None]:
# Cell 15: Inference example (optional)
# Example: Make predictions on a few samples

def predict_cimt(model, dataloader, num_samples=5):
    """
    Make predictions on a few samples and compare with ground truth.
    """
    model.eval()

    results = []
    with torch.no_grad():
        for batch in dataloader:
            left_img = batch['left_image'].to(DEVICE)
            right_img = batch['right_image'].to(DEVICE)
            clinical = batch['clinical'].to(DEVICE)
            targets = batch['cimt'].cpu().numpy()
            patient_ids = batch['patient_id']

            predictions = model(left_img, right_img, clinical)
            predictions = predictions.cpu().numpy()

            for i in range(len(predictions)):
                pred_val = predictions[i][0]
                true_val = targets[i][0]
                error = abs(pred_val - true_val)

                results.append({
                    'patient_id': patient_ids[i],
                    'predicted_cimt': pred_val,
                    'true_cimt': true_val,
                    'error': error,
                    'predicted_class': 'Thickened' if pred_val >= CIMT_THRESHOLD else 'Normal',
                    'true_class': 'Thickened' if true_val >= CIMT_THRESHOLD else 'Normal'
                })

                if len(results) >= num_samples:
                    break

            if len(results) >= num_samples:
                break

    return pd.DataFrame(results)

# Make predictions on test set
print("Making predictions on test samples...\n")
predictions_df = predict_cimt(model, test_loader, num_samples=10)
print(predictions_df.to_string(index=False))

print(f"\nAverage prediction error: {predictions_df['error'].mean():.3f} mm")

Making predictions on test samples...

       patient_id  predicted_cimt  true_cimt    error predicted_class true_class
tensor(151594008)        1.018096        1.1 0.081904       Thickened  Thickened
tensor(151932002)        1.014267        0.9 0.114267       Thickened  Thickened
tensor(152071010)        1.071923        1.1 0.028077       Thickened  Thickened
tensor(152073003)        1.040416        1.1 0.059584       Thickened  Thickened
tensor(152584002)        0.992586        0.9 0.092586       Thickened  Thickened
tensor(153884002)        1.018700        1.0 0.018700       Thickened  Thickened
tensor(155500003)        1.032282        1.1 0.067718       Thickened  Thickened
tensor(155500005)        1.029913        1.4 0.370087       Thickened  Thickened
tensor(157612002)        0.928235        1.0 0.071765       Thickened  Thickened
tensor(157742003)        1.032693        1.2 0.167307       Thickened  Thickened

Average prediction error: 0.107 mm


In [None]:
# ==================== SAVE THE CURRENT MODEL (ALREADY TRAINED) ====================

import torch
from pathlib import Path

print("Saving the fully-trained model that's currently in memory...")

# Save the model that's already been trained
final_checkpoint = {
    'epoch': 50,  # Or STAGE1_EPOCHS + STAGE2_EPOCHS
    'stage1_epochs': 30,
    'stage2_epochs': 20,
    'model_state_dict': model.state_dict(),  # Current model in memory
    'optimizer_state_dict': optimizer.state_dict(),
    'best_mae': best_metric if 'best_metric' in locals() else None,
    'training_complete': True,
    'metrics': {
        'best_mae': float(best_metric) if 'best_metric' in locals() else None,
    }
}

# Save it
save_path = Path('/content/cimt_fully_trained_epoch50.pth')
torch.save(final_checkpoint, save_path)

print(f"\n‚úÖ SAVED!")
print(f"   Path: {save_path}")
print(f"   Epoch: 50")

# Download it
from google.colab import files
files.download(str(save_path))

print("\n‚úÖ Downloaded! This is your fully-trained model.")

Saving the fully-trained model that's currently in memory...

‚úÖ SAVED!
   Path: /content/cimt_fully_trained_epoch50.pth
   Epoch: 50


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


‚úÖ Downloaded! This is your fully-trained model.
