# MAE Evaluation

In [None]:
import os
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
import timm
import torch.nn as nn

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
DATA_ROOT = "/kaggle/input/ssl-dataset/ssl_dataset"
CHECKPOINT_PATH = "/kaggle/input/mae_save/pytorch/default/1/mae_checkpoint_epoch9.pth"
VAL_FOLDER = "val.X"
LABELS_PATH = os.path.join(DATA_ROOT, "Labels.json")

# Hyperparameters
BATCH_SIZE = 64
IMG_SIZE = 224

# --- MAE Model Loading ---
class CustomMAE(nn.Module):
    def __init__(self, encoder, patch_embed, pos_embed, encoder_blocks, 
                 encoder_norm, mask_token, decoder, reconstruction_head):
        super().__init__()
        self.encoder = encoder
        self.patch_embed = patch_embed
        self.pos_embed = pos_embed
        self.encoder_blocks = encoder_blocks
        self.encoder_norm = encoder_norm
        self.mask_token = mask_token
        self.decoder = decoder
        self.reconstruction_head = reconstruction_head

    def forward(self, x):
        # Implement your forward pass here
        pass

    def get_features(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed[:, 1:, :]  # Skip cls token
        for blk in self.encoder_blocks:
            x = blk(x)
        x = self.encoder_norm(x)
        return x.mean(dim=1)  # Global average pooling

# Initialize model components
encoder = timm.create_model('vit_tiny_patch16_224', pretrained=False)
encoder.reset_classifier(0)  # Remove classification head

# Extract components
patch_embed = encoder.patch_embed
pos_embed = encoder.pos_embed
encoder_blocks = encoder.blocks
encoder_norm = encoder.norm
embed_dim = encoder.embed_dim

# MAE-specific components
mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
decoder_layer = nn.TransformerEncoderLayer(
    d_model=embed_dim,
    nhead=4,
    dim_feedforward=embed_dim*2,
    batch_first=True
)
decoder = nn.TransformerEncoder(decoder_layer, num_layers=4)
reconstruction_head = nn.Linear(embed_dim, 16*16*3)  # For 16x16 patches

# Initialize model
model = CustomMAE(
    encoder=encoder,
    patch_embed=patch_embed,
    pos_embed=pos_embed,
    encoder_blocks=encoder_blocks,
    encoder_norm=encoder_norm,
    mask_token=mask_token,
    decoder=decoder,
    reconstruction_head=reconstruction_head
).to(device)

# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# --- Dataset Class ---
class MAEEvalDataset(Dataset):
    def __init__(self, folder_path, label_dict, transform):
        self.image_paths = []
        self.labels = []
        self.class_names = sorted(list(set(label_dict.values())))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.class_names)}
        
        for root, _, files in os.walk(folder_path):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    class_folder = os.path.basename(root)
                    if class_folder in label_dict:
                        self.image_paths.append(os.path.join(root, file))
                        self.labels.append(self.class_to_idx[label_dict[class_folder]])
        
        self.transform = transform
        print(f"Found {len(self.image_paths)} labeled images across {len(self.class_names)} classes")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('RGB')
            label = self.labels[idx]
            if self.transform:
                img = self.transform(img)
            return img, label
        except Exception as e:
            print(f"Error loading {self.image_paths[idx]}: {str(e)}")
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

# --- Transformations ---
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# --- Load Labels and Dataset ---
with open(LABELS_PATH) as f:
    label_dict = json.load(f)

val_dataset = MAEEvalDataset(
    folder_path=os.path.join(DATA_ROOT, VAL_FOLDER),
    label_dict=label_dict,
    transform=val_transform
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)
print(f"Total validation images: {len(val_dataset)}")

# --- Feature Extraction ---
def extract_features(loader):
    features, labels = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(loader, desc="Extracting features"):
            imgs = imgs.to(device)
            feats = model.get_features(imgs).cpu()
            features.append(feats)
            labels.append(lbls)
    
    features = torch.cat(features)
    labels = torch.cat(labels)
    
    # Remove invalid samples (label == -1)
    valid_idx = labels != -1
    return features[valid_idx], labels[valid_idx]

X_val, y_val = extract_features(val_loader)
print(f"Feature shape: {X_val.shape}, Labels shape: {y_val.shape}")

# --- Linear Probing ---
clf = LogisticRegression(
    max_iter=1000,
    multi_class='multinomial',
    solver='lbfgs',
    random_state=42
)
clf.fit(X_val.numpy(), y_val.numpy())

# --- Evaluation ---
val_preds = clf.predict(X_val.numpy())
acc = accuracy_score(y_val.numpy(), val_preds)
f1 = f1_score(y_val.numpy(), val_preds, average='macro')

print(f"\nEvaluation Results:")
print(f"Accuracy: {acc*100:.2f}%")
print(f"Macro F1 Score: {f1*100:.2f}%")

# --- Confusion Matrix ---
plt.figure(figsize=(12, 10))
cm = confusion_matrix(y_val.numpy(), val_preds)
sns.heatmap(cm, annot=False, fmt='d', cmap='Blues')
plt.title(f"Confusion Matrix (Accuracy: {acc*100:.2f}%)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.savefig("/kaggle/working/confusion_matrix.png", bbox_inches='tight')
plt.show()

# --- Training History Visualization ---
def plot_training_history():
    # Try to get training history from checkpoint
    if 'train_loss_history' in checkpoint and 'val_loss_history' in checkpoint:
        train_loss = checkpoint['train_loss_history']
        val_loss = checkpoint['val_loss_history']
        train_acc = checkpoint.get('train_acc_history', [])
        val_acc = checkpoint.get('val_acc_history', [])
        
        epochs = len(train_loss)
        
        plt.figure(figsize=(12, 5))
        
        # Loss plot
        plt.subplot(1, 2, 1)
        plt.plot(range(1, epochs+1), train_loss, label='Training Loss')
        plt.plot(range(1, epochs+1), val_loss, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        
       
        if train_acc and val_acc:
            plt.subplot(1, 2, 2)
            plt.plot(range(1, epochs+1), train_acc, label='Training Accuracy')
            plt.plot(range(1, epochs+1), val_acc, label='Validation Accuracy')
            plt.xlabel('Epochs')
            plt.ylabel('Accuracy')
            plt.title('Training and Validation Accuracy')
            plt.legend()
        
        plt.tight_layout()
        plt.savefig('/kaggle/working/training_history.png')
        plt.show()
    else:
        print("No training history found in checkpoint")

plot_training_history()

# --- Save Results
with open('/kaggle/working/results.txt', 'w') as f:
    f.write(f"Accuracy: {acc*100:.2f}%\n")
    f.write(f"Macro F1 Score: {f1*100:.2f}%\n")