---
# Task 7 - Simsiam (2nd SSL)
## Shourav Deb [2021-3-60-274]
---

## CELL 1: Ablation - varying pretraining epochs (100, 200, 400)

In [1]:
import os, time, json, shutil
import torch
from torch import optim
from tqdm import tqdm


ablation_epochs = [100, 200, 400] 
save_interval = 20              
zip_every_k_backups = 1        
zip_output_name_template = "simsiam_ablation_epochs_{E}.zip" 


# hyperparams
learning_rate = globals().get("learning_rate", 0.05)
momentum = globals().get("momentum", 0.9)
weight_decay = globals().get("weight_decay", 1e-4)

# sanity checks
if "train_loader" not in globals():
    raise RuntimeError("train_loader not found")
if "SimSiam" not in globals():
    raise RuntimeError("SimSiam class not defined")

OUT_DIR = globals().get("OUT_DIR", "/kaggle/working/simsiam_task4")
os.makedirs(OUT_DIR, exist_ok=True)

device = globals().get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
BACKBONE = globals().get("BACKBONE", "resnet18")

# negative cosine similarity helper
def negative_cosine_similarity(p, z):
    z = z.detach()
    p = torch.nn.functional.normalize(p, dim=1)
    z = torch.nn.functional.normalize(z, dim=1)
    return - (p * z).sum(dim=1).mean()

# save checkpoint
def save_full_checkpoint(model, optimizer, scheduler, run_dir, epoch_num, avg_loss):
    epoch_dir = os.path.join(run_dir, f"epoch_{epoch_num:03d}")
    os.makedirs(epoch_dir, exist_ok=True)
    ck = {
        "epoch": int(epoch_num),
        "timestamp": time.time(),
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict() if optimizer is not None else None,
        "scheduler_state": scheduler.state_dict() if scheduler is not None else None,
        "avg_loss": float(avg_loss),
        "manifest": globals().get("split_manifest", None)
    }
    ck_path = os.path.join(epoch_dir, "checkpoint.pth")
    torch.save(ck, ck_path)

    try:
        enc_path = os.path.join(epoch_dir, "encoder.pth")
        torch.save({"encoder_state_dict": model.encoder.state_dict(), "feat_dim": model.feat_dim}, enc_path)
    except Exception:
        pass

    meta = {"epoch": int(epoch_num), "avg_loss": float(avg_loss), "saved_at": time.time()}
    with open(os.path.join(epoch_dir, "metadata.json"), "w") as f:
        json.dump(meta, f)
    return ck_path

# Main ablation loop
for E in ablation_epochs:
    run_dir = os.path.join(OUT_DIR, f"ablation_epochs_{E}")
    os.makedirs(run_dir, exist_ok=True)
    zip_out_path = os.path.join("/kaggle/working", zip_output_name_template.format(E=E))

    print(f"\n=== Ablation config: {E} epochs | saving every {save_interval} epochs | run_dir: {run_dir} ===")


    model_ab = SimSiam(backbone=BACKBONE).to(device)
    opt_ab = optim.SGD(model_ab.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    sched_ab = torch.optim.lr_scheduler.CosineAnnealingLR(opt_ab, T_max=E)


    existing_epochs = []
    if os.path.exists(run_dir):
        for name in os.listdir(run_dir):
            if name.startswith("epoch_"):
                try:
                    existing_epochs.append(int(name.split("_")[1]))
                except Exception:
                    pass
    last_saved = max(existing_epochs) if existing_epochs else 0

    start_epoch = last_saved + 1


    if last_saved > 0:
        ckpt_path = os.path.join(run_dir, f"epoch_{last_saved:03d}", "checkpoint.pth")
        if os.path.exists(ckpt_path):
            print(f"Resuming from saved checkpoint at epoch {last_saved}: {ckpt_path}")
            ck = torch.load(ckpt_path, map_location=device, weights_only=False)
            try:
                model_ab.load_state_dict(ck["model_state"])
            except Exception as e:
                print("Warning: model_state load partial/failed ->", e)
            try:
                if ck.get("optimizer_state") is not None:
                    opt_ab.load_state_dict(ck["optimizer_state"])
            except Exception as e:
                print("Warning: optimizer_state load failed ->", e)
            try:
                if ck.get("scheduler_state") is not None:
                    sched_ab.load_state_dict(ck["scheduler_state"])
            except Exception:
                pass
            print(f"Resumed. Next epoch will be {start_epoch} (1-based) out of {E}.")
        else:
            print("Found epoch folders but checkpoint.pth missing.")
            start_epoch = 1
    else:
        start_epoch = 1

    if last_saved > 0:
        print(f"Note: last saved epoch = {last_saved}. Progress after that may be lost if run stopped earlier than next save point.")

    
    try:
        for epoch in range(start_epoch, E + 1):
            model_ab.train()
            running_loss = 0.0
            n_steps = 0

            loop = tqdm(train_loader, desc=f"Ablation E={E} Epoch {epoch}/{E}", leave=False)
            for x1, x2, _, _ in loop:
                x1 = x1.to(device); x2 = x2.to(device)

                p1, p2, z1, z2 = model_ab(x1, x2)
                loss = 0.5 * negative_cosine_similarity(p1, z2) + 0.5 * negative_cosine_similarity(p2, z1)

                opt_ab.zero_grad()
                loss.backward()
                opt_ab.step()

                running_loss += loss.item()
                n_steps += 1
                loop.set_postfix(loss=f"{loss.item():.4f}")

            avg_loss = (running_loss / n_steps) if n_steps > 0 else 0.0
            sched_ab.step()
            print(f"Epoch {epoch}/{E} completed - Avg Loss: {avg_loss:.4f}")

            
            to_save = (epoch % save_interval == 0) or (epoch == E)
            if to_save:
                try:
                    ck_path = save_full_checkpoint(model_ab, opt_ab, sched_ab, run_dir, epoch, avg_loss)
                    print("Saved full checkpoint for epoch:", epoch, "->", ck_path)
                except Exception as e:
                    print("Warning: could not save checkpoint:", e)

                
                if zip_every_k_backups and zip_every_k_backups > 0:
                    try:
                        
                        if os.path.exists(zip_out_path):
                            os.remove(zip_out_path)
                        shutil.make_archive(base_name=zip_out_path.replace(".zip",""), format="zip", root_dir=run_dir)
                        print("Saved ZIP backup to:", zip_out_path)
                    except Exception as e:
                        print("Warning: ZIP backup failed ->", e)

    except KeyboardInterrupt:
        
        print("KeyboardInterrupt caught - attempting to save resume checkpoint.")
        try:
            
            cur_epoch = epoch
            if (cur_epoch % save_interval == 0) or (cur_epoch == E):
                save_full_checkpoint(model_ab, opt_ab, sched_ab, run_dir, cur_epoch, avg_loss if 'avg_loss' in locals() else 0.0)
                print("Saved checkpoint for interrupted epoch:", cur_epoch)
            else:
                print(f"Current epoch {cur_epoch} not a save point. Last permanent save remains epoch {last_saved}.")
        except Exception as e:
            print("Could not save on interrupt:", e)
        raise

    
    try:
        if zip_every_k_backups and zip_every_k_backups > 0:
            if os.path.exists(zip_out_path):
                os.remove(zip_out_path)
            shutil.make_archive(base_name=zip_out_path.replace(".zip",""), format="zip", root_dir=run_dir)
            print("Completed ablation run. Final ZIP:", zip_out_path)
    except Exception as e:
        print("Could not create final ZIP:", e)

print("\nAll ablation configs processed.")



=== Ablation config: 100 epochs | saving every 20 epochs | run_dir: /kaggle/working/simsiam_task4/ablation_epochs_100 ===

Epoch 1/100 completed - Avg Loss: -0.3544
                                                                                         
Epoch 2/100 completed - Avg Loss: -0.6602
                                                                                         
Epoch 3/100 completed - Avg Loss: -0.8142
                                                                                         
Epoch 4/100 completed - Avg Loss: -0.8335
                                                                                         
Epoch 5/100 completed - Avg Loss: -0.8616
                                                                                         
Epoch 6/100 completed - Avg Loss: -0.8948
                                                                                         
Epoch 7/100 completed - Avg Loss: -0.8459
                                          

## CELL 2: Ablation Evaluation (epoch_020 -> epoch_100)

In [3]:
import os, json, time, shutil, glob
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score, roc_curve
from sklearn.metrics import precision_recall_fscore_support, silhouette_score
from sklearn.preprocessing import label_binarize
import joblib
import matplotlib.pyplot as plt
import seaborn as sns


BACKBONE = "resnet18"
RESOLUTION = 224
BATCH_SIZE = 64
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_ROOT = "/kaggle/working/task7_ablation"
os.makedirs(OUT_ROOT, exist_ok=True)


SPLIT_MANIFEST = "/kaggle/working/corrected_split_manifest.json"   
ABLATION_BASE = "/kaggle/output/all-files/All Files/simsiam_ablation_epochs_100"  


CLASSIFIERS = {
    "LogisticRegression": LogisticRegression(max_iter=2000),
    "SVM_RBF": SVC(kernel="rbf", probability=True),
    "RandomForest": RandomForestClassifier(n_estimators=100),
    "DecisionTree": DecisionTreeClassifier(),
    "MLP": MLPClassifier(hidden_layer_sizes=(512,), max_iter=500)
}


DO_FINETUNE = False   
RUN_LABEL_EFFICIENCY = True  


def build_encoder(backbone="resnet18"):
    if backbone == "resnet18":
        base = models.resnet18(weights=None)
        feat_dim = 512
    elif backbone == "resnet50":
        base = models.resnet50(weights=None)
        feat_dim = 2048
    else:
        raise ValueError("Unsupported backbone")
    modules = list(base.children())[:-1]
    encoder = nn.Sequential(*modules)
    encoder.feat_dim = feat_dim
    return encoder

def try_load_encoder(encoder, path):
    ck = torch.load(path, map_location="cpu")
    
    if isinstance(ck, dict):
        
        for key in ["encoder_state_dict", "encoder", "model_state", "model", "state_dict"]:
            if key in ck:
                st = ck[key]
                break
        else:
            st = ck
    else:
        st = ck
        
    try:
        encoder.load_state_dict(st)
        return True
    except Exception:
        mapped = {}
        for k,v in st.items():
            newk = k
            if k.startswith("encoder."):
                newk = k.replace("encoder.", "")
            if k.startswith("module.encoder."):
                newk = k.replace("module.encoder.", "")
            mapped[newk] = v
        try:
            encoder.load_state_dict(mapped)
            return True
        except Exception as e:
            print("Failed to load encoder weights from", path, "error:", e)
            return False


class ManifestDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        lbl = self.labels[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, lbl, p

eval_transform = transforms.Compose([
    transforms.Resize(int(RESOLUTION * 1.1)),
    transforms.CenterCrop(RESOLUTION),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def extract_features(encoder, paths, batch_size=64, workers=2, save_path=None):
    ds = ManifestDataset(paths, [0]*len(paths), transform=eval_transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=workers)
    feats = []
    encoder.eval()
    enc = encoder.to(DEVICE)
    with torch.no_grad():
        for imgs, _, _ in loader:
            imgs = imgs.to(DEVICE)
            h = enc(imgs).view(imgs.size(0), -1).cpu().numpy()
            feats.append(h)
    feats = np.vstack(feats)
    if save_path:
        np.save(save_path, feats)
    return feats

def train_and_eval_probes(train_feats, train_labels, val_feats, val_labels, test_feats, test_labels, out_prefix):
    results = {}
    for name, clf in CLASSIFIERS.items():
        print("Training probe:", name)
        
        import copy
        clf_local = copy.deepcopy(clf)
        clf_local.fit(train_feats, train_labels)
        
        joblib.dump(clf_local, out_prefix + f"_{name}.joblib")
        
        y_pred = clf_local.predict(test_feats)
        acc = accuracy_score(test_labels, y_pred)
        prec, rec, f1, supp = precision_recall_fscore_support(test_labels, y_pred, average=None, zero_division=0)
        
        aucs = {}
        try:
            if hasattr(clf_local, "predict_proba"):
                probs = clf_local.predict_proba(test_feats)
                y_bin = label_binarize(test_labels, classes=list(range(np.max(test_labels)+1)))
                
                for i in range(y_bin.shape[1]):
                    try:
                        aucs[i] = float(roc_auc_score(y_bin[:,i], probs[:,i]))
                    except Exception:
                        aucs[i] = None
                macro = roc_auc_score(y_bin, probs, average="macro")
                micro = roc_auc_score(y_bin, probs, average="micro")
            else:
                aucs = None
                macro = None; micro = None
        except Exception as e:
            aucs = None; macro = None; micro = None
        results[name] = {
            "accuracy": float(acc),
            "per_class_prec": prec.tolist(),
            "per_class_rec": rec.tolist(),
            "per_class_f1": f1.tolist(),
            "per_class_support": supp.tolist(),
            "per_class_auc": aucs,
            "macro_auc": float(macro) if macro is not None else None,
            "micro_auc": float(micro) if micro is not None else None
        }
    return results


with open(SPLIT_MANIFEST, "r") as f:
    split = json.load(f)
train_paths = split["train"]
train_labels = split["train_labels"]
val_paths = split["val"]
val_labels = split["val_labels"]
test_paths = split["test"]
test_labels = split["test_labels"]
classes = split.get("classes", None)
if classes is None:
    classes = [str(i) for i in range(max(train_labels)+1)]


epoch_dirs = sorted(glob.glob(os.path.join(ABLATION_BASE, "epoch_*")))
print("Found ablation epoch dirs:", epoch_dirs)

ablation_summary = []
for ed in epoch_dirs:
    try:
        epoch_name = os.path.basename(ed)
        enc_path = os.path.join(ed, "encoder.pth")
        if not os.path.exists(enc_path):
            print("No encoder.pth in", ed, "skipping")
            continue
        print("Processing", epoch_name)
        out_dir = os.path.join(OUT_ROOT, epoch_name)
        os.makedirs(out_dir, exist_ok=True)
        
        encoder = build_encoder(BACKBONE)
        ok = try_load_encoder(encoder, enc_path)
        if not ok:
            print("Failed to load encoder for", epoch_name)
            continue

        
        train_feat_path = os.path.join(out_dir, "train_feats.npy")
        val_feat_path = os.path.join(out_dir, "val_feats.npy")
        test_feat_path = os.path.join(out_dir, "test_feats.npy")

        if not (os.path.exists(train_feat_path) and os.path.exists(val_feat_path) and os.path.exists(test_feat_path)):
            print("Extracting features for", epoch_name)
            tr_feats = extract_features(encoder, train_paths, batch_size=BATCH_SIZE, workers=NUM_WORKERS, save_path=train_feat_path)
            v_feats = extract_features(encoder, val_paths, batch_size=BATCH_SIZE, workers=NUM_WORKERS, save_path=val_feat_path)
            te_feats = extract_features(encoder, test_paths, batch_size=BATCH_SIZE, workers=NUM_WORKERS, save_path=test_feat_path)
        else:
            print("Loading cached features for", epoch_name)
            tr_feats = np.load(train_feat_path)
            v_feats = np.load(val_feat_path)
            te_feats = np.load(test_feat_path)

        
        try:
            feats_all = np.vstack([tr_feats, v_feats, te_feats])
            lbls_all = np.array(train_labels + val_labels + test_labels)
            sil = silhouette_score(feats_all, lbls_all) if len(np.unique(lbls_all))>1 else None
        except Exception as e:
            print("Silhouette failed:", e)
            sil = None

        
        probe_results = train_and_eval_probes(tr_feats, train_labels, v_feats, val_labels, te_feats, test_labels, out_prefix=os.path.join(out_dir, "probe"))

        
        label_eff = {}
        if RUN_LABEL_EFFICIENCY:
            fractions = [0.01, 0.05, 0.10, 0.25, 0.50, 1.0]
            total = tr_feats.shape[0]
            for frac in fractions:
                n = max(1, int(total * frac))
                
                idxs = np.arange(total)
                rng = np.random.RandomState(42)
                
                sel = rng.choice(idxs, size=n, replace=False)
                clf = LogisticRegression(max_iter=2000)
                clf.fit(tr_feats[sel], np.array(train_labels)[sel])
                pred = clf.predict(te_feats)
                acc = accuracy_score(test_labels, pred)
                label_eff[f"{int(frac*100)}%"] = float(acc)

            # save
            with open(os.path.join(out_dir, "label_efficiency.json"), "w") as f:
                json.dump(label_eff, f, indent=2)

        
        summary = {
            "epoch_dir": epoch_name,
            "enc_path": enc_path,
            "silhouette": float(sil) if sil is not None else None,
            "probe_results": probe_results,
            "label_efficiency": label_eff
        }
        ablation_summary.append(summary)
        with open(os.path.join(out_dir, "ablation_summary.json"), "w") as f:
            json.dump(summary, f, indent=2)

        
        for name in probe_results.keys():
            
            try:
                clf = joblib.load(os.path.join(out_dir, f"probe_{name}.joblib"))
            except Exception:
                clf = None
            if clf is None: continue
            y_pred = clf.predict(te_feats)
            cm = confusion_matrix(test_labels, y_pred)
            plt.figure(figsize=(6,5))
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
            plt.title(f"Confusion Matrix - {epoch_name} - {name}")
            plt.savefig(os.path.join(out_dir, f"confusion_{name}.png"))
            plt.close()

    except Exception as e:
        print("Error processing", ed, e)


with open(os.path.join(OUT_ROOT, "ablation_results.json"), "w") as f:
    json.dump(ablation_summary, f, indent=2)


import csv
csv_path = os.path.join(OUT_ROOT, "ablation_table.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["epoch_dir","silhouette","probe_name","accuracy","macro_auc","micro_auc"])
    for s in ablation_summary:
        ed = s["epoch_dir"]
        sil = s["silhouette"]
        for probe_name, pr in s["probe_results"].items():
            writer.writerow([ed, sil, probe_name, pr.get("accuracy", None), pr.get("macro_auc", None), pr.get("micro_auc", None)])


zipname = os.path.join("/kaggle/working", "task7_ablation_outputs")
if os.path.exists(zipname + ".zip"):
    os.remove(zipname + ".zip")
shutil.make_archive(base_name=zipname, format="zip", root_dir=OUT_ROOT)
print("Ablation outputs zipped to", zipname + ".zip")
print("Done. Outputs in", OUT_ROOT)


Found ablation epoch dirs: ['/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_020', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_040', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_060', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_080', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_085', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_090', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_095', '/kaggle/input/all-files/All Files/simsiam_ablation_epochs_100/epoch_100']
Processing epoch_020
Extracting features for epoch_020
Training probe: LogisticRegression
Training probe: SVM_RBF
Training probe: RandomForest
Training probe: DecisionTree
Training probe: MLP
Processing epoch_040
Extracting features for epoch_040
Training probe: LogisticRegression
Training probe: SVM_RBF
Training probe: RandomForest
Training probe: DecisionTree
Training

## Cell 3: Ratio Sweep Evaluation

In [9]:
import os, json, time, shutil
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, precision_recall_fscore_support
from sklearn.preprocessing import label_binarize
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy.random as npr


BACKBONE = "resnet18"
RESOLUTION = 224
BATCH_SIZE = 64
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_ROOT = "/kaggle/working/task7_ratios"
os.makedirs(OUT_ROOT, exist_ok=True)


SPLIT_MANIFEST = "/kaggle/input/simsiam-task4-archive/split_manifest.json"
BASE_FEATURES_DIR = "/kaggle/input/all-files/All Files"
BASE_TRAIN_FEATS = os.path.join(BASE_FEATURES_DIR, "train_feats.npy")
BASE_TRAIN_LABELS = os.path.join(BASE_FEATURES_DIR, "train_labels.npy")
BASE_VAL_FEATS = os.path.join(BASE_FEATURES_DIR, "val_feats.npy")
BASE_VAL_LABELS = os.path.join(BASE_FEATURES_DIR, "val_labels.npy")
BASE_TEST_FEATS = os.path.join(BASE_FEATURES_DIR, "test_feats.npy")
BASE_TEST_LABELS = os.path.join(BASE_FEATURES_DIR, "test_labels.npy")

CLASSIFIERS = {
    "LogisticRegression": LogisticRegression(max_iter=2000),
    "SVM_RBF": SVC(kernel="rbf", probability=True),
    "RandomForest": RandomForestClassifier(n_estimators=100),
    "DecisionTree": DecisionTreeClassifier(),
    "MLP": MLPClassifier(hidden_layer_sizes=(512,), max_iter=500)
}


class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        lbl = self.labels[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, lbl, p

eval_transform = transforms.Compose([
    transforms.Resize(int(RESOLUTION * 1.1)),
    transforms.CenterCrop(RESOLUTION),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def build_encoder(backbone="resnet18"):
    if backbone == "resnet18":
        base = models.resnet18(weights=None)
        feat_dim = 512
    elif backbone == "resnet50":
        base = models.resnet50(weights=None)
        feat_dim = 2048
    else:
        raise ValueError("Unsupported backbone")
    modules = list(base.children())[:-1]
    encoder = nn.Sequential(*modules)
    encoder.feat_dim = feat_dim
    return encoder

def load_encoder_weights(encoder, ckpt_path):
    ck = torch.load(ckpt_path, map_location="cpu")
    if isinstance(ck, dict):
        for key in ["encoder_state_dict","encoder","model_state","state_dict","model"]:
            if key in ck:
                state = ck[key]; break
        else:
            state = ck
    else:
        state = ck
    try:
        encoder.load_state_dict(state)
        return True
    except Exception:
        mapped = {}
        for k,v in state.items():
            newk = k
            if k.startswith("encoder."): newk = k.replace("encoder.","")
            if k.startswith("module.encoder."): newk = k.replace("module.encoder.","")
            mapped[newk] = v
        try:
            encoder.load_state_dict(mapped)
            return True
        except Exception as e:
            print("Failed to load encoder:", e)
            return False

def extract_features_from_paths(encoder, paths, batch_size=64, workers=2, save_path=None):
    ds = ImageDataset(paths, [0]*len(paths), transform=eval_transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=workers)
    feats = []
    encoder.eval()
    enc = encoder.to(DEVICE)
    with torch.no_grad():
        for imgs, _, _ in loader:
            imgs = imgs.to(DEVICE)
            h = enc(imgs).view(imgs.size(0), -1).cpu().numpy()
            feats.append(h)
    feats = np.vstack(feats)
    if save_path: np.save(save_path, feats)
    return feats


def stratified_subsample_indices(y, k, random_state=42):
    y = np.array(y)
    n = len(y)
    unique, counts = np.unique(y, return_counts=True)
    n_classes = len(unique)
    if n_classes < 2:
        return None
    if k < 2:
        return None
    rng = npr.RandomState(random_state)
    try:
        
        train_idx, _ = train_test_split(np.arange(n), train_size=k, stratify=y, random_state=random_state)
        if len(np.unique(y[train_idx])) >= 2:
            return train_idx
    except Exception:
        pass
        
    prop = counts / counts.sum()
    desired = np.floor(prop * k).astype(int)
    for i, uc in enumerate(unique):
        if desired[i] == 0 and k >= n_classes and counts[i] > 0:
            desired[i] = 1
    rem = int(k - desired.sum())
    if rem > 0:
        leftover = (prop * k) - desired
        order = np.argsort(-leftover)
        for idx in order:
            if rem <= 0:
                break
            desired[idx] += 1
            rem -= 1
    selected = []
    for cls_idx, cls in enumerate(unique):
        cls_inds = np.where(y == cls)[0]
        cnt = desired[cls_idx]
        if cnt <= 0:
            continue
        if cnt > len(cls_inds):
            cnt = len(cls_inds)
        chosen = rng.choice(cls_inds, size=cnt, replace=False)
        selected.extend(chosen.tolist())
    selected = np.array(selected, dtype=int)
    if len(selected) < 2 or len(np.unique(y[selected])) < 2:
        return None
    if len(selected) > k:
        selected = selected[:k]
    return selected


with open(SPLIT_MANIFEST, "r") as f:
    sm = json.load(f)


feats_all = None
labels_all = None

have_train = os.path.exists(BASE_TRAIN_FEATS) and os.path.exists(BASE_TRAIN_LABELS)
have_val = os.path.exists(BASE_VAL_FEATS) and os.path.exists(BASE_VAL_LABELS)
have_test = os.path.exists(BASE_TEST_FEATS) and os.path.exists(BASE_TEST_LABELS)

if have_train and have_val and have_test:
    tr = np.load(BASE_TRAIN_FEATS); tr_lbl = np.load(BASE_TRAIN_LABELS)
    v = np.load(BASE_VAL_FEATS); v_lbl = np.load(BASE_VAL_LABELS)
    te = np.load(BASE_TEST_FEATS); te_lbl = np.load(BASE_TEST_LABELS)
    feats_all = np.vstack([tr, v, te])
    labels_all = np.hstack([tr_lbl, v_lbl, te_lbl]).astype(int)
elif have_train and have_test:
    tr = np.load(BASE_TRAIN_FEATS); tr_lbl = np.load(BASE_TRAIN_LABELS)
    te = np.load(BASE_TEST_FEATS); te_lbl = np.load(BASE_TEST_LABELS)
    feats_all = np.vstack([tr, te])
    labels_all = np.hstack([tr_lbl, te_lbl]).astype(int)
else:
    
    print("No adequate cached features found, will extract features from images.")
    all_paths = sm["train"] + sm["val"] + sm["test"]
    all_labels = sm["train_labels"] + sm["val_labels"] + sm["test_labels"]
    BASE_ENCODER_CKPT = "/kaggle/input/simsiam-task4-archive/simsiam_encoder.pth"
    encoder = build_encoder(BACKBONE)
    if not load_encoder_weights(encoder, BASE_ENCODER_CKPT):
        raise RuntimeError("Cannot load encoder weights from " + BASE_ENCODER_CKPT)
    feats_all = extract_features_from_paths(encoder, all_paths, batch_size=BATCH_SIZE, workers=NUM_WORKERS, save_path=os.path.join(OUT_ROOT,"feats_all.npy"))
    labels_all = np.array(all_labels).astype(int)
    np.save(os.path.join(OUT_ROOT,"labels_all.npy"), labels_all)


if feats_all is None or labels_all is None:
    raise RuntimeError("Failed to prepare features and labels.")

N = feats_all.shape[0]
print("Total samples used for ratio sweep:", N)


np.save(os.path.join(OUT_ROOT,"feats_all.npy"), feats_all)
np.save(os.path.join(OUT_ROOT,"labels_all.npy"), labels_all)


ratios = [0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]
ratio_results = []

for train_frac in ratios:
    
    n_train = int(round(train_frac * N))
    
    if n_train >= N:
        n_train = N - 1
    if n_train < 1:
        n_train = 1
    n_test = N - n_train
    print(f"Running ratio: train {int(train_frac*100)}% -> n_train={n_train}, n_test={n_test}")
    
    idxs = np.arange(N)
    try:
        train_idx, test_idx = train_test_split(idxs, train_size=n_train, stratify=labels_all, random_state=42)
    except Exception as e:
        
        print("Stratified split failed, falling back to random split:", e)
        train_idx, test_idx = train_test_split(idxs, train_size=n_train, random_state=42)
    X_train = feats_all[train_idx]; y_train = labels_all[train_idx]
    X_test = feats_all[test_idx]; y_test = labels_all[test_idx]
    
    val_portion = max(1, int(round(0.10 * len(X_train))))
    
    try:
        tr_sub_idx, val_sub_idx = train_test_split(np.arange(len(X_train)), test_size=val_portion, stratify=y_train, random_state=42)
    except Exception:
        tr_sub_idx, val_sub_idx = train_test_split(np.arange(len(X_train)), test_size=val_portion, random_state=42)
    X_tr = X_train[tr_sub_idx]; y_tr = y_train[tr_sub_idx]
    X_val = X_train[val_sub_idx]; y_val = y_train[val_sub_idx]

    
    probes_res = {}
    for name, clf in CLASSIFIERS.items():
        import copy
        clf_local = copy.deepcopy(clf)
        if len(np.unique(y_tr)) < 2:
            probes_res[name] = {"accuracy": None, "macro_auc": None, "micro_auc": None, "per_class_f1": None, "skipped_reason": "only_one_class_in_train"}
            continue
        clf_local.fit(X_tr, y_tr)
        joblib.dump(clf_local, os.path.join(OUT_ROOT, f"{int(train_frac*100)}pct_{name}.joblib"))
        y_pred = clf_local.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        prec, rec, f1, sup = precision_recall_fscore_support(y_test, y_pred, average=None, zero_division=0)
        
        try:
            if hasattr(clf_local, "predict_proba"):
                probs = clf_local.predict_proba(X_test)
                y_bin = label_binarize(y_test, classes=list(range(np.max(labels_all)+1)))
                macro = roc_auc_score(y_bin, probs, average="macro")
                micro = roc_auc_score(y_bin, probs, average="micro")
            else:
                macro = None; micro = None
        except Exception:
            macro = None; micro = None
        probes_res[name] = {"accuracy": float(acc), "macro_auc": macro, "micro_auc": micro, "per_class_f1": f1.tolist()}

    
    label_eff = {}
    total = X_tr.shape[0]
    for frac in [0.01,0.05,0.10,0.25,0.50,1.0]:
        k = max(1, int(total * frac))
        sel = stratified_subsample_indices(y_tr, k, random_state=42)
        if sel is None:
            label_eff[f"{int(frac*100)}%"] = {"accuracy": None, "skipped": True, "reason": "insufficient_class_diversity_for_k"}
            continue
        clf = LogisticRegression(max_iter=2000)
        clf.fit(X_tr[sel], y_tr[sel])
        score = accuracy_score(y_test, clf.predict(X_test))
        label_eff[f"{int(frac*100)}%"] = {"accuracy": float(score), "skipped": False, "k_used": int(len(sel))}

    ratio_results.append({
        "train_frac": train_frac,
        "n_train": int(n_train),
        "n_val": int(len(X_val)),
        "n_test": int(n_test),
        "probes": probes_res,
        "label_efficiency": label_eff
    })

    
    with open(os.path.join(OUT_ROOT, f"ratio_{int(train_frac*100)}.json"), "w") as f:
        json.dump(ratio_results[-1], f, indent=2)

        
with open(os.path.join(OUT_ROOT, "ratio_results.json"), "w") as f:
    json.dump(ratio_results, f, indent=2)

import csv
csv_path = os.path.join(OUT_ROOT, "ratio_table.csv")
with open(csv_path, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    
    header = ["train_pct", "n_train", "n_val", "n_test"]
    probe_names = list(CLASSIFIERS.keys())
    header += [pn + "_acc" for pn in probe_names]
    writer.writerow(header)
    for rr in ratio_results:
        row = [int(rr["train_frac"]*100), rr["n_train"], rr["n_val"], rr["n_test"]]
        for pn in probe_names:
            acc_val = rr["probes"][pn]["accuracy"] if rr["probes"][pn].get("accuracy") is not None else ""
            row.append(acc_val)
        writer.writerow(row)


zipname = os.path.join("/kaggle/working", "task7_ratio_outputs")
if os.path.exists(zipname + ".zip"):
    os.remove(zipname + ".zip")
shutil.make_archive(base_name=zipname, format="zip", root_dir=OUT_ROOT)
print("Ratio sweep outputs zipped to", zipname + ".zip")
print("Done. Outputs in", OUT_ROOT)


Total samples used for ratio sweep: 1800
Running ratio: train 90% -> n_train=1620, n_test=180
Running ratio: train 80% -> n_train=1440, n_test=360
Running ratio: train 70% -> n_train=1260, n_test=540
Running ratio: train 60% -> n_train=1080, n_test=720
Running ratio: train 50% -> n_train=900, n_test=900
Running ratio: train 40% -> n_train=720, n_test=1080
Running ratio: train 30% -> n_train=540, n_test=1260
Running ratio: train 20% -> n_train=360, n_test=1440
Running ratio: train 10% -> n_train=180, n_test=1620
Ratio sweep outputs zipped to /kaggle/working/task7_ratio_outputs.zip
Done. Outputs in /kaggle/working/task7_ratios
