In [None]:
!rm *.pth

In [None]:
# Cell 1: Imports and Setup
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image
import os
import timm
from torchvision.transforms import RandAugment, RandomErasing
import re

# Training combination 28/32: LR_INITIAL=0.0005, LR_FINE_TUNE=5e-05, ACTIVATE_WEIGHTS=0, DROPOUT_PROB=0.3, PATIENCE=15, EPOCHS=50

# Hyperparameters
ACTIVATE_WEIGHTS_TENSOR = 0  # Set to 1 to use class weights, 0 to disable
BATCH_SIZE = 128
EPOCHS_INITIAL = 50
EPOCHS_FINE_TUNE = 100
DROPOUT_PROB = 0.3
LR_INITIAL = 0.0005
# LR_FINE_TUNE = 5e-5
LR_FINE_TUNE = 5e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 7
SPLIT_FRAC = 0.8  # Fraction of data for training (rest for validation)

# AI_MODEL='convnext_large_mlp.clip_laion2b_soup_ft_in12k_384'
# AI_MODEL='convnext_xxlarge.clip_laion2b_soup_ft_in12k'
# AI_MODEL="convnextv2_nano.fcmae"
AI_MODEL="convnext_base"

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Cell 2: Load and Preprocess Dataset
# Load the CSV file and prepare image paths
# df = pd.read_csv('../scrape/psa_sales4.csv')
# image_dir = '../scrape/cropped4'
# df['full_path'] = df['certNumber'].apply(lambda x: os.path.join(image_dir, f"cert_{x}.jpg"))

image_dir = '../grade_comparisons/'  # <-- Update this path as needed

# List all files in the directory that are JPEG images
file_names = [
    f for f in os.listdir(image_dir)
    if os.path.isfile(os.path.join(image_dir, f)) and f.lower().endswith('.jpg')
]

def extract_grade(full_path):
    # This regex looks for the pattern "cropped_{grade}_cert"
    pattern = r'cropped_(\d+)_cert'
    match = re.search(pattern, full_path)
    if match:
        return int(match.group(1))
    else:
        # If the full_path doesn't match the pattern, print a message and return None
        print("no match for", full_path)
        return None

# Create a DataFrame with the full_paths
df = pd.DataFrame({'full_path': file_names})

# Apply the extraction function to create a new 'grade' column
df['grade'] = df['full_path'].apply(extract_grade)

# Create a column for the full file path
df['full_path'] = df['full_path'].apply(lambda x: os.path.join(image_dir, x))

# Check for missing images by using the full path
missing_count = len(df[df['full_path'].apply(lambda x: not os.path.exists(x))])
print(f"Missing images: {missing_count}")

# Filter out missing images based on the full path
df = df[df['full_path'].apply(os.path.exists)]
print(f"Dataset size after filtering: {len(df)}")

# Check class distribution
print("Unique grades in full dataset:", df['grade'].unique())
print("Number of unique grades in full dataset:", df['grade'].nunique())

grade_counts = df['grade'].value_counts()

# also add percent of total to grade_counts as a new column
grade_counts = grade_counts.reset_index()
grade_counts.columns = ['grade', 'count']
grade_counts['percent'] = grade_counts['count'] / len(df) * 100

print(grade_counts)

# Split into train and validation sets
train_df = df.sample(frac=SPLIT_FRAC, random_state=42)
val_df = df.drop(train_df.index)
print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}")

# Encode labels (0 to 18 for grades 1 to 10 with half-point increments)
le = LabelEncoder()
le.fit(df['grade'])
train_df['label'] = le.transform(train_df['grade'])
val_df['label'] = le.transform(val_df['grade'])
num_classes = len(le.classes_)  # Should be 19
print(f"Number of classes: {num_classes}")

# Compute class weights to handle imbalance (optional)
if ACTIVATE_WEIGHTS_TENSOR:
    class_weights = compute_class_weight('balanced', classes=np.unique(train_df['label']), y=train_df['label'])
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
    print("Class weights computed and moved to device")
else:
    class_weights_tensor = None
    print("Class weights disabled")

# Cell 3: Define Transforms
# Custom padding function to make images square
def pad_to_square(image, fill=0, padding_mode="constant"):
    w, h = image.size
    max_wh = max(w, h)
    pad_left = (max_wh - w) // 2
    pad_top = (max_wh - h) // 2
    pad_right = max_wh - w - pad_left
    pad_bottom = max_wh - h - pad_top
    return transforms.Pad((pad_left, pad_top, pad_right, pad_bottom), fill=fill, padding_mode=padding_mode)(image)

# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
    transforms.Resize(224),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Validation transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Cell 4: Create Dataset and DataLoader
# Custom Dataset class
class CardDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['full_path']
        label = self.df.iloc[idx]['label']
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Instantiate datasets
train_dataset = CardDataset(train_df, transform=train_transform)
val_dataset = CardDataset(val_df, transform=val_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print("Data loaders created")

# Cell 5: Define CORAL Head and Modify Model
# Custom CORAL head for ordinal regression
class CoralHead(nn.Module):
    def __init__(self, in_features, num_classes):
        super().__init__()
        self.fc = nn.Linear(in_features, num_classes - 1)  # K-1 logits for CORAL
    def forward(self, x):
        return self.fc(x)

# Load ConvNeXt model and replace the head with CORAL head
model = timm.create_model(AI_MODEL, pretrained=True)
in_features = model.head.fc.in_features
model.head.fc = CoralHead(in_features, num_classes)
model = model.to(device)
print("Model with CORAL head initialized")

# Cell 6: Define CORAL Loss Function
def coral_loss(logits, levels, class_weights=None):
    """
    Compute CORAL loss as the sum of binary cross-entropy losses for each threshold.
    
    Args:
        logits: Tensor of shape (batch_size, K-1) - model outputs
        levels: Tensor of shape (batch_size,) - true labels (0 to K-1)
        class_weights: Tensor of shape (K,) - optional class weights
    """
    batch_size = logits.size(0)
    levels = levels.view(-1, 1).to(device)
    # Create target matrix: 1 if true label > threshold, 0 otherwise
    targets = (levels > torch.arange(num_classes - 1).to(device)).float()
    # Compute binary cross-entropy for each threshold
    loss = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    # Apply class weights if provided
    if class_weights is not None:
        sample_weights = class_weights[levels.squeeze()].to(device)
        loss = loss * sample_weights.view(-1, 1)
    return loss.mean()

# Cell 7: Define Validation Function
def validate_coral(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            loss = criterion(logits, labels)
            val_loss += loss.item()
            # Predict grade by summing binary decisions
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).int()
            pred_levels = torch.sum(preds, dim=1).cpu().numpy()
            true_levels = labels.cpu().numpy()
            correct += np.sum(pred_levels == true_levels)
            total += labels.size(0)
    val_loss /= len(val_loader)
    val_acc = correct / total
    return val_loss, val_acc

# Cell 8: Define Training Function
def train_model_coral(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, phase='initial'):
    best_val_loss = float('inf')
    best_overfit_gap = -float('inf')  # Initialize the max gap (val_loss - train_loss)
    patience_counter = 0
    scaler = GradScaler()
    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast():
                logits = model(inputs)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        
        val_loss, val_acc = validate_coral(model, val_loader, criterion)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc)
        
        print(f'{phase} Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        scheduler.step()
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'best_model_{phase.lower()}.pth')
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print('Early stopping triggered')
                break
        
        # Compute the overfit gap (higher gap means more overfitting)
        overfit_gap = val_loss - train_loss
        if overfit_gap > best_overfit_gap:
            best_overfit_gap = overfit_gap
            torch.save(model.state_dict(), 'OVERFIT_model.pth')
            print(f"Epoch {epoch+1}: Overfit gap {overfit_gap:.4f} (new maximum), saving OVERFIT_model.pth")
    
    return history

# Cell 9: Initial Training (Frozen Backbone)
# Freeze all layers except the head
for name, param in model.named_parameters():
    if 'head' not in name:
        param.requires_grad = False

# Define optimizer and scheduler
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_INITIAL, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_INITIAL, eta_min=1e-6)

# Train the model (initial phase)
print("Starting initial training...")
history_initial = train_model_coral(
    model, train_loader, val_loader,
    lambda logits, labels: coral_loss(logits, labels, class_weights_tensor),
    optimizer, scheduler, EPOCHS_INITIAL, 'Initial'
)

# Cell 10: Fine-Tuning (Unfreeze All Layers)
# Load the best model from initial training
model.load_state_dict(torch.load('best_model_initial.pth'))
print("Loaded best model from initial training")

# Unfreeze all layers
for param in model.parameters():
    param.requires_grad = True

# Define optimizer and scheduler for fine-tuning
optimizer = optim.AdamW(model.parameters(), lr=LR_FINE_TUNE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_FINE_TUNE, eta_min=1e-6)

# Fine-tune the model
print("Starting fine-tuning...")
history_fine = train_model_coral(
    model, train_loader, val_loader,
    lambda logits, labels: coral_loss(logits, labels, class_weights_tensor),
    optimizer, scheduler, EPOCHS_FINE_TUNE, 'Fine-tune'
)

# Cell 11: Final Evaluation and Save Model
# Load the best fine-tuned model
model.load_state_dict(torch.load('best_model_fine-tune.pth'))
print("Loaded best fine-tuned model")

# Final evaluation
val_loss, val_acc = validate_coral(model, val_loader, lambda logits, labels: coral_loss(logits, labels, class_weights_tensor))
print(f"Final Validation Loss: {val_loss:.4f}, Final Validation Accuracy: {val_acc:.4f}")

# Save the final model
torch.save(model.state_dict(), 'card_grader_model_coral.pth')
print("Final model saved as 'card_grader_model_coral.pth'")

# Cell 12: Prediction Function
def predict_grade_coral(img_path, model, le, transform):
    """
    Predict the grade of a card image using the CORAL model.
    
    Args:
        img_path: Path to the image file
        model: Trained CORAL model
        le: LabelEncoder instance
        transform: Image transformation pipeline
    Returns:
        Predicted grade as a string
    """
    model.eval()
    image = Image.open(img_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(image)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).int()
        pred_level = torch.sum(preds).item()
        if pred_level == 0:
            return le.inverse_transform([0])[0]
        elif pred_level == num_classes - 1:
            return le.inverse_transform([num_classes - 1])[0]
        else:
            return le.inverse_transform([pred_level])[0]

In [None]:
# Cell 13: Plot Training Metrics
import matplotlib.pyplot as plt

# Combine the histories from the initial training and fine-tuning phases
combined_train_loss = history_initial['train_loss'] + history_fine['train_loss']
combined_val_loss = history_initial['val_loss'] + history_fine['val_loss']
combined_val_accuracy = history_initial['val_accuracy'] + history_fine['val_accuracy']

# Create an array of epoch numbers
epochs = np.arange(1, len(combined_train_loss) + 1)

plt.figure(figsize=(14, 5))

# Plot training and validation loss
plt.subplot(1, 2, 1)
plt.plot(epochs, combined_train_loss, label='Train Loss', marker='o')
plt.plot(epochs, combined_val_loss, label='Validation Loss', marker='x')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss over Epochs')
plt.legend()

# Plot validation accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, combined_val_accuracy, label='Validation Accuracy', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# make sure to use the gpu for inference

##########################################################################
##########################################################################
# TODO: BEST MODEL PROBABLY NOT LOADED!!!
# model.load_state_dict(torch.load(f'best_model_fine-tune_{idx}.pth'))
##########################################################################
##########################################################################

# load good model
torch.save(model.state_dict(), 'card_grader_model_coral.pth')
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

# show a progress bar for inference
from tqdm import tqdm
tqdm.pandas()

# infer on only the validation set
val_df['predicted'] = val_df['full_path'].progress_apply(lambda x: predict_grade_coral(x, model.to(device), le, val_transform))
val_df['correct'] = val_df['predicted'] == val_df['grade']
print(f"Overall Validation Set Accuracy: {val_df['correct'].mean():.4f}")
# print accuracy by grade, sorted by accuracy, accuracy in percent
print(val_df.groupby('grade')['correct'].mean().sort_values(ascending=False) * 100)

# for fun infer on the training set too
train_df['predicted'] = train_df['full_path'].progress_apply(lambda x: predict_grade_coral(x, model.to(device), le, val_transform))
train_df['correct'] = train_df['predicted'] == train_df['grade']
print(f"Overall Training Set Accuracy: {train_df['correct'].mean():.4f}")
# print accuracy by grade, sorted by accuracy, accuracy in percent
print(train_df.groupby('grade')['correct'].mean().sort_values(ascending=False) * 100)

