### ConvNeXt

### Load data

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image
from tqdm.notebook import tqdm
import timm
import warnings

warnings.filterwarnings("ignore")

# --- CONFIGURATION (PRO SETTINGS) ---
Config = {
    'model_name': 'convnext_tiny', 
    'img_size': 224,
    'batch_size': 32,             
    'epochs': 15,                 # Increased to allow learning rare classes
    'learning_rate': 3e-4,        # Slightly higher start, will decay
    'weight_decay': 1e-5,         # Prevents overfitting
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'seed': 42
}

# --- PATHS ---
ROOT_DIR = "/kaggle/input/vindr-spinexr-modified/vindr-spinexr-a-large-annotated-medical-image-dataset"
TRAIN_CSV_PATH = os.path.join(ROOT_DIR, "annotations/train.csv")
TEST_CSV_PATH = os.path.join(ROOT_DIR, "annotations/test.csv")
TRAIN_IMG_DIR = os.path.join(ROOT_DIR, "train_png")
TEST_IMG_DIR = os.path.join(ROOT_DIR, "test_png")

print(f"Device: {Config['device']}")
print("Configuration Loaded.")



Device: cuda
Configuration Loaded.


### Load Model

In [2]:
# 1. Load Data
df = pd.read_csv(TRAIN_CSV_PATH)
ID_COL = 'image_id'
LABEL_COL = 'lesion_type'

# 2. Preprocessing
# Drop duplicates (Force Single-Label Classification)
df = df.drop_duplicates(subset=[ID_COL]).reset_index(drop=True)

# Encode Labels
encoder = LabelEncoder()
df['label_encoded'] = encoder.fit_transform(df[LABEL_COL])
Config['num_classes'] = len(encoder.classes_)

# --- CRITICAL: CALCULATE CLASS WEIGHTS ---
# This fixes the imbalance problem
class_weights = compute_class_weight(
    class_weight='balanced', 
    classes=np.unique(df['label_encoded']), 
    y=df['label_encoded']
)
weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(Config['device'])

print("Computed Class Weights (Penalty for missing):")
for cls, w in zip(encoder.classes_, class_weights):
    print(f"  {cls}: {w:.2f}x")

# 3. Dataset Class
class SpineDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        self.df = dataframe
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row[ID_COL]
        # Handle extension
        img_name = f"{img_id}.png" if not str(img_id).endswith('.png') else img_id
        img_path = os.path.join(self.root_dir, img_name)
        
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            image = Image.new('RGB', (Config['img_size'], Config['img_size']))
            
        label = torch.tensor(row['label_encoded'], dtype=torch.long)

        if self.transform:
            image = self.transform(image)
        return image, label

# 4. Transforms (Stronger Augmentation)
train_transforms = transforms.Compose([
    transforms.Resize((Config['img_size'], Config['img_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15), 
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # New: Helps model generalize
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((Config['img_size'], Config['img_size'])),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 5. Loaders
train_df, val_df = train_test_split(df, test_size=0.2, random_state=Config['seed'], stratify=df['label_encoded'])

train_loader = DataLoader(SpineDataset(train_df, TRAIN_IMG_DIR, train_transforms), 
                          batch_size=Config['batch_size'], shuffle=True, num_workers=2)

val_loader = DataLoader(SpineDataset(val_df, TRAIN_IMG_DIR, val_transforms), 
                        batch_size=Config['batch_size'], shuffle=False, num_workers=2)

print(f"Training on {len(train_df)} images | Validating on {len(val_df)} images")

Computed Class Weights (Penalty for missing):
  Disc space narrowing: 4.77x
  Foraminal stenosis: 4.50x
  No finding: 0.25x
  Osteophytes: 0.32x
  Other lesions: 6.36x
  Spondylolysthesis: 15.20x
  Surgical implant: 7.77x
  Vertebral collapse: 18.73x
Training on 6711 images | Validating on 1678 images


### Train model

In [3]:
# 1. Model Setup
model = timm.create_model(Config['model_name'], pretrained=True)
model.head.fc = nn.Linear(model.head.fc.in_features, Config['num_classes'])
model = model.to(Config['device'])

# 2. Loss & Optimizer (WITH WEIGHTS)
# The 'weight' argument here tells the model: "Pay attention to the rare classes!"
criterion = nn.CrossEntropyLoss(weight=weights_tensor) 

optimizer = optim.AdamW(model.parameters(), lr=Config['learning_rate'], weight_decay=Config['weight_decay'])

# Scheduler: Lowers LR if validation loss stops improving
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# 3. Training Loop
best_acc = 0.0

print("Starting Training...")

for epoch in range(Config['epochs']):
    # --- TRAIN ---
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    # Using simple loop to avoid tqdm nesting issues
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config['epochs']}"):
        images, labels = images.to(Config['device']), labels.to(Config['device'])
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels) # This loss is now WEIGHTED
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    avg_train_loss = train_loss / len(train_loader)
    train_acc = 100 * correct / total
    
    # --- VALIDATE ---
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(Config['device']), labels.to(Config['device'])
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    avg_val_loss = val_loss / len(val_loader)
    val_acc = 100 * correct / total
    
    # --- LOGGING ---
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    
    # Update Scheduler
    scheduler.step(avg_val_loss)
    
    # Save Best
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_convnext_model.pth")
        print(">>> Best Model Saved!")
    print("-" * 30)

Starting Training...


Epoch 1/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.1864 | Train Acc: 11.24%
Val Loss:   2.1389 | Val Acc:   1.61%
>>> Best Model Saved!
------------------------------


Epoch 2/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.1166 | Train Acc: 19.52%
Val Loss:   2.1050 | Val Acc:   38.74%
>>> Best Model Saved!
------------------------------


Epoch 3/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.1018 | Train Acc: 20.95%
Val Loss:   2.0735 | Val Acc:   50.77%
>>> Best Model Saved!
------------------------------


Epoch 4/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0941 | Train Acc: 27.78%
Val Loss:   2.1073 | Val Acc:   1.97%
------------------------------


Epoch 5/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0953 | Train Acc: 18.61%
Val Loss:   2.0658 | Val Acc:   50.77%
------------------------------


Epoch 6/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0833 | Train Acc: 29.49%
Val Loss:   2.0701 | Val Acc:   2.80%
------------------------------


Epoch 7/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0801 | Train Acc: 33.56%
Val Loss:   2.0622 | Val Acc:   2.62%
------------------------------


Epoch 8/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0791 | Train Acc: 35.14%
Val Loss:   2.0576 | Val Acc:   50.77%
------------------------------


Epoch 9/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0811 | Train Acc: 31.59%
Val Loss:   2.0594 | Val Acc:   38.74%
------------------------------


Epoch 10/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0999 | Train Acc: 31.99%
Val Loss:   2.0562 | Val Acc:   38.74%
------------------------------


Epoch 11/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0840 | Train Acc: 34.67%
Val Loss:   2.0663 | Val Acc:   50.42%
------------------------------


Epoch 12/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0707 | Train Acc: 33.29%
Val Loss:   2.0731 | Val Acc:   2.62%
------------------------------


Epoch 13/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 2.0603 | Train Acc: 36.88%
Val Loss:   1.9921 | Val Acc:   40.46%
------------------------------


Epoch 14/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 1.9624 | Train Acc: 34.57%
Val Loss:   1.8759 | Val Acc:   31.29%
------------------------------


Epoch 15/15:   0%|          | 0/210 [00:00<?, ?it/s]

Train Loss: 1.9029 | Train Acc: 26.73%
Val Loss:   1.8618 | Val Acc:   21.22%
------------------------------


### Test Model

In [4]:
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix
from sklearn.preprocessing import label_binarize

# 1. Setup Test Data
test_df = pd.read_csv(TEST_CSV_PATH).drop_duplicates(subset=[ID_COL]).reset_index(drop=True)
# Keep only classes we know
test_df = test_df[test_df[LABEL_COL].isin(encoder.classes_)]
test_df['label_encoded'] = encoder.transform(test_df[LABEL_COL])

test_dataset = SpineDataset(test_df, TEST_IMG_DIR, transform=val_transforms)
# num_workers=0 to prevent freezing
test_loader = DataLoader(test_dataset, batch_size=Config['batch_size'], shuffle=False, num_workers=0)

# 2. Load Best Model
model.load_state_dict(torch.load("best_convnext_model.pth"))
model.eval()

# 3. Predict
y_true, y_pred, y_probs = [], [], []

print("Running Inference...")
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(Config['device'])
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        y_probs.extend(probs.cpu().numpy())

# 4. Metrics
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_probs = np.array(y_probs)

f1 = f1_score(y_true, y_pred, average='weighted')

# Specificity & Sensitivity
cm = confusion_matrix(y_true, y_pred)
specificities = []
sensitivities = []
for i in range(Config['num_classes']):
    tn = cm.sum() - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i])
    fp = cm[:, i].sum() - cm[i, i]
    fn = cm[i, :].sum() - cm[i, i]
    tp = cm[i, i]
    spec = tn / (tn + fp) if (tn + fp) > 0 else 0
    sens = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificities.append(spec)
    sensitivities.append(sens)

avg_spec = np.mean(specificities)
avg_sens = np.mean(sensitivities)

# AUROC
try:
    if Config['num_classes'] == 2:
        auroc = roc_auc_score(y_true, y_probs[:, 1])
    else:
        y_true_bin = label_binarize(y_true, classes=range(Config['num_classes']))
        auroc = roc_auc_score(y_true_bin, y_probs, multi_class='ovr', average='weighted')
except:
    auroc = 0.5

print("\n" + "="*40)
print(f"{'METRIC':<15} | {'VALUE':<10}")
print("-" * 40)
print(f"{'AUROC':<15} | {auroc*100:.2f}%")
print(f"{'F1 Score':<15} | {f1*100:.2f}%")
print(f"{'Sensitivity':<15} | {avg_sens*100:.2f}%")
print(f"{'Specificity':<15} | {avg_spec*100:.2f}%")
print("="*40)

Running Inference...


  0%|          | 0/65 [00:00<?, ?it/s]


METRIC          | VALUE     
----------------------------------------
AUROC           | 60.14%
F1 Score        | 35.03%
Sensitivity     | 12.50%
Specificity     | 87.50%
