In [2]:
import pandas as pd, numpy as np
import sys, os
from pathlib import Path

# Ensure we are at the project root (the folder that contains `src/`)
# If your notebook sits in the root, this is already correct.
root = Path.cwd()

# If your notebook lives somewhere else, climb up until we see 'src'
while not (root / "src").exists() and root.parent != root:
    root = root.parent

# Put project root and src/ on sys.path
sys.path.insert(0, str(root))
sys.path.insert(0, str(root / "src"))

print("Project root:", root)
print("Has src?:", (root / "src").exists())

# Force a clean import of the latest file
import importlib, src.esm_feats as esm_feats
importlib.reload(esm_feats)

# See what the module actually exports
print([n for n in dir(esm_feats) if "emb" in n or "shot" in n or "load" in n])

from src.esm_feats import embed_dataframe, load_esm1v


Project root: d:\Bioinformatics\Rosaloid
Has src?: True
['__loader__', 'embed_dataframe', 'get_embedding', 'load_esm1v', 'zero_shot_dataframe', 'zero_shot_score']


In [3]:
# ---- knobs ----
BATCH = 96          # use 24 if you want tiny
SHORTLIST = 5*BATCH # diversify beyond the strict top-k
RADIUS = 2          # trust radius (num_subs <= RADIUS)

# 1) load zero-shot table and apply trust radius
zs = pd.read_csv("gfp_dms_with_zeroshot.csv").reset_index(drop=True)
pool = zs.query("num_subs <= @RADIUS").copy()

# 2) shortlist by zero-shot score (higher is better)
short = (pool.sort_values("esm1v_zero_shot", ascending=False)
              .head(SHORTLIST)
              .reset_index(drop=True))

# 3) ensure embeddings for the shortlist (uses cache; fast on 4070S)
_ = load_esm1v(device="cuda")   # falls back to cpu if needed
short = embed_dataframe(short)  # adds 'embedding_path' and writes matrix
E = np.load(short["embedding_path"].iloc[0])   # rows align with 'short' order

# 4) diversity pick via farthest-point sampling in embedding space
def farthest_point_sampling(X, k):
    rng = np.random.default_rng(0)
    idx = [int(rng.integers(0, len(X)))]
    dist = np.full(len(X), np.inf, dtype=np.float32)
    for _ in range(1, k):
        d = np.linalg.norm(X - X[idx[-1]], axis=1)
        dist = np.minimum(dist, d)
        idx.append(int(np.argmax(dist)))
    return np.array(idx)

sel = farthest_point_sampling(E, BATCH)
round0 = short.iloc[sel].copy()

# (optional) quick sanity prints
print("Round-0 mean zero-shot:", round0["esm1v_zero_shot"].mean())
print("Avg Hamming to WT:", (round0["mutated_sequence"]
                             .apply(lambda s: sum(a!=b for a,b in zip(s, round0.iloc[0]['mutated_sequence'])))
                             .mean()))

round0.to_csv("round0_batch.csv", index=False)
print(f"Saved round0_batch.csv with {len(round0)} sequences.")



Round-0 mean zero-shot: 2.029902537663778
Avg Hamming to WT: 3.90625
Saved round0_batch.csv with 96 sequences.
