In [None]:
def build_pwm(seqs, length=9):
    aa = "ACDEFGHIKLMNPQRSTVWY"
    pwm = np.zeros((20, length))
    for s in seqs:
        if len(s) != length:
            continue
        for i, res in enumerate(s):
            if res not in aa:         # skip non-canonical
                break
            pwm[aa.index(res), i] += 1
    pwm = pwm / pwm.sum(axis=0, keepdims=True)          # column normalise
    return pwm

def max_identity(query, ref_set):
    if len(query) in {len(r) for r in ref_set}:         # fast path – same length
        return max(sum(q==r for q, r in zip(query, ref)) / len(query)
                   for ref in ref_set)
    # otherwise do Smith–Waterman (score=1, mismatch=0, no gaps)
    best = 0
    for ref in ref_set:
        aln = pairwise2.align.localms(query, ref, 1, 0, -10, -10, one_alignment_only=True)
        if aln:
            _, _, score, _, _ = aln[0]
            best = max(best, score / max(len(query), len(ref)))
    return best

# 1. split sets
testing_set  = training_data
generated_set = final_df[final_df['measured'].isna()][['sequence', 'HLA']]

testing_pep   = testing_set['sequence'].tolist()
generated_pep = generated_set['sequence'].tolist()

# 2. Jensen–Shannon divergence on PWMs (length 9 by default)
L = 9
pwm_test = build_pwm(testing_pep, length=L)
pwm_gen  = build_pwm(generated_pep, length=L)

# flatten 20×L → 20L vector, add small ε to avoid log(0)
eps  = 1e-9
jsd  = jensenshannon(pwm_test.flatten()+eps, pwm_gen.flatten()+eps, base=2)
print(f"JS-divergence between measured and RFdiffusion PWMs (length {L}): {jsd:.4f} bits")

# 3. max identity distribution (might take a few minutes for large sets)
ref_set = set(testing_pep)                       # for O(1) exact look-ups
identity_scores = [max_identity(p, ref_set) for p in generated_pep]

# save for plotting (e.g., seaborn.histplot)
id_series = pd.Series(identity_scores, name='max_identity_to_training')
# id_series.to_csv('/global/scratch/users/sergiomar10/ESMCBA/analysis/max_identity_generated_vs_training.csv',
#                  index=False)
id_series.describe()

In [None]:
import random, string, pandas as pd
aa = list("ACDEFGHIKLMNPQRSTVWY")

def mutate_keep_anchors(native_pep, n_mutants=50, anchors=(1,8)):
    """anchors are 0-indexed positions to keep (P2=1, PΩ=8 for 9-mers)."""
    mutants = []
    for _ in range(n_mutants):
        s = list(native_pep)
        for i in range(len(s)):
            if i in anchors:           # keep anchor residue
                continue
            s[i] = random.choice(aa)
        mutants.append("".join(s))
    return mutants

# example
native_peptide = "LLFGYPVYV"           # A*02:01 binder in 1AKJ
mutants = mutate_keep_anchors(native_peptide)


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, numpy as np

tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModelForCausalLM.from_pretrained("facebook/esm2_t33_650M_UR50D").eval()

def sample_9mers(n_samples=5000, temperature=1.0):
    # Encode BOS + EOS special tokens
    bos = tok.cls_token_id
    eos = tok.eos_token_id
    samples = []
    with torch.no_grad():
        for _ in range(n_samples):
            seq_ids = [bos]
            while True:
                logits = model(torch.tensor(seq_ids)[None, :]).logits[0, -1, :] / temperature
                next_id = torch.multinomial(torch.softmax(logits, dim=-1), 1).item()
                if next_id == eos or len(seq_ids) > 9:
                    break
                seq_ids.append(next_id)
            if len(seq_ids) == 10:                    # 9 aa + BOS
                samples.append(tok.decode(seq_ids[1:]))
    return samples
