In [None]:
# Revised Code: PSA Card Grade Prediction with EfficientNet

import os
import cv2
import torch
import timm  # Install via: pip install timm
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight

# --- Configuration ---
ACTIVATE_WEIGHTS_TENSOR = False
ACTIVATE_CROPPING = False
BATCH_SIZE = 64
SET_FRAC = 0.7
NUM_EPOCHS = 50  # initial training, can adjust later

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# --- Load dataset ---
df = pd.read_csv('../scrape/psa_sales_20250222_170248.csv')
image_dir = '../scrape/pictures'
df['filename'] = df['certNumber'].apply(lambda x: os.path.join(image_dir, f"cert_{x}.jpg"))
print(f"Missing images: {len(df[~df['filename'].apply(os.path.exists)])}")
df = df[df['filename'].apply(os.path.exists)]

# --- Split into training and validation sets ---
train_df = df.sample(frac=SET_FRAC, random_state=42)
val_df = df.drop(train_df.index)

# --- Encode labels ---
le = LabelEncoder()
le.fit(df['grade'])
train_df['label'] = le.transform(train_df['grade'])
val_df['label'] = le.transform(val_df['grade'])

print("Unique grades:", df['grade'].unique())

grade_counts = df['grade'].value_counts().reset_index()
grade_counts.columns = ['grade', 'count']
grade_counts['percent'] = grade_counts['count'] / len(df) * 100
print(grade_counts)

# --- Compute class weights (if needed) ---
if ACTIVATE_WEIGHTS_TENSOR:
    classes = np.unique(train_df['grade'])
    class_weights = compute_class_weight('balanced', classes=classes, y=train_df['grade'])
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# --- Define cropping functions (unchanged) ---
def crop_card_for_light_image(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, otsu_grad = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(otsu_grad, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return None
    height, width = image.shape[:2]
    image_area = height * width
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    for contour in contours:
        peri = cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, 0.001 * peri, True)
        x, y, w, h = cv2.boundingRect(approx)
        area = w * h
        if 0.48 * image_area <= area <= 0.6 * image_area:
            return image[y:y+h, x:x+w]
    return None

def crop_card_for_dark_image(image):
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(gray, (3, 3), -10)
    adaptive_binary = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 7, 3)
    edges = cv2.Canny(adaptive_binary, 100, 200)
    binarized_grad = 255 - edges
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
    open_binarized_grad = cv2.morphologyEx(binarized_grad, cv2.MORPH_OPEN, kernel)
    contours, _ = cv2.findContours(open_binarized_grad, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    if not contours:
        return None
    height, width = image.shape[:2]
    image_area = height * width
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    for contour in contours:
        area = cv2.contourArea(contour)
        if 0.48 * image_area <= area <= 0.7 * image_area:
            x, y, w, h = cv2.boundingRect(contour)
            return image[y:y+h, x:x+w]
    return None

def crop_card(image):
    cropped = crop_card_for_light_image(image)
    if cropped is not None:
        return cropped
    cropped = crop_card_for_dark_image(image)
    if cropped is not None:
        return cropped
    return image  # Fallback

def pad_to_square(image, fill=0, padding_mode="constant"):
    w, h = image.size
    if w == h:
        return image
    max_wh = max(w, h)
    pad_left = (max_wh - w) // 2
    pad_top = (max_wh - h) // 2
    pad_right = max_wh - w - pad_left
    pad_bottom = max_wh - h - pad_top
    padding = (pad_left, pad_top, pad_right, pad_bottom)
    return transforms.Pad(padding, fill=fill, padding_mode=padding_mode)(image)

# --- Updated transform pipeline with augmentations ---
train_transform = transforms.Compose([
    transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1))
])

val_transform = transforms.Compose([
    transforms.Lambda(lambda img: pad_to_square(img, fill=0)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def show_sample_images(df, transform):
    sample_df = df.sample(5)
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    for ax, (_, row) in zip(axes, sample_df.iterrows()):
        img_path = row['filename']
        image = Image.open(img_path).convert('RGB')
        image = transform(image)
        ax.imshow(image.permute(1, 2, 0))
        ax.axis('off')
    plt.show()

print("\nSample transformed training images:")
show_sample_images(train_df, train_transform)

# --- Custom Dataset ---
class CardDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        if ACTIVATE_CROPPING:
            self.failed_crops = []

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['filename']
        label = self.df.iloc[idx]['label']
        image = Image.open(img_path).convert('RGB')
        if ACTIVATE_CROPPING:
            image_np = np.array(image)
            cropped_np = crop_card(image_np)
            if cropped_np is image_np:
                self.failed_crops.append(img_path)
            image = Image.fromarray(cropped_np)
        if self.transform:
            image = self.transform(image)
        return image, label

train_dataset = CardDataset(train_df, transform=train_transform)
val_dataset = CardDataset(val_df, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

if ACTIVATE_CROPPING:
    print(f"Failed crops in training: {len(train_dataset.failed_crops)}")
    print(f"Failed crops in validation: {len(val_dataset.failed_crops)}")

# --- Build the model using EfficientNet ---
model = timm.create_model('efficientnet_b3', pretrained=True)
n_features = model.classifier.in_features  # EfficientNet classifier features
model.classifier = nn.Sequential(
    nn.Linear(n_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, len(le.classes_))
)
model = model.to(device)

# Optionally freeze feature extractor for initial training
for name, param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False

# --- Loss, Optimizer, and Scheduler ---
if ACTIVATE_WEIGHTS_TENSOR:
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
else:
    criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# --- Validation function ---
def validate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    return total_loss / len(loader), correct / total

# --- Training function with early stopping ---
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, phase='Initial'):
    best_val_loss = float('inf')
    patience, patience_counter = 7, 0
    scaler = GradScaler()
    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()
        train_loss = running_loss / len(train_loader)

        val_loss, val_acc = validate(model, val_loader, criterion)
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc)

        print(f"{phase} Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        scheduler.step()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'best_model_{phase.lower()}.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break
    return history

# --- Initial training ---
history_initial = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, NUM_EPOCHS, phase='Initial')

# --- Fine-tuning: Unfreeze later layers ---
model.load_state_dict(torch.load('best_model_initial.pth'))
for name, param in model.named_parameters():
    if 'classifier' in name or 'blocks.5' in name or 'blocks.6' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
history_finetune = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, NUM_EPOCHS, phase='FineTune')

# --- Load best fine-tuned model ---
model.load_state_dict(torch.load('best_model_finetune.pth'))

# --- Visualize training history ---
plt.figure(figsize=(12, 5))
plt.subplot(1,2,1)
plt.plot(history_initial['val_accuracy'] + history_finetune['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()

plt.subplot(1,2,2)
plt.plot(history_initial['val_loss'] + history_finetune['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()

plt.tight_layout()
plt.show()

# --- Save final model ---
torch.save(model.state_dict(), 'card_grader_model_efficientnet.pth')

# --- Prediction function ---
def predict_grade(img_path, model, le, transform):
    model.eval()
    image = Image.open(img_path).convert('RGB')
    image_np = np.array(image)
    cropped_np = crop_card(image_np)
    image = Image.fromarray(cropped_np)
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(image)
        _, pred = torch.max(outputs, 1)
    return le.inverse_transform([pred.item()])[0]

# --- Evaluate final model ---
val_loss, val_acc = validate(model, val_loader, criterion)
print(f"Overall Validation Accuracy: {val_acc:.4f}")
