### Install Dependencies

In [1]:
!pip install timm --no-deps
import timm
print("Timm version:", timm.__version__)





Timm version: 1.0.19


### Configuration & Imports

In [2]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from PIL import Image
from tqdm.notebook import tqdm
import timm
import warnings
import gc

warnings.filterwarnings("ignore")

# --- OPTIMIZED CONFIGURATION (Fixes OOM) ---
Config = {
    'model_name': 'maxvit_base_tf_512.in21k_ft_in1k',
    'img_size': 512,
    'batch_size': 4,              # Reduced to fit VRAM
    'accum_iter': 4,              # Effective Batch Size = 16
    'epochs': 15,
    'learning_rate': 2e-5,
    'weight_decay': 0.05,
    'drop_path_rate': 0.2,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'seed': 42
}

# --- PATHS ---
ROOT_DIR = "/kaggle/input/vindr-spinexr-modified/vindr-spinexr-a-large-annotated-medical-image-dataset"
TRAIN_CSV_PATH = os.path.join(ROOT_DIR, "annotations/train.csv")
TEST_CSV_PATH = os.path.join(ROOT_DIR, "annotations/test.csv")
TRAIN_IMG_DIR = os.path.join(ROOT_DIR, "train_png")
TEST_IMG_DIR = os.path.join(ROOT_DIR, "test_png")

# --- DATASET CLASS ---
class SpineDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        self.df = dataframe
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row['image_id']
        img_name = f"{img_id}.png" if not str(img_id).endswith('.png') else img_id
        img_path = os.path.join(self.root_dir, img_name)
        
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            image = Image.new('RGB', (Config['img_size'], Config['img_size']))
            
        label = torch.tensor(row['label_encoded'], dtype=torch.long)
        if self.transform:
            image = self.transform(image)
        return image, label

# --- PREPROCESSING ---
df = pd.read_csv(TRAIN_CSV_PATH).drop_duplicates(subset=['image_id']).reset_index(drop=True)
encoder = LabelEncoder()
df['label_encoded'] = encoder.fit_transform(df['lesion_type'])
Config['num_classes'] = len(encoder.classes_)

# Class Weights
class_weights = compute_class_weight('balanced', classes=np.unique(df['label_encoded']), y=df['label_encoded'])
weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(Config['device'])

# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((Config['img_size'], Config['img_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((Config['img_size'], Config['img_size'])),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Loaders
train_df, val_df = train_test_split(df, test_size=0.2, random_state=Config['seed'], stratify=df['label_encoded'])
train_loader = DataLoader(SpineDataset(train_df, TRAIN_IMG_DIR, train_transforms), 
                          batch_size=Config['batch_size'], shuffle=True, num_workers=Config['num_workers'])
val_loader = DataLoader(SpineDataset(val_df, TRAIN_IMG_DIR, val_transforms), 
                        batch_size=Config['batch_size'], shuffle=False, num_workers=Config['num_workers'])

print(f"Setup Complete. Training on {len(train_df)} images.")

Setup Complete. Training on 6711 images.


### Model Initialization

In [3]:
import gc

# 1. Force Clear Memory
torch.cuda.empty_cache()
gc.collect()

print(f"Allocated Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

# 2. Create Model
model = timm.create_model(
    Config['model_name'], 
    pretrained=True, 
    num_classes=Config['num_classes'],
    drop_path_rate=Config['drop_path_rate']
)

# 3. Move to Device safely
try:
    model = model.to(Config['device'])
    print("Model successfully moved to GPU.")
except RuntimeError as e:
    print(f"ERROR: {e}")
    print("Tip: Restart your kernel (Run > Restart Session) to clear old memory.")

# 4. Optimizer & Loss
optimizer = optim.AdamW(model.parameters(), lr=Config['learning_rate'], weight_decay=Config['weight_decay'])
criterion = nn.CrossEntropyLoss(weight=weights_tensor)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=Config['learning_rate'], steps_per_epoch=len(train_loader)//Config['accum_iter'], epochs=Config['epochs']
)

Allocated Memory: 0.00 GB
Model successfully moved to GPU.


### Training Loop (With AMP & Accumulation)

In [4]:
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
best_acc = 0.0

print(f"Starting Training for {Config['epochs']} Epochs...")

for epoch in range(Config['epochs']):
    # --- TRAIN ---
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    optimizer.zero_grad()
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config['epochs']} [Train]")
    
    for i, (images, labels) in enumerate(loop):
        images, labels = images.to(Config['device']), labels.to(Config['device'])
        
        # Mixed Precision
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss = loss / Config['accum_iter'] # Normalize
        
        scaler.scale(loss).backward()
        
        if (i + 1) % Config['accum_iter'] == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
        
        train_loss += loss.item() * Config['accum_iter']
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        loop.set_postfix(loss=loss.item() * Config['accum_iter'])
    
    avg_train_loss = train_loss / len(train_loader)
    
    # Cleanup before validation
    del images, labels, outputs
    torch.cuda.empty_cache()
    gc.collect()
    
    # --- VALIDATE ---
    model.eval()
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(Config['device']), labels.to(Config['device'])
            with autocast():
                outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            
    val_acc = 100 * val_correct / val_total
    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "maxvit_best_model.pth")
        print(f"--> Best Model Saved! ({val_acc:.2f}%)")
    print("-" * 50)

Starting Training for 15 Epochs...


Epoch 1/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 1: Train Loss: 1.9942 | Val Acc: 54.77%
--> Best Model Saved! (54.77%)
--------------------------------------------------


Epoch 2/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 2: Train Loss: 1.6243 | Val Acc: 58.52%
--> Best Model Saved! (58.52%)
--------------------------------------------------


Epoch 3/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 3: Train Loss: 1.3947 | Val Acc: 61.56%
--> Best Model Saved! (61.56%)
--------------------------------------------------


Epoch 4/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 4: Train Loss: 1.2244 | Val Acc: 64.78%
--> Best Model Saved! (64.78%)
--------------------------------------------------


Epoch 5/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 5: Train Loss: 1.1212 | Val Acc: 65.14%
--> Best Model Saved! (65.14%)
--------------------------------------------------


Epoch 6/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 6: Train Loss: 1.0715 | Val Acc: 67.58%
--> Best Model Saved! (67.58%)
--------------------------------------------------


Epoch 7/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 7: Train Loss: 1.0167 | Val Acc: 67.10%
--------------------------------------------------


Epoch 8/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 8: Train Loss: 0.9572 | Val Acc: 69.85%
--> Best Model Saved! (69.85%)
--------------------------------------------------


Epoch 9/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 9: Train Loss: 0.9403 | Val Acc: 69.07%
--------------------------------------------------


Epoch 10/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 10: Train Loss: 0.8966 | Val Acc: 69.79%
--------------------------------------------------


Epoch 11/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 11: Train Loss: 0.8728 | Val Acc: 70.14%
--> Best Model Saved! (70.14%)
--------------------------------------------------


Epoch 12/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 12: Train Loss: 0.8543 | Val Acc: 70.26%
--> Best Model Saved! (70.26%)
--------------------------------------------------


Epoch 13/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 13: Train Loss: 0.8414 | Val Acc: 70.86%
--> Best Model Saved! (70.86%)
--------------------------------------------------


Epoch 14/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 14: Train Loss: 0.8340 | Val Acc: 70.02%
--------------------------------------------------


Epoch 15/15 [Train]:   0%|          | 0/1678 [00:00<?, ?it/s]

Epoch 15: Train Loss: 0.8413 | Val Acc: 71.16%
--> Best Model Saved! (71.16%)
--------------------------------------------------


### Final Evalution

In [11]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, classification_report, accuracy_score

# ==========================================
# 1. SETUP & DATA PREPARATION
# ==========================================
test_df = pd.read_csv(TEST_CSV_PATH)

# UPDATE: Added missing classes from your warning
class_mapper = {
    'No finding': 0,
    'Disc space narrowing': 1,
    'Foraminal stenosis': 2,
    'Osteophytes': 3,
    'Spondylolisthesis': 4,      # Correct spelling (likely used in training)
    'Spondylolysthesis': 4,      # CSV spelling (maps to same class ID 4)
    'Vertebral collapse': 5,
    'Scoliosis': 6,
    'Surgical implant': 7,       # Assigned new ID (Ensure this matches training if used)
    'Other lesions': 8           # Assigned new ID (Ensure this matches training if used)
}

print("Mapping labels...")
test_df['label_encoded'] = test_df['lesion_type'].map(class_mapper)

# Drop rows that still have no match (Safety check)
if test_df['label_encoded'].isnull().any():
    print(f"⚠️ Warning: Still dropping {test_df['label_encoded'].isnull().sum()} rows with unknown labels.")
    test_df = test_df.dropna(subset=['label_encoded'])

test_df['label_encoded'] = test_df['label_encoded'].astype(int)

# REMOVED .head(50) -> Running on FULL dataset
test_df = test_df.drop_duplicates(subset=['image_id']).reset_index(drop=True)
print(f"Total Test Images: {len(test_df)}")

test_dataset = SpineDataset(test_df, TEST_IMG_DIR, transform=val_transforms)
test_loader = DataLoader(test_dataset, batch_size=Config['batch_size'], shuffle=False, num_workers=Config['num_workers'])

# ==========================================
# 2. LOAD MODEL
# ==========================================
model_path = "maxvit_best_model.pth" 
print(f"Loading model from: {model_path}")

model.load_state_dict(torch.load(model_path))
model.eval()

# ==========================================
# 3. FULL INFERENCE LOOP
# ==========================================
y_true, y_pred, y_probs = [], [], []

print("Running Full Inference...")
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(Config['device'])
        
        with torch.amp.autocast('cuda'): 
            outputs = model(images)
            # Get probabilities
            probs = torch.softmax(outputs, dim=1)
            # Get predicted class
            _, preds = torch.max(outputs, 1)
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_probs.extend(probs.cpu().numpy())

# ==========================================
# 4. METRICS & SAVING
# ==========================================
y_true = np.array(y_true)
y_pred = np.array(y_pred)

# Calculate Accuracy
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='weighted')

print("\n" + "="*30)
print(f"Test Accuracy: {acc*100:.2f}%")
print(f"Weighted F1:   {f1*100:.2f}%")
print("="*30)

# Print Detailed Report
print("\nClassification Report:")
print(classification_report(y_true, y_pred))

# Save Results to CSV
results_df = pd.DataFrame({
    'image_id': test_df['image_id'],
    'true_label': y_true,
    'predicted_label': y_pred
})
results_df.to_csv("test_predictions.csv", index=False)
print("✅ Predictions saved to 'test_predictions.csv'")

Mapping labels...
Total Test Images: 2077
Loading model from: maxvit_best_model.pth
Running Full Inference...


100%|██████████| 520/520 [02:32<00:00,  3.41it/s]


Test Accuracy: 28.12%
Weighted F1:   28.59%

Classification Report:
              precision    recall  f1-score   support

           0       0.22      0.01      0.01      1070
           1       0.00      0.00      0.00        59
           2       0.02      0.40      0.03        47
           3       0.74      0.69      0.72       806
           4       0.00      0.00      0.00        13
           5       0.00      0.00      0.00        16
           6       0.00      0.00      0.00         0
           7       0.00      0.00      0.00        34
           8       0.00      0.00      0.00        32

    accuracy                           0.28      2077
   macro avg       0.11      0.12      0.08      2077
weighted avg       0.40      0.28      0.29      2077

✅ Predictions saved to 'test_predictions.csv'



