# SimSiam Evaluation

In [None]:
import os
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import models
import matplotlib.pyplot as plt
from torchvision import transforms
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_ROOT = "/kaggle/input/ssl-dataset/ssl_dataset"
CHECKPOINT_PATH = "/kaggle/input/simsaimmodel/pytorch/default/1/simsiam_r18_epoch50.pth"  
VAL_FOLDER = "val.X"
LABELS_PATH = os.path.join(DATA_ROOT, "Labels.json")
BATCH_SIZE = 64
IMG_SIZE = 96  


## Class Formation and Transformations

In [None]:
class SimSiamBackbone(nn.Module):
    def __init__(self, checkpoint_path):
        super().__init__()
        class ProjectionMLP(nn.Module):
            def __init__(self, in_dim=512, hidden_dim=512, out_dim=512):
                super().__init__()
                self.layer1 = nn.Sequential(
                    nn.Linear(in_dim, hidden_dim, bias=False),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU(inplace=True)
                )
                self.layer2 = nn.Sequential(
                    nn.Linear(hidden_dim, out_dim, bias=False),
                    nn.BatchNorm1d(out_dim)
                )

            def forward(self, x):
                x = self.layer1(x)
                x = self.layer2(x)
                return x

        class SimSiam(nn.Module):
            def __init__(self):
                super().__init__()
                base = models.resnet18(pretrained=False)
                self.encoder = nn.Sequential(*list(base.children())[:-1])  
                feat_dim = base.fc.in_features  
                self.projector = ProjectionMLP(in_dim=feat_dim,
                                               hidden_dim=feat_dim,
                                               out_dim=feat_dim)
              

            def forward_backbone(self, x):
                feat = self.encoder(x)          
                feat = torch.flatten(feat, 1)   
                return feat

       
        model = SimSiam()
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['model_state']

        
        filtered_state_dict = {
            k: v for k, v in state_dict.items() if not k.startswith("predictor.")
        }

        missing, unexpected = model.load_state_dict(filtered_state_dict, strict=False)
        print(f"Missing keys: {missing}")
        print(f"Unexpected keys: {unexpected}")

        model.eval()
        for param in model.parameters():
            param.requires_grad = False

        self.encoder = model.encoder

    def forward(self, x):
        with torch.no_grad():
            feat = self.encoder(x)
            feat = torch.flatten(feat, 1)
        return feat

class SimSiamEvalDataset(Dataset):
    def __init__(self, folder_path, label_dict, transform):
        self.image_paths = []
        self.labels = []
        self.class_names = sorted(list(set(label_dict.values())))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.class_names)}

        for root, _, files in os.walk(folder_path):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    class_folder = os.path.basename(root)
                    if class_folder in label_dict:
                        self.image_paths.append(os.path.join(root, file))
                        self.labels.append(self.class_to_idx[label_dict[class_folder]])

        self.transform = transform
        print(f"Found {len(self.image_paths)} labeled images across {len(self.class_names)} classes")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('RGB')
            label = self.labels[idx]
            if self.transform:
                img = self.transform(img)
            return img, label
        except Exception as e:
            print(f"Error loading {self.image_paths[idx]}: {str(e)}")
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1


val_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

with open(LABELS_PATH) as f:
    label_dict = json.load(f)

val_dataset = SimSiamEvalDataset(
    folder_path=os.path.join(DATA_ROOT, VAL_FOLDER),
    label_dict=label_dict,
    transform=val_transform
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

model = SimSiamBackbone(CHECKPOINT_PATH).to(device)
model.eval()

def extract_features(loader):
    features, labels = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(loader, desc="Extracting features (SimSiam)"):
            imgs = imgs.to(device)
            feats = model(imgs).cpu()
            features.append(feats)
            labels.append(lbls)

    features = torch.cat(features)
    labels = torch.cat(labels)
    valid_idx = labels != -1
    return features[valid_idx], labels[valid_idx]

X_val, y_val = extract_features(val_loader)
print(f"Total validation images: {len(val_dataset)}")
print(f"Feature shape: {X_val.shape}, Labels shape: {y_val.shape}")

## Linear Classifier and Evaluation

In [None]:
scaler = StandardScaler()
X_val_scaled = scaler.fit_transform(X_val.numpy())

clf = LogisticRegression(
    max_iter=5000,
    multi_class='multinomial',
    solver='saga',  
    random_state=42
)
clf.fit(X_val_scaled, y_val.numpy())


val_probs = clf.predict_proba(X_val.numpy())
val_preds = np.argmax(val_probs, axis=1)


acc = accuracy_score(y_val.numpy(), val_preds)
f1 = f1_score(y_val.numpy(), val_preds, average='macro')

val_probs_tensor = torch.tensor(val_probs)
true_labels_tensor = torch.tensor(y_val.numpy())

top1_preds = torch.argmax(val_probs_tensor, dim=1)
top1_correct = (top1_preds == true_labels_tensor).sum().item()
top1_accuracy = top1_correct / len(y_val)

top5_preds = torch.topk(val_probs_tensor, k=5, dim=1).indices
top5_correct = torch.any(top5_preds == true_labels_tensor.unsqueeze(1), dim=1).sum().item()
top5_accuracy = top5_correct / len(y_val)

print(f"\nEvaluation Results (SimSiam):")
print(f"Accuracy (Top-1): {acc*100:.2f}%")
print(f"Macro F1 Score: {f1*100:.2f}%")
print(f"Top-1 Accuracy: {top1_accuracy*100:.2f}%")
print(f"Top-5 Accuracy: {top5_accuracy*100:.2f}%")