In [None]:
import torch, numpy as np, pandas as pd, uuid, pickle
from pathlib import Path
from tqdm import tqdm
from sklearn.decomposition import IncrementalPCA
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
# 1. Load model

device  = "cuda" if torch.cuda.is_available() else "cpu"
model   = ESMC.from_pretrained("esmc_300m").to(device).eval()
config  = LogitsConfig(return_embeddings=True)   # <-- fixed
print("ESM-C ready")
# 2. Helper: return mean-pooled (960-D) AND full tensor (L×960)

@torch.inference_mode()
def embed_seq(seq: str):
    prot   = ESMProtein(sequence=seq)
    toks   = model.encode(prot).to(device)
    out    = model.logits(toks, config)          # residue embeddings present
    full   = np.asarray(out.embeddings, dtype=np.float32)  # (L,960)
    pooled = torch.from_numpy(full).to(device).mean(dim=0)
    return pooled.cpu().numpy(), full            # (960,), (L,960)
# 3. Read + normalise + pad sequences
INPUT_JSON  = "../data/raw/cd98_test.json"
SEQ_COL     = "sequence"

df = pd.read_json(INPUT_JSON)

def normalise(x):
    if isinstance(x, (list, tuple)):
        return "".join(x)
    if isinstance(x, dict) and "sequence" in x:
        return str(x["sequence"])
    return str(x)

df[SEQ_COL] = df[SEQ_COL].apply(normalise)

max_len = df[SEQ_COL].str.len().max()
PAD_CHAR = "X"
df[SEQ_COL] = df[SEQ_COL].apply(lambda s: s.ljust(max_len, PAD_CHAR))
print(f"All sequences padded to length {max_len}")
# 4. Embed every sequence ➜ collect pooled + tensor + flattened
SAVE_DIR = Path("../data/processed/full_tensors")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

pooled_vecs, tensor_paths, flat_vecs = [], [], []

for seq in tqdm(df[SEQ_COL], desc="Embedding"):
    pooled, tensor = embed_seq(seq)

    # save full tensor
    fname = SAVE_DIR / f"{uuid.uuid4().hex}.npy"
    np.save(fname, tensor)

    pooled_vecs.append(pooled)
    tensor_paths.append(str(fname))
    flat_vecs.append(tensor.flatten())        # for PCA

# 5. Fit Incremental PCA → 1 024 D and transform
batch_size = 16
n_components = 16
pca = IncrementalPCA(n_components=n_components, batch_size=batch_size)
for i in range(0, len(flat_vecs), batch_size):
    pca.partial_fit(np.stack(flat_vecs[i:i+batch_size]))

pca_vecs = pca.transform(np.stack(flat_vecs))  # (N,1024)

# persist PCA model
with open(f"../data/processed/pca_{n_components}.pkl", "wb") as f:
    pickle.dump(pca, f)
print(f"Saved PCA model to ../data/processed/pca_{n_components}.pkl")

# 6. Assemble augmented DataFrame and save to JSON
df["embedding_mean"]     = [v.tolist() for v in pooled_vecs]
df[f"embedding_pca{n_components}"]  = [v.tolist() for v in pca_vecs]
df["tensor_path"]        = tensor_paths

OUTPUT_JSON = "../data/processed/cd98_test_embeds.json"
df.to_json(OUTPUT_JSON, orient="records", force_ascii=False)
print("✓ wrote", OUTPUT_JSON)

# Quick spot-check
sample = df.iloc[0]
print("Mean-pooled len:", len(sample["embedding_mean"]),
      "| PCA len:", len(sample[f"embedding_pca{n_components}"]),
      "| tensor_path:", sample["tensor_path"])
