In [None]:
# Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import timm  # For Vision Transformers
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image
import os
import cv2
from torchvision import transforms
import torch.nn.functional as F

# Configuration
ACTIVATE_WEIGHTS_TENSOR = 0  # Enable class weights
ACTIVATE_CROPPING = 0  # Enable cropping
BATCH_SIZE = 32  # Suitable for ViT-Large with 384x384 inputs on H200
EPOCHS = 100  # Single-phase training
DROPOUT_PROB = 0.3
LR = 1e-5  # Lower learning rate for ViT
WEIGHT_DECAY = 1e-4
PATIENCE = 10  # Increased patience for early stopping
SPLIT_FRAC = 0.7

# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Load Dataset
df = pd.read_csv('../scrape/psa_sales_20250222_170248.csv')  # Replace with your CSV path
image_dir = '../scrape/cropped'  # Replace with your image directory
df['filename'] = df['certNumber'].apply(lambda x: os.path.join(image_dir, f"cert_{x}.jpg"))

# Print Missing Images
missing_images = len(df[df['filename'].apply(lambda x: not os.path.exists(x))])
print(f"Missing images: {missing_images}")

# Remove Non-Existing Images
df = df[df['filename'].apply(os.path.exists)]

# Split into Training and Validation Sets
train_df = df.sample(frac=SPLIT_FRAC, random_state=42)
val_df = df.drop(train_df.index)

# Encode Labels
le = LabelEncoder()
le.fit(df['grade'])
train_df['label'] = le.transform(train_df['grade'])
val_df['label'] = le.transform(val_df['grade'])

# Check Class Distribution
print("Unique grades in full dataset:", df['grade'].unique())
print("Number of unique grades in full dataset:", df['grade'].nunique())

grade_counts = df['grade'].value_counts().reset_index()
grade_counts.columns = ['grade', 'count']
grade_counts['percent'] = grade_counts['count'] / len(df) * 100
print(grade_counts)

# Compute Class Weights for Imbalance
if ACTIVATE_WEIGHTS_TENSOR:
    classes = np.unique(train_df['grade'])
    class_weights = compute_class_weight('balanced', classes=classes, y=train_df['grade'])
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
    print("Class weights computed.")

# Define Cropping Functions
def crop_card_for_light_image(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, otsu_grad = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(otsu_grad, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return None
    height, width = image.shape[:2]
    image_area = height * width
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    for contour in contours:
        peri = cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, 0.001 * peri, True)
        x, y, w, h = cv2.boundingRect(approx)
        area = w * h
        if 0.48 * image_area <= area <= 0.6 * image_area:
            return image[y:y+h, x:x+w]
    return None

def crop_card_for_dark_image(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(gray, (3, 3), -10)
    adaptive_binary = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 7, 3)
    edges = cv2.Canny(adaptive_binary, 100, 200)
    binarized_grad = 255 - edges
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
    open_binarized_grad = cv2.morphologyEx(binarized_grad, cv2.MORPH_OPEN, kernel)
    contours, _ = cv2.findContours(open_binarized_grad, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    if not contours:
        return None
    height, width = image.shape[:2]
    image_area = height * width
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    for contour in contours:
        area = cv2.contourArea(contour)
        if 0.48 * image_area <= area <= 0.7 * image_area:
            x, y, w, h = cv2.boundingRect(contour)
            return image[y:y+h, x:x+w]
    return None

def crop_card(image):
    cropped = crop_card_for_light_image(image)
    if cropped is not None:
        return cropped
    cropped = crop_card_for_dark_image(image)
    if cropped is not None:
        return cropped
    return image  # Fallback to original if cropping fails

def pad_to_square(image, fill=0, padding_mode="constant"):
    w, h = image.size
    if w == h:
        return image
    max_wh = max(w, h)
    pad_left = (max_wh - w) // 2
    pad_top = (max_wh - h) // 2
    pad_right = max_wh - w - pad_left
    pad_bottom = max_wh - h - pad_top
    padding = (pad_left, pad_top, pad_right, pad_bottom)
    return transforms.Pad(padding, fill=fill, padding_mode=padding_mode)(image)

# Define Enhanced Transforms with Augmentation
train_transform = transforms.Compose([
    transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
    transforms.Resize(384),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
    transforms.Resize(384),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Visualize Sample Images After Transform
def show_sample_images(df, transform, title="Sample Images"):
    sample_df = df.sample(5)
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for ax, (_, row) in zip(axes, sample_df.iterrows()):
        img_path = row['filename']
        image = Image.open(img_path).convert('RGB')
        if ACTIVATE_CROPPING:
            image_np = np.array(image)
            cropped_np = crop_card(image_np)
            image = Image.fromarray(cropped_np)
        image = transform(image)
        ax.imshow(image.permute(1, 2, 0).numpy() * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))
        ax.axis('off')
        ax.set_title(row['grade'])
    plt.suptitle(title)
    plt.show()

print("\nSample images after train_transform:")
show_sample_images(train_df, train_transform, "Training Samples After Transform")

# Custom Dataset
class CardDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        self.failed_crops = [] if ACTIVATE_CROPPING else None

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['filename']
        label = self.df.iloc[idx]['label']
        image = Image.open(img_path).convert('RGB')
        if ACTIVATE_CROPPING:
            image_np = np.array(image)
            cropped_np = crop_card(image_np)
            if cropped_np is image_np:  # Cropping failed
                self.failed_crops.append(img_path)
            image = Image.fromarray(cropped_np)
        if self.transform:
            image = self.transform(image)
        return image, label

# Create Datasets and Dataloaders
train_dataset = CardDataset(train_df, transform=train_transform)
val_dataset = CardDataset(val_df, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Log Failed Crops
if ACTIVATE_CROPPING:
    # Note: Failed crops are logged during dataset iteration
    print("Failed crops will be logged after first epoch.")

# Load Vision Transformer Model
model = timm.create_model('vit_large_patch16_384', pretrained=True, num_classes=len(le.classes_))
model = model.to(device)

# Set All Parameters to Require Gradients
for param in model.parameters():
    param.requires_grad = True

# Loss Function with Class Weights and Label Smoothing
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor if ACTIVATE_WEIGHTS_TENSOR else None, label_smoothing=0.1)

# Optimizer and Scheduler
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

# Validation Function
def validate(model, val_loader, criterion, use_tta=False):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            if use_tta:
                outputs = []
                # Original
                with torch.amp.autocast('cuda'):
                    out = model(inputs)
                outputs.append(out)
                # Horizontal Flip
                inputs_flipped = torch.flip(inputs, dims=[3])
                with torch.amp.autocast('cuda'):
                    out_flipped = model(inputs_flipped)
                outputs.append(out_flipped)
                # Average Predictions and Cast to Float32
                outputs = torch.stack(outputs).mean(dim=0).float()
            else:
                with torch.amp.autocast('cuda'):
                    outputs = model(inputs)
                outputs = outputs.float()  # Cast to Float32
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_loss /= len(val_loader)
    val_acc = correct / total
    return val_loss, val_acc

# Training Function with Early Stopping
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    best_val_loss = float('inf')
    patience_counter = 0
    scaler = torch.amp.GradScaler('cuda')  # Updated API
    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'):  # Updated API
                outputs = model(inputs)
                loss = criterion(outputs.float(), labels)  # Cast to Float32
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        val_loss, val_acc = validate(model, val_loader, criterion, use_tta=False)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc)

        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        scheduler.step()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print('Early stopping')
                break

        # Log failed crops after first epoch
        if epoch == 0 and ACTIVATE_CROPPING:
            print(f"Training failed crops: {len(train_dataset.failed_crops)}")
            print(f"Validation failed crops: {len(val_dataset.failed_crops)}")

    return history

# Train the Model
history = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=EPOCHS)

# Load Best Model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate with and without TTA
val_loss, val_acc = validate(model, val_loader, criterion, use_tta=False)
print(f"Validation Accuracy without TTA: {val_acc:.4f}")

val_loss_tta, val_acc_tta = validate(model, val_loader, criterion, use_tta=True)
print(f"Validation Accuracy with TTA: {val_acc_tta:.4f}")

# Visualize Training History
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['val_accuracy'], label='Val Accuracy')
plt.plot(history['train_loss'], label='Train Loss')
plt.title('Model Accuracy and Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Metric')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history['val_loss'], label='Val Loss')
plt.plot(history['train_loss'], label='Train Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()

# Save Final Model
torch.save(model.state_dict(), 'card_grader_model.pth')

# Prediction Function with TTA
def predict_grade(img_path, model, le, transform, use_tta=False):
    model.eval()
    image = Image.open(img_path).convert('RGB')
    if ACTIVATE_CROPPING:
        image_np = np.array(image)
        cropped_np = crop_card(image_np)
        image = Image.fromarray(cropped_np)
    if use_tta:
        # Original
        img_original = transform(image).unsqueeze(0).to(device)
        # Horizontal Flip
        img_flipped = transform(transforms.functional.hflip(image)).unsqueeze(0).to(device)
        with torch.no_grad():
            with autocast():
                output_original = model(img_original)
                output_flipped = model(img_flipped)
            outputs = (output_original + output_flipped) / 2
    else:
        img = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            with autocast():
                outputs = model(img)
    _, predicted = torch.max(outputs, 1)
    return le.inverse_transform([predicted.item()])[0]

# Test Prediction
sample_img_path = val_df['filename'].iloc[0]
grade_pred = predict_grade(sample_img_path, model, le, val_transform, use_tta=True)
print(f"Predicted grade for sample image: {grade_pred}")