In [3]:
import torch, math, re
import numpy as np, pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM

# -------------------- Helper -------------------- #
def get_token_index_from_sequence(input_ids_cpu, special_ids, residue_pos):
    """
    Map 1-based residue_pos → token index by skipping special tokens.
    input_ids_cpu: 1D torch.LongTensor on CPU.
    special_ids: set of token IDs treated as 'special' (CLS, SEP, PAD, etc).
    """
    seq_counter = 0
    for idx, tid in enumerate(input_ids_cpu.tolist()):
        if tid not in special_ids:
            seq_counter += 1
            if seq_counter == residue_pos:
                return idx
    raise ValueError(f"Couldn't map residue {residue_pos} to a token index")

# -------------------- HeatmapGenerator -------------------- #
class HeatmapGenerator:
    def __init__(self, model, tokenizer):
        # pick device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model  = model.to(self.device)
        self.tokenizer = tokenizer
        self.special_ids = set(tokenizer.all_special_ids)
        self.mask_id = tokenizer.mask_token_id
        self.model.eval()

    def _get_logits(self, ids, mask=None):
        # ids, mask are already on self.device
        with torch.no_grad():
            if mask is not None:
                return self.model(ids, attention_mask=mask).logits
            else:
                return self.model(ids).logits

    def llrData(self, sequence, start_pos=1, end_pos=None):
        if end_pos is None:
            end_pos = len(sequence)
        aas = list("ACDEFGHIKLMNPQRSTVWY")
        heatmap = np.zeros((20, end_pos - start_pos + 1), dtype=float)

        # 1) Tokenize on CPU
        enc_cpu = self.tokenizer(sequence, return_tensors="pt")
        input_ids_cpu = enc_cpu["input_ids"][0]                       # 1D CPU
        attention_cpu = enc_cpu.get("attention_mask", None)          # 2D CPU or None

        # 2) Move to device
        input_ids = enc_cpu["input_ids"].to(self.device)             # [1, L] GPU/CPU
        attention_mask = (attention_cpu.to(self.device)
                          if attention_cpu is not None else None)

        for col, pos in enumerate(range(start_pos, end_pos+1)):
            # a) map residue → token index
            tok_i = get_token_index_from_sequence(input_ids_cpu,
                                                  self.special_ids,
                                                  pos)
            # b) wild‑type token ID BEFORE masking
            wt_id = input_ids[0, tok_i].item()

            # c) mask and forward
            masked = input_ids.clone()
            masked[0, tok_i] = self.mask_id

            logits = self._get_logits(masked, attention_mask)
            log_probs = torch.log_softmax(logits[0, tok_i], dim=-1)

            lp_wt = log_probs[wt_id].item()
            for i, aa in enumerate(aas):
                aa_id = self.tokenizer.convert_tokens_to_ids(aa)
                hm_val = (log_probs[aa_id].item() - lp_wt
                          if aa_id is not None and aa_id < log_probs.size(0)
                          else 0.0)
                heatmap[i, col] = hm_val

        cols = [str(p) for p in range(start_pos, end_pos+1)]
        return pd.DataFrame(heatmap, index=aas, columns=cols)

# -------------------- Single‑mutation Function -------------------- #
def compute_mutation_llr(model, tokenizer, sequence, mutation, device):
    wt, pos, mut = re.match(r"([A-Z])(\d+)([A-Z])", mutation).groups()
    pos = int(pos)
    if sequence[pos-1] != wt:
        print(f"[!] expected '{wt}' at {pos}, found '{sequence[pos-1]}'")

    # 1) Tokenize on CPU → for mapping
    enc_cpu = tokenizer(sequence, return_tensors="pt")
    input_ids_cpu = enc_cpu["input_ids"][0]
    special_ids = set(tokenizer.all_special_ids)
    tok_i = get_token_index_from_sequence(input_ids_cpu, special_ids, pos)

    # 2) Move inputs to device
    enc = {k: v.to(device) for k, v in enc_cpu.items()}
    input_ids = enc["input_ids"]
    attention_mask = enc.get("attention_mask", None)

    # 3) Mask & forward
    masked = input_ids.clone()
    masked[0, tok_i] = tokenizer.mask_token_id

    model.to(device).eval()
    with torch.no_grad():
        out = (model(masked, attention_mask=attention_mask)
               if attention_mask is not None else model(masked))
        log_probs = torch.log_softmax(out.logits[0, tok_i], dim=-1)

    wt_id  = tokenizer.convert_tokens_to_ids(wt)
    mut_id = tokenizer.convert_tokens_to_ids(mut)
    lp_wt  = log_probs[wt_id].item()
    lp_mut = log_probs[mut_id].item()
    return lp_mut - lp_wt, lp_wt, lp_mut

# -------------------- Comparison Runner -------------------- #
def compare_llr_methods(model, tokenizer, sequence, mutations):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    generator = HeatmapGenerator(model, tokenizer)

    print("="*60)
    for mut in mutations:
        llr_fn, lp_wt, lp_mut = compute_mutation_llr(
            model, tokenizer, sequence, mut, device
        )
        pos = int(re.match(r".(\d+).", mut).group(1))
        df = generator.llrData(sequence, start_pos=pos, end_pos=pos)
        llr_hm = df.loc[mut[-1], str(pos)]

        print(f"Mutation {mut}:")
        print(f"  [Function]     LLR: {llr_fn:.6f}, wt: {lp_wt:.6f}, mut: {lp_mut:.6f}")
        print(f"  [HeatmapData]  LLR: {llr_hm:.6f}")
        print(f"  ΔLLR = {abs(llr_fn - llr_hm):.6f}")
        print("-"*60)

# ESM-2

In [6]:
# Import packages
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

# Example inputs
base_model_name = "facebook/esm2_t30_150M_UR50D" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=True)

adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/esm2_dgoa_finetune_1/checkpoint-3000"

# Load models
model_pretrained = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForMaskedLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

DgoA_seq = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

# List of mutations provided as strings
mutations = ['F33I','D58N','A75V','Q72H','V85A','V154F','Y180F']

print("PRETRAINED ESM2")
compare_llr_methods(model_pretrained, tokenizer, DgoA_seq, mutations)
print("FINETUNED ESM2")
compare_llr_methods(model_finetuned, tokenizer, DgoA_seq, mutations)

PRETRAINED ESM2
Mutation F33I:
  [Function]     LLR: 1.238841, wt: -2.160510, mut: -0.921669
  [HeatmapData]  LLR: 1.238841
  ΔLLR = 0.000000
------------------------------------------------------------
Mutation D58N:
  [Function]     LLR: -3.135317, wt: -0.580529, mut: -3.715846
  [HeatmapData]  LLR: -3.135317
  ΔLLR = 0.000000
------------------------------------------------------------
Mutation A75V:
  [Function]     LLR: -1.263491, wt: -1.756014, mut: -3.019505
  [HeatmapData]  LLR: -1.263491
  ΔLLR = 0.000000
------------------------------------------------------------
Mutation Q72H:
  [Function]     LLR: -1.220701, wt: -0.489650, mut: -1.710351
  [HeatmapData]  LLR: -1.220701
  ΔLLR = 0.000000
------------------------------------------------------------
Mutation V85A:
  [Function]     LLR: -5.180855, wt: -0.158098, mut: -5.338953
  [HeatmapData]  LLR: -5.180855
  ΔLLR = 0.000000
------------------------------------------------------------
Mutation V154F:
  [Function]     LLR: -9.

# ProGen2

In [14]:
# Import packages
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM
from peft import PeftModel, PeftConfig
from autoamp.evolveFinetune import *
import torch
from tqdm import tqdm
import math
from Bio import SeqIO 
import json
import warnings
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

# Example inputs
base_model_name = "hugohrban/progen2-small" 
tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=True)

adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/progen2_dgoa_finetune_1/checkpoint-3000"

# Load models
model_pretrained = AutoModelForCausalLM.from_pretrained(base_model_name)
model_with_adapter = AutoModelForCausalLM.from_pretrained(base_model_name)
model_finetuned = PeftModel.from_pretrained(model_with_adapter, adapter_checkpoint)

DgoA_seq = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

# List of mutations provided as strings
mutations = ['F33I','D58N','A75V','Q72H','V85A','V154F','Y180F']

print("PRETRAINED PROGEN2")
compare_llr_methods(model_pretrained, tokenizer, DgoA_seq, mutations)
print("FINETUNED PROGEN2")
compare_llr_methods(model_finetuned, tokenizer, DgoA_seq, mutations)

The repository for hugohrban/progen2-small contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/hugohrban/progen2-small.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y
The repository for hugohrban/progen2-small contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/hugohrban/progen2-small.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


ProGenForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


The repository for hugohrban/progen2-small contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/hugohrban/progen2-small.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


PRETRAINED PROGEN2
Using model type: ESM on cuda


TypeError: can't assign a NoneType to a torch.cuda.LongTensor

In [9]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import PreTrainedTokenizer

class HeatmapGenerator:
    def __init__(self, model=None, tokenizer=None, model_name=None, is_progen=False):
        """
        Initialize the HeatmapGenerator with either:
        1. A pre-loaded model and tokenizer
        2. A model_name to load from HuggingFace
        
        Args:
            model: A pre-loaded model instance
            tokenizer: A pre-loaded tokenizer
            model_name: HuggingFace model name (used only if model and tokenizer are None)
            is_progen: Set to True when using ProGen models that don't support MLM
        """
        self.is_progen = is_progen
        
        if model is not None and tokenizer is not None:
            self.model = model
            self.tokenizer = tokenizer
        elif model_name is not None:
            if "progen" in model_name.lower():
                self.is_progen = True
                from transformers import AutoModelForCausalLM, AutoTokenizer
                self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
                self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
            else:
                from transformers import EsmForMaskedLM, EsmTokenizer
                self.tokenizer = EsmTokenizer.from_pretrained(model_name)
                self.model = EsmForMaskedLM.from_pretrained(model_name)
            if hasattr(self.model, 'eval'):
                self.model.eval()
        else:
            raise ValueError("Either provide a model and tokenizer, or a model_name")

        # device & specials
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.special_ids = set(self.tokenizer.all_special_ids)
        self.mask_token_id = self.tokenizer.mask_token_id

        print(f"Using model type: {'ProGen' if self.is_progen else 'ESM'} on {self.device}")

    def _get_logits(self, input_ids, attention_mask=None):
        with torch.no_grad():
            input_ids = input_ids.to(self.device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.device)
                out = self.model(input_ids, attention_mask=attention_mask)
            else:
                out = self.model(input_ids)
            # extract logits regardless of output type
            return getattr(out, "logits", out[0] if isinstance(out, tuple) else out)

    @staticmethod
    def _map_residue_to_token(input_ids, special_ids, residue_pos):
        """
        Given a 1D tensor of input_ids on CPU and set of special_ids,
        return the token index corresponding to the residue_pos (1-based),
        skipping any special tokens.
        """
        seq_count = 0
        for idx, tid in enumerate(input_ids.tolist()):
            if tid not in special_ids:
                seq_count += 1
                if seq_count == residue_pos:
                    return idx
        raise ValueError(f"Could not map residue {residue_pos} to a token index")

    def llrData(self, protein_sequence, start_pos=1, end_pos=None):
        seq_len = len(protein_sequence)
        if end_pos is None or end_pos > seq_len:
            end_pos = seq_len

        aas = list("ACDEFGHIKLMNPQRSTVWY")
        heatmap = np.zeros((20, end_pos - start_pos + 1), dtype=float)

        # tokenize once if MLM, or nothing if CLM
        if not self.is_progen:
            # Masked LM: we need full sequence tokenization
            enc = self.tokenizer(protein_sequence, return_tensors="pt")
            input_ids_cpu = enc["input_ids"][0]
            attention_cpu = enc.get("attention_mask", None)
            input_ids = enc["input_ids"].to(self.device)
            attention_mask = attention_cpu.to(self.device) if attention_cpu is not None else None
        else:
            # For ProGen, we will retokenize prefixes each time
            special = None  # unused

        for col, pos in enumerate(range(start_pos, end_pos+1)):
            try:
                if not self.is_progen:
                    # ESM path (masked LM)
                    tok_idx = self._map_residue_to_token(input_ids_cpu,
                                                         self.special_ids, pos)
                    wt_id   = input_ids[0, tok_idx].item()
                    masked  = input_ids.clone()
                    masked[0, tok_idx] = self.mask_token_id
                    logits  = self._get_logits(masked, attention_mask)
                    log_probs = torch.log_softmax(logits[0, tok_idx], dim=-1)
                else:
                    # ProGen path (causal LM)
                    # Build prefix up to pos-1
                    prefix = protein_sequence[: pos-1] if pos>1 else ""
                    enc_prefix = self.tokenizer(prefix, return_tensors="pt")
                    ids_cpu = enc_prefix["input_ids"][0]
                    # map pos-> next-token index = len(ids_cpu)-1
                    wt_res = protein_sequence[pos-1]
                    # wildtype token ID
                    wt_id = self.tokenizer.convert_tokens_to_ids(wt_res)
                    # move to device
                    ids = enc_prefix["input_ids"].to(self.device)
                    mask = enc_prefix.get("attention_mask", None)
                    logits_full = self._get_logits(ids, mask)
                    # next-token logits are at last position
                    log_probs = torch.log_softmax(logits_full[0, -1], dim=-1)

                log_wt = log_probs[wt_id].item()
                for i, aa in enumerate(aas):
                    aa_id = self.tokenizer.convert_tokens_to_ids(aa)
                    heatmap[i, col] = (
                        log_probs[aa_id].item() - log_wt
                        if aa_id is not None and aa_id < log_probs.size(0)
                        else 0.0
                    )

            except Exception as e:
                print(f"Error at pos {pos}: {e}")
                heatmap[:, col] = 0.0

        cols = [str(p) for p in range(start_pos, end_pos+1)]
        return pd.DataFrame(heatmap, index=aas, columns=cols)


    def generate_heatmap(self, protein_sequence, start_pos=1, end_pos=None,
                         figsize=(10, 5), cmap="viridis", tick_interval=5, title=None):
        """
        Plots the heatmap of log_prob_mutant - log_prob_wildtype.
        """
        df = self.llrData(protein_sequence, start_pos, end_pos)
        fig, ax = plt.subplots(figsize=figsize)
        cax = ax.imshow(df.values, cmap=cmap, aspect="auto")

        ax.set_xticks(range(df.shape[1]))
        ax.set_xticklabels(
            [col if i % tick_interval == 0 else "" for i, col in enumerate(df.columns)],
            rotation=90, fontsize=8
        )
        ax.set_yticks(range(len(df.index)))
        ax.set_yticklabels(df.index)

        ax.set_xlabel("Position")
        ax.set_ylabel("Amino Acid")
        model_type = "ProGen" if self.is_progen else "ESM"
        ax.set_title(title or f"LLR Heatmap ({model_type})")
        plt.colorbar(cax, ax=ax, label="LLR (mutant vs WT)")
        plt.tight_layout()
        return fig

    def save_heatmap_data(self, protein_sequence, filename, start_pos=1, end_pos=None):
        """
        Compute the LLR data and save it to CSV.
        Returns the DataFrame for further use.
        """
        df = self.llrData(protein_sequence, start_pos, end_pos)
        df.to_csv(filename)
        print(f"Saved LLR data to {filename}")
        return df


In [10]:
pwd

'/home/sdowell/scratch/Thesis/ADP1/results'

In [12]:
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM, 
    AutoModelForCausalLM
)
from peft import PeftModel

# ─── Your existing imports & data ─────────────────────────────────────
# (you already ran these)
from autoamp.evolveFinetune import *  # if needed
from Bio import SeqIO
import numpy as np, pandas as pd

DgoA_seq = (
    'MQWQTKLPLIAILRGITPDEALAHVGAVIDAGFDAVEIPLNSPQWEQSIPAIVDAYGDKA'
    'LIGAGTVLKPEQVDALARMGCQLIVTPNIHSEVIRRAVGYGMTVCPGCATATEAFTALEA'
    'GAQALKIFPSSAFGPQYIKALKAVLPSDIAVFAVGGVTPENLAQWIDAGCAGAGLGSDLY'
    'RAGQSVERTAQQAAAFVKAYREAVQ'
)

# List of mutations provided as strings
mutations = ['F33I','D58N','A75V','Q72H','V85A','V154F','Y180F']

# ─── 1. ESM‑2 models (masked LM) ────────────────────────────────────────
esm_name = "facebook/esm2_t30_150M_UR50D"
esm_tok  = AutoTokenizer.from_pretrained(esm_name, use_fast=True)

# pretrained
esm_pre   = AutoModelForMaskedLM.from_pretrained(esm_name)
esm_pre_gen = HeatmapGenerator(model=esm_pre, tokenizer=esm_tok)

# finetuned via PEFT adapter
esm_base   = AutoModelForMaskedLM.from_pretrained(esm_name)
adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/esm2_dgoa_finetune_1/checkpoint-3000"
esm_ft     = PeftModel.from_pretrained(esm_base, adapter_checkpoint)
esm_ft_gen = HeatmapGenerator(model=esm_ft, tokenizer=esm_tok)

# ─── 2. ProGen2 models (causal LM) ─────────────────────────────────────
progen_name = "hugohrban/progen2-small"
progen_tok  = AutoTokenizer.from_pretrained(progen_name, trust_remote_code=True)

# pretrained ProGen2
progen_pre   = AutoModelForCausalLM.from_pretrained(progen_name, trust_remote_code=True)
progen_pre_gen = HeatmapGenerator(model=progen_pre, tokenizer=progen_tok, is_progen=True)

# finetuned ProGen2 via PEFT
progen_adapter_checkpoint = "/home/sdowell/scratch/Thesis/ADP1/runs/progen2_dgoa_finetune_1/checkpoint-3000"
progen_base = AutoModelForCausalLM.from_pretrained(progen_name, trust_remote_code=True)
progen_ft   = PeftModel.from_pretrained(progen_base, progen_adapter_checkpoint)
progen_ft_gen = HeatmapGenerator(model=progen_ft, tokenizer=progen_tok, is_progen=True)

gene_seq = {
    "DgoA": DgoA_seq
}

# ── Loop & save CSVs across the full sequence ─────────────────────────
for model_label, gen in [
    ("esm_pretrained",    esm_pre_gen),
    ("esm_finetuned",     esm_ft_gen),
    ("progen_pretrained", progen_pre_gen),
    ("progen_finetuned",  progen_ft_gen),
]:
    for gene, seq in gene_seq.items():
        filename = f"{gene}_{model_label}_full_sequence.csv"
        print(f"Saving full‐sequence LLRs to {filename} ...")
        # omit start_pos/end_pos → defaults to entire sequence
        df = gen.save_heatmap_data(
            protein_sequence=DgoA_seq,
            filename=filename
        )
        # (optional) inspect the head
        print(df.head(), "\n")


Using model type: ESM on cuda
Using model type: ESM on cuda
Using model type: ProGen on cuda
Using model type: ProGen on cuda
Saving full‐sequence LLRs to DgoA_esm_pretrained_full_sequence.csv ...
Saved LLR data to DgoA_esm_pretrained_full_sequence.csv
           1         2         3         4         5         6         7  \
A  -8.607591 -0.254883  0.598394 -1.159047 -0.223114 -1.143873  0.027223   
C -11.709294 -4.224277 -2.918569 -3.368530 -3.439480 -4.316178 -2.283886   
D -10.143125  0.022967 -0.416268 -1.278674  0.414902 -1.062482 -3.250323   
E  -9.222314 -0.473002  0.438081 -0.981727  0.850894 -1.277184 -2.456563   
F -10.871401 -2.358262 -0.152274 -0.042889 -2.167486 -4.575073 -1.300926   

          8         9         10  ...       196       197       198       199  \
A -4.770309 -2.947813  -4.447848  ... -2.200759 -1.534761  0.464405  0.000000   
C -7.264739 -3.379026  -4.942312  ... -3.540176 -2.611518 -4.847313 -5.613589   
D -5.378175 -8.959444 -10.487106  ... -7.156523