# Momentum Contrast (MoCo) Evaluation

In [None]:
import os
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms, models
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import Dataset, DataLoader
import seaborn as sns
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATA_ROOT = "/kaggle/input/ssl-dataset/ssl_dataset"
CHECKPOINT_PATH = "/kaggle/input/moco-model-epoch/pytorch/default/1/checkpoint_epoch.pth"
VAL_FOLDER = "val.X"
LABELS_PATH = os.path.join(DATA_ROOT, "Labels.json")

BATCH_SIZE = 64
IMG_SIZE = 224


In [None]:

class MoCoBackbone(nn.Module):
    def __init__(self, checkpoint_path):
        super().__init__()
      
        base_encoder = models.resnet18(pretrained=False)
        base_encoder.fc = nn.Identity()

       
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict = checkpoint['model_state_dict']

        encoder_weights = {
            k.replace("module.encoder_q.", "").replace("encoder_q.", ""): v
            for k, v in state_dict.items()
            if "encoder_q" in k or "module.encoder_q" in k
        }

        missing, unexpected = base_encoder.load_state_dict(encoder_weights, strict=False)

        for param in base_encoder.parameters():
            param.requires_grad = False  

        self.encoder = base_encoder

    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return features

model = MoCoBackbone(CHECKPOINT_PATH).to(device)
model.eval()

class MoCoEvalDataset(Dataset):
    def __init__(self, folder_path, label_dict, transform):
        self.image_paths = []
        self.labels = []
        self.class_names = sorted(list(set(label_dict.values())))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.class_names)}

        for root, _, files in os.walk(folder_path):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    class_folder = os.path.basename(root)
                    if class_folder in label_dict:
                        self.image_paths.append(os.path.join(root, file))
                        self.labels.append(self.class_to_idx[label_dict[class_folder]])

        self.transform = transform
        print(f"Found {len(self.image_paths)} labeled images across {len(self.class_names)} classes")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('RGB')
            label = self.labels[idx]
            if self.transform:
                img = self.transform(img)
            return img, label
        except Exception as e:
            print(f"Error loading {self.image_paths[idx]}: {str(e)}")
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


with open(LABELS_PATH) as f:
    label_dict = json.load(f)

val_dataset = MoCoEvalDataset(
    folder_path=os.path.join(DATA_ROOT, VAL_FOLDER),
    label_dict=label_dict,
    transform=val_transform
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)
print(f"Total validation images: {len(val_dataset)}")

def extract_features(loader):
    features, labels = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(loader, desc="Extracting features"):
            imgs = imgs.to(device)
            feats = model(imgs).cpu()
            features.append(feats)
            labels.append(lbls)

    features = torch.cat(features)
    labels = torch.cat(labels)
    valid_idx = labels != -1
    return features[valid_idx], labels[valid_idx]

X_val, y_val = extract_features(val_loader)
print(f"Feature shape: {X_val.shape}, Labels shape: {y_val.shape}")

# --- Linear Classifier ---
clf = LogisticRegression(
    max_iter=1000,
    multi_class='multinomial',
    solver='lbfgs',
    random_state=42
)
clf.fit(X_val.numpy(), y_val.numpy())

val_probs = clf.predict_proba(X_val.numpy())
val_preds = np.argmax(val_probs, axis=1)

acc = accuracy_score(y_val.numpy(), val_preds)
f1 = f1_score(y_val.numpy(), val_preds, average='macro')


val_probs_tensor = torch.tensor(val_probs)
true_labels_tensor = torch.tensor(y_val.numpy())

top1_preds = torch.argmax(val_probs_tensor, dim=1)
top1_correct = (top1_preds == true_labels_tensor).sum().item()
top1_accuracy = top1_correct / len(y_val)

top5_preds = torch.topk(val_probs_tensor, k=5, dim=1).indices
top5_correct = torch.any(top5_preds == true_labels_tensor.unsqueeze(1), dim=1).sum().item()
top5_accuracy = top5_correct / len(y_val)


print(f"\nEvaluation Results (MoCo):")
print(f"Accuracy (Top-1): {acc*100:.2f}%")
print(f"Macro F1 Score: {f1*100:.2f}%")
print(f"Top-1 Accuracy: {top1_accuracy*100:.2f}%")
print(f"Top-5 Accuracy: {top5_accuracy*100:.2f}%")


'''Evaluation Results (MoCo):
Accuracy (Top-1): 46.20%
Macro F1 Score: 45.58%
Top-1 Accuracy: 46.20%
Top-5 Accuracy: 73.96%'''