In [16]:
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from transformers import AutoTokenizer, AutoModel
from transformers import BitsAndBytesConfig
from torch.nn import functional as F
import re

In [10]:
class Tea(nn.Module, PyTorchModelHubMixin, 
          repo_url="tea", 
          license="mit"):
    """
    The embedding-based Alphabet (tea) model for converting input pLMs embeddings into sequences.

    This model consists of two linear layers, normalization, and dropout,
    followed by a codebook-based decoder. It provides methods to compute Shannon 
    entropy over the output distribution and to convert model outputs into character 
    sequences.

    Args:
        representation_size (int): Dimensionality of input representations.
        hidden_size (int): Hidden size for first linear transformation.
        codebook_size (int): Number of unique tokens (characters) in the alphabet.
        dropout_prob (float): Dropout probability for regularization.
        ignore_token_ids (list[int]): Token ids to ignore when constructing sequences.
    """
    def __init__(
        self,
        representation_size: int,
        hidden_size: int,
        codebook_size: int,
        dropout_prob: float = 0.1,
        ignore_token_ids: list[int] = [0, 1, 2],
    ):
        super().__init__()
        self.representation_size = representation_size
        self.hidden_size = hidden_size
        self.codebook_size = codebook_size
        self.ignore_token_ids = ignore_token_ids

        self.dense = nn.Linear(representation_size, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.decoder = nn.Linear(hidden_size, codebook_size)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.eps = 1e-8  # Add epsilon for numerical stability

        characters = list("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
        self.characters = characters[:self.codebook_size]

    def forward(self, x):
        x = self.dense(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.layer_norm(x)
        x = self.decoder(x)
        return x

    def compute_shannon_entropy(self, logits):
        # Convert logits to probabilities
        probs = F.softmax(logits, dim=-1)
        entropy = -torch.sum(
            probs * torch.log(probs + self.eps), dim=-1
        ) / torch.log(
            torch.tensor(self.codebook_size, dtype=probs.dtype, device=probs.device)
        )
        return entropy

    def to_sequences(
        self,
        input_ids,
        embeddings,
        attention_mask=None,
        logits=None,
        return_avg_entropy=False,
        return_logits=False,
        return_residue_entropy=False,
    ):
        if logits is None:
            logits = self(embeddings)
        
        # Build a mask to ignore specified token ids
        ignore_mask = torch.ones_like(input_ids, dtype=torch.bool)
        for token_id in self.ignore_token_ids:
            ignore_mask &= (input_ids != token_id)
        
        # Get token predictions
        predicted_indices = torch.argmax(logits, dim=-1)
        
        sequences = []
        logits_list = []
        residue_entropy_list = []
        avg_entropy_list = []
        
        # Iterate over all sequences in the batch
        for seq_idx, seq_logits, mask in zip(predicted_indices, logits, ignore_mask):
            filtered_indices = seq_idx[mask]
            filtered_logits = seq_logits[mask]
            sequence = ''.join(self.characters[idx.item()] for idx in filtered_indices)
            sequences.append(sequence)
            logits_list.append(filtered_logits)
            entropies = self.compute_shannon_entropy(filtered_logits)
            residue_entropy_list.append(entropies)
            avg_entropy_list.append(entropies.mean().item())
        
        # Decide on return type
        if not (return_avg_entropy or return_logits or return_entropy):
            return sequences
        
        result = {"sequences": sequences}
        if return_avg_entropy:
            result["avg_entropy"] = avg_entropy_list
        if return_residue_entropy:
            result["residue_entropy"] = residue_entropy_list
        if return_logits:
            result["logits"] = logits_list
        return result


In [11]:
config = {"representation_size": 1280, "hidden_size": 1280, "codebook_size": 20, "dropout_prob": 0.1, "ignore_token_ids": [0, 1, 2]}
model = Tea(**config)
model.eval();

In [12]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move your model to GPU before loading weights
model = model.to(device)

Using device: cpu


In [13]:
weights = torch.load("/scicore/home/schwede/durair0000/projects/alphabeta/AlphaBeta/data/final_ablations/entropy_01/all/checkpoints/entropy_01-all-epoch=9.ckpt", map_location=device)
weights["state_dict"]["model.lm_head.decoder.bias"] = weights["state_dict"]["model.lm_head.decoder.bias"] + weights["state_dict"]["model.lm_head.bias"]
state_dict = weights["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("model.lm_head."):
        new_key = k[len("model.lm_head."):]
        if new_key == "bias":
            continue
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)

  weights = torch.load("/scicore/home/schwede/durair0000/projects/alphabeta/AlphaBeta/data/final_ablations/entropy_01/all/checkpoints/entropy_01-all-epoch=9.ckpt", map_location=device)


<All keys matched successfully>

In [14]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
bnb_config = BitsAndBytesConfig(load_in_4bit=True) if torch.cuda.is_available() else None
esm2 = AutoModel.from_pretrained(
        "facebook/esm2_t33_650M_UR50D",
        torch_dtype="auto",
        quantization_config=bnb_config,
        add_pooling_layer=False,
    )
esm2.eval()
sequence_examples = ["PRTEINO", "SEQWENCE"]
sequence_examples = [" ".join(list(re.sub(r"[UZOBJ]", "X", sequence))) for sequence in sequence_examples]
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
# device = next(model.parameters()).device
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)
with torch.no_grad():
    x = esm2(
        input_ids=input_ids, attention_mask=attention_mask
    ).last_hidden_state.to(device)
    results = model.to_sequences(embeddings=x, input_ids=input_ids, return_avg_entropy=True)
results


{'sequences': ['TPTTHPT', 'TTGHTTTT'],
 'avg_entropy': [0.40297362208366394, 0.21876809000968933]}

# CHANGE!

In [15]:
# save locally
model.save_pretrained("tea")

# push to the hub
model.push_to_hub("lorenzo-pantolini/tea")

# reload
model = AlphaBeta.from_pretrained("lorenzo-pantolini/tea")

HfHubHTTPError: (Request ID: Root=1-691210c4-1eb570ec760b65ea0184194f;9fc66cf8-b604-4276-a5f4-7eb7632d2d78)

403 Forbidden: You don't have the rights to create a model under the namespace "lorenzo-pantolini".
Cannot access content at: https://huggingface.co/api/repos/create.
Make sure your token has the correct permissions.