In [30]:
import wandb
import torch
import wandb, pandas as pd, re, os, glob, warnings
import os
from pathlib import Path
import einops
import torch.nn.functional as F
import wandb
import torch, einops
import torch.nn.functional as F
from pathlib import Path
import pandas as pd


In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [37]:
# ==== USER CONFIG ====
ENTITY  = "yuulia-volkova-algoverse"
PROJECT  = "outlines_linear_probe"
RUN_ID = "rajjwtho"

# dir for checkpoints and normalizers 
LOCAL_DIR = f"artifacts/{RUN_ID}/"
RUN_PATH = f"yuulia-volkova-algoverse/outlines_linear_probe/{RUN_ID}" 

RESIDUALS_PATH = "/workspace/hdd_cache/tensors/llama-3b"
EMBEDS_PATH     = "/workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines/results/llama-3b-outlines-embeddings_new"
HF_REPO_ID     = "yulia-volkova/llama-3b-outlines-embeddings_new"  # only used if you enable fallback
DTYPE          = torch.float32

# Evaluate the "last 4 chunks"
CHUNK_IDS = list(range(996, 1000))



In [38]:

# Check logged artifacts for the run 
api = wandb.Api()
 
run = api.run(RUN_PATH)
print("Run name:", run.name, "| Run id:", run.id)

print("\n=== logged_artifacts ===")
logged = list(run.logged_artifacts())
print("count:", len(logged))
for a in logged:
    print(a.name, "| type:", a.type, "| version:", a.version)

print("\n=== used_artifacts ===")
used = list(run.used_artifacts())
print("count:", len(used))
for a in used:
    creator = a.logged_by().name if a.logged_by() else "<unknown>"
    print(a.name, "| type:", a.type, "| version:", a.version, "| created by:", creator)

Run name: probe-local-resids-all-chunks-one-epoch | Run id: rajjwtho

=== logged_artifacts ===
count: 6
epoch-001-chunks.txt:v2 | type: metadata | version: v2
res_normalizer.pt:v3 | type: asset | version: v3
embed_normalizer.pt:v3 | type: asset | version: v3
checkpoint-epoch-1:v3 | type: model | version: v3
run-rajjwtho-history:v0 | type: wandb-history | version: v0
run-rajjwtho-events:v0 | type: wandb-events | version: v0

=== used_artifacts ===
count: 0


In [39]:
import os, re, shutil
import pandas as pd
import wandb

CHECKPOINT_DIR = LOCAL_DIR + "/checkpoints"



os.makedirs(LOCAL_DIR, exist_ok=True)

api = wandb.Api()
run = api.run(RUN_PATH)
print("Run:", run.name, "| id:", run.id)

# pick best epoch
hist = run.history(pandas=True)
val_col_candidates = ["epoch/val_loss","val/loss","val_loss","val/mse","epoch/val_mse"]
val_col = next((c for c in val_col_candidates if c in hist.columns and hist[c].notna().any()), None)
h2 = hist[["epoch", val_col]].dropna()
best_epoch = int(h2.loc[h2[val_col].idxmin(), "epoch"])
print(f"Best epoch: {best_epoch}  ({val_col}={float(h2.loc[h2[val_col].idxmin(), val_col]):.6f})")

# find its artifact
chosen_art = None
for art in run.logged_artifacts():
    if art.type == "model" and art.name.startswith("checkpoint-epoch-"):
        m = re.match(r"checkpoint-epoch-(\d+):", art.name)
        if m and int(m.group(1)) == best_epoch:
            chosen_art = art; break
assert chosen_art is not None, "No checkpoint artifact found for that epoch."

print("Chosen artifact:", chosen_art.name)

# build a stable local cache filename (with .pt extension for convenience)
base_name = chosen_art.name.split(":")[0]        # e.g. "checkpoint-epoch-16"
local_ckpt = os.path.join(CHECKPOINT_DIR, base_name + ".pt")

if os.path.exists(local_ckpt):
    print("Using cached checkpoint:", local_ckpt)
    ckpt_path = local_ckpt
else:
    print("⬇️ Downloading from W&B (one time)...")
    ckpt_dir = chosen_art.download(root=CHECKPOINT_DIR)

    manifest_keys = list(chosen_art.manifest.entries.keys())
    if not manifest_keys:
        raise RuntimeError("Artifact has no files in manifest.")
    rel_path = manifest_keys[0]                
    src_path = os.path.join(ckpt_dir, rel_path)
    if not os.path.exists(src_path):
        # fallback: search deeply
        cand = None
        for root, _, files in os.walk(ckpt_dir):
            for f in files:
                if f.startswith(base_name):     
                    cand = os.path.join(root, f); break
            if cand: break
        if cand is None:
            # last resort: any file
            files = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)]
            if not files:
                raise RuntimeError("Downloaded artifact folder is empty.")
            cand = files[0]
        src_path = cand

    # copy to the stable cache path with .pt extension (even if original had none)
    shutil.copy2(src_path, local_ckpt)
    ckpt_path = local_ckpt
    print("Saved to:", ckpt_path)

print("Final checkpoint path:", ckpt_path)

# sanity-load
import torch
state = torch.load(ckpt_path, map_location="cpu")
print("Checkpoint keys:", list(state.keys()))


Run: probe-local-resids-all-chunks-one-epoch | id: rajjwtho
Best epoch: 1  (epoch/val_loss=0.537032)
Chosen artifact: checkpoint-epoch-1:v3
⬇️ Downloading from W&B (one time)...


[34m[1mwandb[0m: Downloading large artifact 'checkpoint-epoch-1:v3', 2052.02MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:07.4 (277.0MB/s)


Saved to: artifacts/rajjwtho//checkpoints/checkpoint-epoch-1.pt
Final checkpoint path: artifacts/rajjwtho//checkpoints/checkpoint-epoch-1.pt
Checkpoint keys: ['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'train_loss', 'val_loss']


In [40]:
import train_probe as core  

# Load probe with state at best checkpoint

state = torch.load(ckpt_path, map_location="cpu")

N_LAYERS = 57         
D_MODEL  = 3072       
probe = core.LinearProbe(n_layers=N_LAYERS, d_model=D_MODEL)
probe.load_state_dict(state["model_state_dict"], strict=True)
probe.eval()

print("Loaded epoch (0-based):", state.get("epoch"))


Loaded epoch (0-based): 0


In [41]:
# download normalizers for the run
NORMALIZERS_DIR = LOCAL_DIR + "/normalizers"
res_art   = next(a for a in run.logged_artifacts() if a.name.startswith("res_normalizer.pt"))
embed_art = next(a for a in run.logged_artifacts() if a.name.startswith("embed_normalizer.pt"))

res_dir   = res_art.download(root=NORMALIZERS_DIR)
embed_dir = embed_art.download(root=NORMALIZERS_DIR)

from train_probe import Normalizer
res_norm  = Normalizer(**torch.load(os.path.join(NORMALIZERS_DIR,   "res_normalizer.pt"),   map_location="cpu"))
emb_norm  = Normalizer(**torch.load(os.path.join(NORMALIZERS_DIR, "embed_normalizer.pt"), map_location="cpu"))


[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [42]:
import sonar_utils

# Initialize the SONAR models

text2vec, vec2text = sonar_utils.init_sonar()

[SONAR] init on device=cuda, dtype=torch.bfloat16
2025-09-24 10:04:21,603 - Using the cached checkpoint of text_sonar_basic_encoder. Set `force` to `True` to download again.
2025-09-24 10:04:28,134 - Using the cached tokenizer of text_sonar_basic_encoder. Set `force` to `True` to download again.
2025-09-24 10:04:28,202 - Using the cached checkpoint of text_sonar_basic_decoder. Set `force` to `True` to download again.
2025-09-24 10:04:36,986 - Using the cached tokenizer of text_sonar_basic_encoder. Set `force` to `True` to download again.


In [43]:
def eval_chunks(chunk_ids, batch_size=32):
    import torch, torch.nn.functional as F, einops, os
    from pathlib import Path

    global probe, res_norm, emb_norm, RESIDUALS_PATH, EMBEDS_PATH, DTYPE

    # Make sure the probe is on a device and in eval mode
    if next(probe.parameters()).device.type == "cpu" and torch.cuda.is_available():
        probe.to("cuda")
    probe.eval()

    mse_all, cos_all, examples = [], [], []

    for chunk_id in chunk_ids:
        res_path = Path(RESIDUALS_PATH) / f"res_data_{chunk_id:03d}.pt"
        res_list = torch.load(res_path, map_location="cpu")

        emb_path = Path(EMBEDS_PATH) / f"outlines_{chunk_id:03d}.pt"
        embeds = torch.load(emb_path, map_location="cpu").to(dtype=DTYPE)

        assert len(res_list) == 1000 and len(embeds) == 1000, f"Count mismatch in chunk {chunk_id}"

        # The single source of truth for device is the model's device
        model_dev = next(probe.parameters()).device

        for i in range(0, 1000, batch_size):
            res_batch, emb_batch = [], []
            end = min(i + batch_size, 1000)
            for j in range(i, end):
                r = res_list[j]["res"].to(dtype=DTYPE)               # CPU tensor
                x = einops.rearrange(r[:, :1, :], "l p d -> p l d")  # [1,L,D] on CPU
                x = res_norm.normalize(x)                            # normalize on CPU

                y = emb_norm.normalize(embeds[j])                    # [1024] on CPU
                y = y.unsqueeze(0)                                   # [1,1024] on CPU

                res_batch.append(x)
                emb_batch.append(y)

            # NOW move to the model's device (not a global variable)
            X = torch.cat(res_batch, dim=0).to(model_dev, non_blocking=True)
            Y = torch.cat(emb_batch, dim=0).to(model_dev, non_blocking=True)

            # Safety check (remove after first successful run)
            assert next(probe.parameters()).device == X.device, \
                f"Probe on {next(probe.parameters()).device}, X on {X.device}"

            with torch.no_grad():
                P = probe(X)                             # [B, 1024]
                mse = ((P - Y) ** 2).mean(dim=1)
                cos = F.cosine_similarity(P, Y, dim=1)

            mse_all.extend(mse.detach().cpu().tolist())
            cos_all.extend(cos.detach().cpu().tolist())

            if len(examples) < 12:
                keep = min(3, P.shape[0])
                examples.append({
                    "chunk": chunk_id,
                    "pred": P[:keep].detach().cpu(),
                    "gold": Y[:keep].detach().cpu(),
                })

    mse_t = torch.tensor(mse_all)
    cos_t = torch.tensor(cos_all)
    print(f"Eval on chunks {chunk_ids}:")
    print(f"  MSE  mean={mse_t.mean().item():.4f} | median={mse_t.median().item():.4f}")
    print(f"  COS  mean={cos_t.mean().item():.4f} | median={cos_t.median().item():.4f}")
    return examples


In [44]:

examples = eval_chunks([996], batch_size=32)

Eval on chunks [996]:
  MSE  mean=0.5437 | median=0.5377
  COS  mean=0.6798 | median=0.6833


In [45]:
import torch

def _vec2text_device_dtype():
    """
    Infer the actual device & dtype used by the SONAR decoder weights.
    """
    # Newer sonar has .model; older might differ — try a few options
    for attr in ["model", "_model", "decoder", "_decoder"]:
        m = getattr(vec2text, attr, None)
        if m is not None:
            print(attr)
            try:
                p = next(m.parameters())
                return p.device, p.dtype
            except Exception:
                pass

    # Absolute fallback: try attributes or default to CUDA bf16 if available
    dev = getattr(vec2text, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    dt  = getattr(vec2text, "dtype", torch.bfloat16 if dev.type == "cuda" else torch.float32)
    return dev, dt


def decode_embeddings(tensors, target_lang="eng_Latn"):
    """
    tensors: list[1D torch.Tensor] or a single 2D torch.Tensor [B, 1024]
    Cast inputs to the actual decoder models device & dtype before predict().
    """
    # normalize to list of 1D tensors
    if isinstance(tensors, torch.Tensor):
        if tensors.ndim == 1:
            batch = [tensors]
        elif tensors.ndim == 2:
            batch = [tensors[i] for i in range(tensors.size(0))]
        else:
            raise ValueError("Expected 1D or 2D tensor for embeddings.")
    else:
        batch = list(tensors)

    dec_device, dec_dtype = _vec2text_device_dtype()

    # IMPORTANT: single-step cast to both device & dtype
    batch = [t.detach().to(device=dec_device, dtype=dec_dtype) for t in batch]

    # SONAR expects a list of tensors
    return vec2text.predict(batch, target_lang=target_lang)


In [46]:
# Test run 
fake = torch.randn(1024)
print(decode_embeddings([fake, fake]))


model
['administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative administrative ad

In [47]:

DEVICE = next(probe.parameters()).device
# restore embeddings for decoding from their normalized values
def emb_restore(z_norm: torch.Tensor) -> torch.Tensor:
    return z_norm * (emb_norm.std + 1e-6) + emb_norm.mean



def eval_and_decode_chunks(chunks, per_chunk=10, batch_size=32, csv_path=f"decoded_outputs_{RUN_ID}_best_epoch_{best_epoch}.csv"):
    probe.eval().to(DEVICE)

    all_mse, all_cos = [], []
    rows = []  

    for chunk_id in chunks:
        res_path = Path(RESIDUALS_PATH) / f"res_data_{chunk_id:03d}.pt"
        emb_path = Path(EMBEDS_PATH)    / f"outlines_{chunk_id:03d}.pt"
        res_list = torch.load(res_path, map_location="cpu")
        embeds   = torch.load(emb_path,  map_location="cpu").to(dtype=DTYPE)

        assert len(res_list) == 1000 and len(embeds) == 1000, f"Count mismatch in chunk {chunk_id}"

        # --- evaluate whole chunk (normalized metrics) ---
        for i in range(0, 1000, batch_size):
            res_batch, emb_batch = [], []
            end = min(i + batch_size, 1000)
            for j in range(i, end):
                r = res_list[j]["res"].to(dtype=DTYPE)                 # [L, P, D]
                x = einops.rearrange(r[:, :1, :], "l p d -> p l d")    # [1, L, D]
                x = res_norm.normalize(x)                              # CPU
                y = emb_norm.normalize(embeds[j]).unsqueeze(0)         # [1, 1024]
                res_batch.append(x); emb_batch.append(y)

            X = torch.cat(res_batch, 0).to(DEVICE, non_blocking=True)  # [B, L, D]
            Y = torch.cat(emb_batch, 0).to(DEVICE, non_blocking=True)  # [B, 1024]
            with torch.no_grad():
                P = probe(X)
                all_mse.extend(((P - Y) ** 2).mean(dim=1).cpu().tolist())
                all_cos.extend(F.cosine_similarity(P, Y, dim=1).cpu().tolist())

        # --- decode ONLY `per_chunk` samples (evenly spaced across the 1000) ---
        if per_chunk <= 0:
            continue
        step = max(1000 // per_chunk, 1)
        decode_idxs = [min(k * step, 999) for k in range(per_chunk)]

        print(f"\n=== Chunk {chunk_id} : decoding {len(decode_idxs)} samples ===")
        for idx in decode_idxs:
            r = res_list[idx]["res"].to(dtype=DTYPE)
            x = einops.rearrange(r[:, :1, :], "l p d -> p l d")  # [1, L, D]
            x_n = res_norm.normalize(x)
            y_n = emb_norm.normalize(embeds[idx]).unsqueeze(0)   # [1, 1024]

            with torch.no_grad():
                y_hat_n = probe(x_n.to(DEVICE)).squeeze(0)

            mse = F.mse_loss(y_hat_n, y_n.to(DEVICE)).item()
            cos = F.cosine_similarity(y_hat_n.unsqueeze(0), y_n.to(DEVICE), dim=1).item()
            print(f"[{idx:04d}] MSE={mse:.6f} | COS={cos:.4f}")

            # restore embeddings for decoding from their normalized values
            y_hat_rest = emb_restore(y_hat_n)
            y_gt_rest  = emb_restore(y_n.squeeze(0))

            decoded = decode_embeddings([y_gt_rest, y_hat_rest])

            print("Original decoded:",  decoded[0])
            print("Predicted decoded:", decoded[1])
            print("-" * 80)

            
            rows.append({
                "chunk": chunk_id,
                "index": idx,
                "mse": mse,
                "cosine": cos,
                "decoded_original": decoded[0],
                "decoded_predicted": decoded[1],
            })

    
    mse_t = torch.tensor(all_mse); cos_t = torch.tensor(all_cos)
    print("\n=== Overall metrics across chunks", chunks, "===")
    print(f"MSE  mean={mse_t.mean().item():.4f} | median={mse_t.median().item():.4f}")
    print(f"COS  mean={cos_t.mean().item():.4f} | median={cos_t.median().item():.4f}")

    
    if rows:
        df = pd.DataFrame(rows)
        df.to_csv(csv_path, index=False)
        print(f"\n Saved decoded outputs to {csv_path}")
        return df
    else:
        print("No decoded samples collected.")
        return None


In [None]:

eval_and_decode_chunks([996, 997, 998, 999], per_chunk=10, batch_size=32)


=== Chunk 996 : decoding 10 samples ===
[0000] MSE=0.626499 | COS=0.6626
model


  mse = F.mse_loss(y_hat_n, y_n.to(DEVICE)).item()


Original decoded: Summary: 1.History map of the Helsinki University Library - reflects Finnish history and politics 2.Early Swedish identity and national identity - Influenced by Swedish and Russian history through the formation of the map 3.Use of national identity and sovereignty to demonstrate autocracy 4.Impacts on historical events - Finnish war and national revival 5.Philosophical heritage and the preservation of Finland's national identity - Essential to understanding the significance of Finnish maps
Predicted decoded: Summary: 1. map of Dutch urban areas - 2. historical preservation of libraries - 3. contexts of national security and migration - 4. conservation of archives - 3. historical influence - 4. mapping of archives and national security - 5. cultural issues - 5. the role of European cartography - 5. historical preservation and museums - 5. the importance of national information - 7. the management of national archives and geography
--------------------------------------