In [3]:
import os
from pathlib import Path
import sys
import numpy as np
import torch
import open_clip
import matplotlib.pyplot as plt
import csv
import pandas as pd

# sys.path.append(str(Path.cwd().parent))

from scripts.few_shot import (
    ensure_features,
    load_labels,
    prototype_classifier,
    train_linear_probe,
    topk_acc,
    balanced_acc,
)

In [None]:


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

def zero_shot_scores(label_names, backbone, pretrained, device):
    model, _, _ = open_clip.create_model_and_transforms(backbone, pretrained=pretrained, device=device)
    
    model.eval()
    tokenizer = open_clip.get_tokenizer(backbone)
    prompts = [f"a photo of {n}" for n in label_names]
    with torch.no_grad():
        T = tokenizer(prompts).to(device)
        text = model.encode_text(T)
        text = text / text.norm(dim=-1, keepdim=True)
        text = text.float().cpu().numpy().T  # [D, K]
    return text

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Notebook configuration
class Args:
    data_root = "/home/c/dkorot/AI4GOOD/provided_dir/datasets/mushroom/merged_dataset"
    train_csv = "/home/c/dkorot/AI4GOOD/ai4good-mushroom/splits/train.csv"
    test_csv = "/home/c/dkorot/AI4GOOD/ai4good-mushroom/splits/test.csv"
    labels = "/home/c/dkorot/AI4GOOD/ai4good-mushroom/labels.tsv"
    backbone = "ViT-B-32-quickgelu"
    pretrained = "openai"
    shots = [1, 5]
    save_dir = "features"
    results_dir = "results"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed = 42
    epochs = 100
    lr = 1e-2
    batch_size = 32

args = Args()

torch.manual_seed(args.seed)
np.random.seed(args.seed)

Path(args.results_dir).mkdir(parents=True, exist_ok=True)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
label_names, name2id = load_labels(args.labels)
K = len(label_names)

# Load cached features
X_tr, y_tr, _ = ensure_features("train", args.backbone, args.data_root, args.train_csv, args.labels,
                                save_dir=args.save_dir, pretrained=args.pretrained)
X_te, y_te, _ = ensure_features("test", args.backbone, args.data_root, args.test_csv, args.labels,
                                save_dir=args.save_dir, pretrained=args.pretrained)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# text_emb = zero_shot_scores(label_names, args.backbone, args.pretrained, args.device)
records = []

# # Compute zero-shot baseline if 0 in shots
# if 0 in args.shots:
#     scores_zs = X_te @ text_emb
#     yhat_zs = np.argmax(scores_zs, axis=1)
#     zs_top1 = (yhat_zs == y_te).mean()
#     zs_bal = balanced_acc(y_te, yhat_zs, K)
#     records.append((0, "zero-shot", zs_top1, zs_bal, None))
#     print(f"Zero-shot top1={zs_top1:.4f}, balanced={zs_bal:.4f}")
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
from scripts.few_shot import sample_few_shot_indices

for shot in sorted([s for s in args.shots if s > 0]):
    rng = np.random.default_rng(args.seed + int(shot))
    sup_idx, sup_labels = sample_few_shot_indices(y_tr, K, shot, rng)
    X_sup = X_tr[sup_idx]
    y_sup = sup_labels

    # Prototype classifier
    prototypes = prototype_classifier(X_sup, y_sup, K)
    scores_proto = X_te @ prototypes.T
    yhat_proto = np.argmax(scores_proto, axis=1)
    proto_top1 = (yhat_proto == y_te).mean()
    proto_bal = balanced_acc(y_te, yhat_proto, K)
    records.append((shot, "prototype", proto_top1, proto_bal, None))

    # Linear probe
    model = train_linear_probe(X_sup, y_sup, X_val=None, y_val=None,
                               epochs=args.epochs, lr=args.lr,
                               batch_size=args.batch_size,
                               device=args.device, seed=args.seed + int(shot))
    model.eval()
    with torch.no_grad():
        xt = torch.from_numpy(X_te).float().to(args.device)
        logits = model(xt).cpu().numpy()
        yhat_lin = np.argmax(logits, axis=1)
    lin_top1 = (yhat_lin == y_te).mean()
    lin_bal = balanced_acc(y_te, yhat_lin, K)
    records.append((shot, "linear", lin_top1, lin_bal, None))

    print(f"shot={shot}: proto top1={proto_top1:.4f}, lin top1={lin_top1:.4f}")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
out_csv = os.path.join(args.results_dir, f"few_shot_table_{args.backbone.replace(' ','_')}.csv")
with open(out_csv, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["shot", "model", "top1", "balanced_acc", "notes"])
    for r in records:
        w.writerow(r)

print(f"✅ Wrote table -> {out_csv}")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



  1%|▏                                     | 60/10774 [00:28<1:25:18,  2.09it/s]
Traceback (most recent call last):
  File "/home/c/dkorot/AI4GOOD/ai4good-mushroom/scripts/dump_features.py", line 70, in <module>
    main()
  File "/home/c/dkorot/AI4GOOD/ai4good-mushroom/scripts/dump_features.py", line 55, in main
    ims.append(preprocess(img))
               ^^^^^^^^^^^^^^^
  File "/home/c/dkorot/miniconda3/envs/torch_env/lib/python3.12/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
          ^^^^^^
  File "/home/c/dkorot/miniconda3/envs/torch_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/c/dkorot/miniconda3/envs/torch_env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "

AssertionError: Feature dump failed

In [None]:

from scripts.few_shot import (
    ensure_features,
    load_labels,
    prototype_classifier,
    train_linear_probe,
    topk_acc,
    balanced_acc,
)