# MAE Evaluation

In [None]:
import os
import json
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
import timm

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
DATA_ROOT = "/kaggle/input/ssl-dataset/ssl_dataset"
CHECKPOINT_PATH = "/kaggle/input/maemodel/pytorch/default/1/mae_checkpoint_epoch50.pth"
VAL_FOLDER = "val.X"
LABELS_PATH = os.path.join(DATA_ROOT, "Labels.json")

# Hyperparameters
BATCH_SIZE = 64
IMG_SIZE = 224

# --- MAE Model Loading ---
# (Assumes you have defined CustomMAE elsewhere, can import it or run code in continue)
encoder = timm.create_model('vit_tiny_patch16_224', pretrained=False)
encoder.reset_classifier(0)

patch_embed     = encoder.patch_embed
pos_embed       = encoder.pos_embed
encoder_blocks  = encoder.blocks
encoder_norm    = encoder.norm
embed_dim       = encoder.embed_dim

mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
decoder_layer = nn.TransformerEncoderLayer(
    d_model=embed_dim,
    nhead=4,
    dim_feedforward=embed_dim*2,
    batch_first=True
)
decoder = nn.TransformerEncoder(decoder_layer, num_layers=4)
reconstruction_head = nn.Linear(embed_dim, 16*16*3)

model = CustomMAE(
    encoder=encoder,
    patch_embed=patch_embed,
    pos_embed=pos_embed,
    encoder_blocks=encoder_blocks,
    encoder_norm=encoder_norm,
    mask_token=mask_token,
    decoder=decoder,
    reconstruction_head=reconstruction_head
).to(device)

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

def get_features(self, x):
    x = self.patch_embed(x)
    x = x + self.pos_embed[:, 1:, :]
    for blk in self.encoder_blocks:
        x = blk(x)
    x = self.encoder_norm(x)
    return x.mean(dim=1)

model.get_features = get_features.__get__(model)

# --- Dataset Class ---
class MAEEvalDataset(Dataset):
    def __init__(self, folder_path, label_dict, transform):
        self.image_paths = []
        self.labels = []
        self.class_names = sorted(list(set(label_dict.values())))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.class_names)}
        
        for root, _, files in os.walk(folder_path):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    cls_folder = os.path.basename(root)
                    if cls_folder in label_dict:
                        self.image_paths.append(os.path.join(root, file))
                        self.labels.append(self.class_to_idx[label_dict[cls_folder]])
        
        self.transform = transform
        print(f"Found {len(self.image_paths)} labeled images across {len(self.class_names)} classes")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('RGB')
            label = self.labels[idx]
            if self.transform:
                img = self.transform(img)
            return img, label
        except Exception as e:
            print(f"Error loading {self.image_paths[idx]}: {e}")
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

# --- Transforms ---
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# --- Load Labels & DataLoader ---
with open(LABELS_PATH) as f:
    label_dict = json.load(f)

val_dataset = MAEEvalDataset(
    folder_path=os.path.join(DATA_ROOT, VAL_FOLDER),
    label_dict=label_dict,
    transform=val_transform
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)
print(f"Total validation images: {len(val_dataset)}")

# --- Feature Extraction ---
def extract_features(loader):
    feats_list, lbls_list = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(loader, desc="Extracting features"):
            imgs = imgs.to(device)
            feats = model.get_features(imgs).cpu()
            feats_list.append(feats)
            lbls_list.append(lbls)
    feats = torch.cat(feats_list)
    lbls = torch.cat(lbls_list)
    valid = lbls != -1
    return feats[valid], lbls[valid]

X_val, y_val = extract_features(val_loader)
print(f"Feature shape: {X_val.shape}, Labels shape: {y_val.shape}")

# --- Linear Probing ---
clf = LogisticRegression(
    max_iter=1000,
    multi_class='multinomial',
    solver='lbfgs',
    random_state=42
)
clf.fit(X_val.numpy(), y_val.numpy())

# --- Evaluation ---
val_probs = clf.predict_proba(X_val.numpy())
val_preds = np.argmax(val_probs, axis=1)

acc = accuracy_score(y_val.numpy(), val_preds)
f1  = f1_score(y_val.numpy(), val_preds, average='macro')

# Top-1 & Top-5
probs_tensor = torch.tensor(val_probs)
true_tensor = torch.tensor(y_val.numpy())

top1 = (torch.argmax(probs_tensor,1) == true_tensor).float().mean().item()
top5 = (torch.any(torch.topk(probs_tensor,5,1).indices == true_tensor.unsqueeze(1),1)
        .float().mean().item())

print("\nEvaluation Results:")
print(f"F1 Score (Macro):  {f1*100:.2f}%")
print(f"Top-1 Accuracy:     {top1*100:.2f}%")
print(f"Top-5 Accuracy:     {top5*100:.2f}%")

# --- Confusion Matrix ---
cm = confusion_matrix(y_val.numpy(), val_preds)
class_names = val_dataset.class_names

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

# --- Training History Visualization ---
def plot_training_history():
    # Try to get training history from checkpoint
    if 'train_loss_history' in checkpoint and 'val_loss_history' in checkpoint:
        train_loss = checkpoint['train_loss_history']
        val_loss = checkpoint['val_loss_history']
        train_acc = checkpoint.get('train_acc_history', [])
        val_acc = checkpoint.get('val_acc_history', [])
        
        epochs = len(train_loss)
        
        plt.figure(figsize=(12, 5))
        
        # Loss plot
        plt.subplot(1, 2, 1)
        plt.plot(range(1, epochs+1), train_loss, label='Training Loss')
        plt.plot(range(1, epochs+1), val_loss, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        
       
        if train_acc and val_acc:
            plt.subplot(1, 2, 2)
            plt.plot(range(1, epochs+1), train_acc, label='Training Accuracy')
            plt.plot(range(1, epochs+1), val_acc, label='Validation Accuracy')
            plt.xlabel('Epochs')
            plt.ylabel('Accuracy')
            plt.title('Training and Validation Accuracy')
            plt.legend()
        
        plt.tight_layout()
        plt.savefig('/kaggle/working/training_history.png')
        plt.show()
    else:
        print("No training history found in checkpoint")

plot_training_history()

# --- Save Results
with open('/kaggle/working/results.txt', 'w') as f:
    f.write(f"Accuracy: {acc*100:.2f}%\n")
    f.write(f"Macro F1 Score: {f1*100:.2f}%\n")