<a href="https://colab.research.google.com/github/pranay8297/llm/blob/main/esm/esm_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
! pip install transformers evaluate datasets requests pandas sklearn
! pip install datasets
! pip install evaluate

In [1]:
import math
import numpy as np
import torch

from datasets import load_dataset
from dataclasses import dataclass
from einops import rearrange, repeat
from sklearn.model_selection import train_test_split

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

PFGPT_VOCAB_SIZE = 384
PFGPT_HF_MODEL_PATH = 'lamm-mit/ProteinForceGPT'

@dataclass
class LoRAConfig:
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: int = 0.05
    lora_query: bool = True
    lora_key: bool = False
    lora_value: bool = True
    lora_projection: bool = False
    lora_mlp: bool = False
    lora_head: bool = False

class LoRALinear(nn.Linear):
    def __init__(self, nin, nout, lora_config):
        super().__init__(nin, nout)
        std_dev = 1 / torch.sqrt(torch.tensor(lora_config.lora_r).float())
        self.lora_A = torch.nn.Parameter(torch.randn(nin, lora_config.lora_r) * std_dev)
        self.lora_B = torch.nn.Parameter(torch.zeros(lora_config.lora_r, nout))
        self.alpha = lora_config.lora_alpha

    def forward(self, x):
        lora_x = self.alpha * (x @ self.lora_A @ self.lora_B)
        x = super().forward(x)
        return x + lora_x

def get_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(PFGPT_HF_MODEL_PATH, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:, :, : x.shape[-2], :]
    sin = sin[:, :, : x.shape[-2], :]

    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(torch.nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    """

    def __init__(self, dim: int):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
        inv_freq = inv_freq
        self.register_buffer("inv_freq", inv_freq)

        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=2):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self._cos_cached = emb.cos()[None, None, :, :]
            self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)
        )

class ESMEmbeddings(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.word_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embeddings = nn.Embedding(config.block_size, config.n_embd)

    def post_model_init(self):
        # Merge both the tokenizer vocabs - Battle of Tokenizers
        # That is create a new word embedding of pf_gpts vocab size and configs n_embd
        # Get pf GPT Tokenizer
        pfgpt_tokenizer = AutoTokenizer.from_pretrained(PFGPT_HF_MODEL_PATH, trust_remote_code=True)
        pfgpt_tokenizer.pad_token = pfgpt_tokenizer.eos_token
        pfgpt_vocab = pfgpt_tokenizer.get_vocab()

        # Get ESM Tokenizer
        esm_tokenizer = AutoTokenizer.from_pretrained(self.config.pre_trained_model_name, padding='max_length', max_length=1026)
        esm_vocab = esm_tokenizer.get_vocab()
        new_word_embeddings = nn.Embedding(PFGPT_VOCAB_SIZE, self.config.n_embd)
        torch.nn.init.normal_(new_word_embeddings.weight, std = 0.1263)

        # Find all the common keys tokens between esm tokenizer and pf_gpt tokenizer
        pfgpt_keys = set(pfgpt_vocab.keys())
        esm_keys = set(esm_vocab.keys())
        common_keys = list(pfgpt_keys.intersection(esm_keys))

        # now, copy a particular tokens embedding from ems_embedding to the new embedding that we create here
        with torch.no_grad():
            indices = []
            for key in common_keys:
                esm_embd_index = esm_tokenizer.convert_tokens_to_ids(key)
                pfg_embd_index = pfgpt_tokenizer.convert_tokens_to_ids(key)
                indices.append(pfg_embd_index)
                new_word_embeddings.weight[pfg_embd_index] = self.word_embeddings.weight[esm_embd_index]

            # Check for embedding equivalance
            assert torch.equal(new_word_embeddings.weight[pfgpt_tokenizer.convert_tokens_to_ids(common_keys)],
                               self.word_embeddings.weight[esm_tokenizer.convert_tokens_to_ids(common_keys)])

        # Create a mask for all the indecis we have copied pretrained embeddings
        # and turn requires_grad off to those embeddings that we have copied - This is
        # not possible, so instead we store the indecis and zero out the grads before optim.step()
        # hence we do not update these embeddings
        self.indices = indices

        with torch.no_grad():
            self.word_embeddings = new_word_embeddings
        self.word_embeddings.requires_grad_(True)

    def forward(self, x, attention_mask = None):
        token_embs = self.word_embeddings(x)
        # Not required as we are use rotary embeddings - Hence we do not require absolute position embeddings
        # position_embs = self.esm.embeddings.position_embeddings(torch.arange(0, x.shape[1], 1, dtype = torch.long))
        if attention_mask is not None:
            token_embs = (token_embs * attention_mask.unsqueeze(-1)).to(token_embs.dtype)
        return token_embs

@dataclass
class ESMConfig():
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    hidden_size: int = 4096 # 4 * block_size
    dropout: float = 0.0
    pre_trained_model_name: str = ''

class ESMIntermediateLayer(nn.Module):
    def __init__(self, nin, nout, lora_config, dropout = 0.0, ):
        super().__init__()

        self.dense = nn.Linear(nin, nout) if not lora_config.lora_mlp else LoRALinear(nin, nout, lora_config)
        self.act = nn.GELU(approximate = 'tanh')

    def forward(self, x):
        return self.act(self.dense(x))

class ESMOutLayer(nn.Module):
    def __init__(self, nin, nout, lora_config, dropout = 0.0, inside_attention = False):
        super().__init__()

        # 2 places used - 1. inisde the attention block  and inside the MLP
        # if used inside attention and lora_config.lora_projection is true then dense is a LoRALInear
        # elif used in mlp and lora_config.lora_mlp is true then dens is a LoRALinear again
        # else its a Linear

        if inside_attention == True and lora_config.lora_projection:
            self.dense = LoRALinear(nin, nout, lora_config)
        elif inside_attention == False and lora_config.lora_mlp:
            self.dense = LoRALinear(nin, nout, lora_config)
        else:
            self.dense = nn.Linear(nin, nout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_scores):
        x = self.dense(x)
        x = self.dropout(x)
        x = x + attn_scores
        return x

class ESMSelfAttn(nn.Module): # Verified

    def __init__(self, config, lora_config):
        super().__init__()

        assert config.n_embd % config.n_head == 0

        self.query = nn.Linear(config.n_embd, config.n_embd) if not lora_config.lora_query else LoRALinear(config.n_embd, config.n_embd, lora_config)
        self.key = nn.Linear(config.n_embd, config.n_embd) if not lora_config.lora_key else LoRALinear(config.n_embd, config.n_embd, lora_config)
        self.value = nn.Linear(config.n_embd, config.n_embd) if not lora_config.lora_value else LoRALinear(config.n_embd, config.n_embd, lora_config)
        self.n_head = config.n_head

        attention_head_size = config.n_embd//config.n_head

        # Add a rotary embeddings here
        self.rotary_embeddings = RotaryEmbedding(dim = attention_head_size)

    def forward(self, x, attention_mask):

        # x -> (b, s, e) -> (b s, h, e/h)
        k, q, v = self.key(x), self.query(x), self.value(x)

        k = rearrange(k, 'b s (h e) -> b h s e', h = self.n_head)
        q = rearrange(q, 'b s (h e) -> b h s e', h = self.n_head)
        v = rearrange(v, 'b s (h e) -> b h s e', h = self.n_head)

        # Add rotary embeddings here for k and q tensors
        q, k = self.rotary_embeddings(q, k)

        # Attention claculation - # TODO: make is_casual true in case of finetuning - Very important
        y = F.scaled_dot_product_attention(q, k, v, attn_mask = attention_mask, is_causal = False) # flash attention
        y = rearrange(y, 'b h s e -> b s (h e)', h = self.n_head)
        return y

class ESMAttn(nn.Module): # Verified

    def __init__(self, config, lora_config):
        super().__init__() # No activation function at this level
        self.self = ESMSelfAttn(config, lora_config)
        self.output = ESMOutLayer(config.n_embd, config.n_embd, lora_config, dropout = getattr(config, 'dropout', 0.), inside_attention = True)
        self.LayerNorm = nn.LayerNorm(config.n_embd)

    def forward(self, x, attention_mask):
        inter_x = self.LayerNorm(x)
        attn = self.self(inter_x, attention_mask)
        out = self.output(attn, x)
        return out

class ESMLayers(nn.Module): # Both Init and Forward Verified - Done and Dusted

    def __init__(self, config, lora_config):
        super().__init__()
        self.attention = ESMAttn(config, lora_config)
        self.intermediate = ESMIntermediateLayer(config.n_embd, config.hidden_size, lora_config) #
        self.output = ESMOutLayer(config.hidden_size, config.n_embd, lora_config) #
        self.LayerNorm = nn.LayerNorm(config.n_embd)

    def forward(self, x, attention_mask):
        attention_op = self.attention(x, attention_mask)
        attention_op_ln = self.LayerNorm(attention_op) # This will keep the activations in check - Lets see
        inter = self.intermediate(attention_op_ln)
        out = self.output(inter, attention_op)
        return out

class ESMEncoder(nn.Module):

    def __init__(self, config, lora_config):
        super().__init__()

        # No activation functions here as well

        self.layer = nn.ModuleList([ESMLayers(config, lora_config) for _ in range(config.n_layer)])
        self.emb_layer_norm_after = nn.LayerNorm(config.n_embd)

    def forward(self, x, attention_mask = None):

        for layer in self.layer:
            x = layer(x, attention_mask)

        return self.emb_layer_norm_after(x)

class ESM(nn.Module):

    def __init__(self, config, lora_config):

        super().__init__()
        self.config = config
        self.esm = nn.ModuleDict(dict(
            embeddings = ESMEmbeddings(config),
            encoder = ESMEncoder(config, lora_config), # Done, forward - here
            final_layer = nn.Linear(config.n_embd, config.vocab_size) if not lora_config.lora_head else
                              LoRALinear(config.n_embd, config.vocab_size, lora_config)
        ))
        self.esm.final_layer.weight = self.esm.embeddings.word_embeddings.weight

        # Final Layer bias initializtion
        torch.nn.init.zeros_(self.esm.final_layer.bias) # Set the bias to 0
        # Finally one small thing is to decide weather to add an intermediate layer or not? - Thats a future discussion

    @classmethod
    def get_pretrained_config(cls, model_type = 'esm2_t33_650M_UR50D'):

        '''
        name                n_layers    n_params
        esm2_t48_15B_UR50D	48	        15B
        esm2_t36_3B_UR50D	36	        3B
        esm2_t33_650M_UR50D	33	        650M
        esm2_t30_150M_UR50D	30	        150M
        esm2_t12_35M_UR50D	12	        35M
        esm2_t6_8M_UR50D
        '''

        assert model_type in {'esm2_t36_3B_UR50D', 'esm2_t33_650M_UR50D', 'esm2_t30_150M_UR50D'}

        config_args = {
            'esm2_t36_3B_UR50D': dict(n_layer=36, n_head = 40, n_embd=2560, hidden_size=10240), # 3B params
            'esm2_t33_650M_UR50D': dict(n_layer=33, n_head = 20, n_embd=1280, hidden_size=5120), # 650M params
            'esm2_t30_150M_UR50D': dict(n_layer=30, n_head = 20, n_embd=640, hidden_size=2560), # 150M params
        }[model_type]

        config_args['vocab_size'] = 33 # always 33 for ESM Models
        config_args['block_size'] = 1026 # Always constant for ESM Models
        config_args['pre_trained_model_name'] = f"facebook/{model_type}"

        config = ESMConfig(**config_args)
        return config

    @classmethod
    def from_pretrained(cls, lora_config, model_type = 'esm2_t33_650M_UR50D', embedding_post_init = True):

        config = cls.get_pretrained_config(model_type)
        print("loading weights from pretrained gpt: %s" % model_type)

        # create a from-scratch initialized minGPT model
        model = cls(config, lora_config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        from transformers import AutoModelForSequenceClassification
        num_labels = 33
        model_hf = AutoModelForSequenceClassification.from_pretrained(config.pre_trained_model_name, num_labels = num_labels)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        sd_keys_hf = [k for k in sd_keys_hf if 'inv_freq' not in k]
        sd_keys_hf = [k for k in sd_keys_hf if 'classifier' not in k]

        ignore_keys = ['esm.contact_head.regression.weight', 'esm.contact_head.regression.bias']
        for k in sd_keys_hf:

            if k in ignore_keys: continue

            # vanilla copy over the other parameters
            try: assert sd_hf[k].shape == sd[k].shape
            except Exception as e:
              print(k)
              print(f"Mismatch in the shape of tensor while loading weights - Key: {k}, expected shape: {sd_hf[k].shape}, actual shape: {sd[k].shape if k in sd else k}")

            with torch.no_grad():
                sd[k].copy_(sd_hf[k])

        # Set the final layers bias as 0 so that it does not affect weight tying scheme
        with torch.no_grad():
            model.esm.final_layer.bias.zero_()

        # Freeze the model
        for name, param in model.named_parameters():
            if 'lora_' not in name:
                param.requires_grad = False

        if embedding_post_init:
            model.esm.embeddings.post_model_init()
            del model.esm.final_layer

            #IMP: Here we are assuming that embeddings will never have LoRA attached to it, hence we are going with Linear
            model.esm.final_layer = nn.Linear(config.n_embd, model.esm.embeddings.word_embeddings.weight.shape[0])

            model.esm.final_layer.weight = model.esm.embeddings.word_embeddings.weight

        return model

    def get_extended_attn_mask(self, attention_mask, input_shape):

        if attention_mask == None: return None
        b, s = attention_mask.shape
        # Make the attention mask braodcastable for [batch_size, n_heads, seq_len, seq_len]
        attention_mask = attention_mask[:, None, None, :]

        # Now make sure that it has negetive infinity for all the padded tokens and
        # 0 for all attention tokens as we add this mask to attention scores
        attn_mask = attention_mask.to(torch.float32)
        attn_mask = (1 - attn_mask) * (torch.finfo(torch.float32).min)
        attn_mask = attn_mask.expand(b, 1, s, s)
        return attn_mask

    def forward(self, x, y = None, attention_mask = None, output_encoder_states = True):

        # Calculate Embeddings
        x = self.esm.embeddings(x, attention_mask) # TODO: Verify the new embeddings function without doing post init and after doing post model init - Ideally both should stay the same

        # compute attention_mask for attention scores
        extended_attention_mask = self.get_extended_attn_mask(attention_mask, x.shape)

        #Do the forward pass
        x = self.esm.encoder(x, attention_mask = extended_attention_mask)
        logits = self.esm.final_layer(x)
        output = {'logits': logits}

        if output_encoder_states:
            output['encoder_output'] = x
        if y is not None:
            # Calculate loss and send it in output
            outputs = logits.view(-1, logits.size(-1))  # (bs*seq_len, 384)
            targets = y.view(-1)  # (bs*seq_len)

            # Flatten the attention mask
            attention_mask = attention_mask.view(-1)  # (bs*seq_len)

            # Calculate cross entropy loss
            loss = F.cross_entropy(outputs, targets, reduction='none')

            # Apply the mask to the loss
            masked_loss = loss * attention_mask

            # Calculate the mean loss over the actual tokens (excluding padding)
            total_loss = masked_loss.sum()
            num_tokens = attention_mask.sum()

            actual_loss = total_loss / num_tokens
            output['loss'] = actual_loss

        return output

ds = load_dataset("lamm-mit/GPTProteinPretrained")
tokenizer = get_tokenizer()

class ProtDS(Dataset):
    def __init__(self, sequences, max_len = 1026):

        self.sequences = sequences # list object
        self.max_len = max_len
        self.eos_token = '</s>' # PFGPT's eos token - End of sequence

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        if len(seq) >= self.max_len - 1:
            seq = seq[:self.max_len - 1]

        label = seq[1:] + self.eos_token

        return seq, label

def seq_collate_fn(data):
    # data is a list of tuples
    x, y = zip(*data)
    train_tokenized = tokenizer(x)
    labels_tokenized = tokenizer(y)

    padded_train_tokenized = tokenizer.pad(train_tokenized, padding = 'max_length', max_length = 1026)
    padded_labels_tokenized = tokenizer.pad(labels_tokenized, padding = 'max_length', max_length = 1026)
    padded_train_tokenized['labels'] = padded_labels_tokenized['input_ids']

    for k, v in padded_train_tokenized.items():
        try:
            padded_train_tokenized[k] = torch.tensor(v, dtype = torch.long)
        except:
            breakpoint()

    return padded_train_tokenized



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


# Dataset

In [2]:
sequences = ds['train']['text']
train_seqs, valid_seqs = train_test_split(sequences, test_size = 0.05, shuffle = True)
train_ds, valid_ds = ProtDS(train_seqs), ProtDS(valid_seqs)

# PFGPT Benchmarking

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
ForceGPT_model_name='lamm-mit/ProteinForceGPT'

pf_tokenizer = AutoTokenizer.from_pretrained(ForceGPT_model_name, trust_remote_code=True)
pf_tokenizer.pad_token = pf_tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    ForceGPT_model_name,
    trust_remote_code=True
).to(device)

In [15]:
def pf_seq_collate_fn(data):
    # data is a list of tuples
    x, y = zip(*data)
    train_tokenized = pf_tokenizer(x)
    labels_tokenized = pf_tokenizer(y)

    padded_train_tokenized = pf_tokenizer.pad(train_tokenized, padding = 'max_length', max_length = 1026)
    padded_labels_tokenized = pf_tokenizer.pad(labels_tokenized, padding = 'max_length', max_length = 1026)
    padded_train_tokenized['labels'] = padded_labels_tokenized['input_ids']

    for k, v in padded_train_tokenized.items():
        try:
            padded_train_tokenized[k] = torch.tensor(v, dtype = torch.long)[:, :1024]
        except:
            breakpoint()

    return padded_train_tokenized
valid_dl = DataLoader(valid_ds, batch_size = 4, collate_fn = pf_seq_collate_fn)
batch = next(iter(valid_dl))

In [17]:
batch['input_ids'].shape, batch['attention_mask'].shape, batch['labels'].shape

(torch.Size([4, 1024]), torch.Size([4, 1024]), torch.Size([4, 1024]))

In [20]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(384, 1024)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-35): 36 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)


In [31]:
valid_dl = DataLoader(valid_ds, batch_size = 64, collate_fn = pf_seq_collate_fn)

with torch.no_grad():
    vl_losses = []
    for it, batch in enumerate(valid_dl):
        # breakpoint()
        out = model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
        outputs = out.logits.view(-1, out.logits.size(-1))  # (bs*seq_len, 384)
        targets = batch['labels'].reshape(-1)  # (bs*seq_len)

        # Flatten the attention mask
        attention_mask = batch['attention_mask'].reshape(-1)  # (bs*seq_len)

        # Calculate cross entropy loss
        loss = F.cross_entropy(outputs, targets.to(device))#, reduction='none')

        # Apply the mask to the loss
        # masked_loss = loss * attention_mask.to(device)

        # # Calculate the mean loss over the actual tokens (excluding padding)
        # total_loss = masked_loss.sum()
        # num_tokens = attention_mask.sum()

        # actual_loss = total_loss / num_tokens
        # vl_losses.append(actual_loss)
        vl_losses.append(loss.item())
        progress = it/len(valid_dl)
        print(f"\r Valid Progress: {progress:.5f}%, train loss: {vl_losses[-1]:.6f}", end='')

        if (it+1)%200 == 0: break

 Valid Progress: 0.22020%, train loss: 2.241032

KeyboardInterrupt: 

In [32]:
np.mean(vl_losses)

2.042728337778974

In [34]:
target_sequences = valid_ds.sequences[:64]

In [39]:
encoded_seqs = [pf_tokenizer.encode(target_sequences[0], add_special_tokens = False) for i in target_sequences]

[86,
 104,
 116,
 120,
 104,
 113,
 102,
 104,
 63,
 80,
 72,
 84,
 71,
 78,
 72,
 81,
 78,
 84,
 86,
 79,
 83,
 86,
 78,
 92,
 81,
 83,
 78,
 72,
 89,
 72,
 72,
 74,
 85,
 92,
 84,
 73,
 90,
 79,
 71,
 74,
 78,
 73,
 73,
 72,
 68,
 87,
 74,
 71,
 83,
 71,
 78,
 72,
 83,
 92,
 86,
 76,
 89,
 76,
 83,
 83,
 83,
 81,
 89,
 87,
 74,
 85,
 79,
 75,
 79,
 74,
 75,
 68,
 90,
 71,
 87,
 86,
 80,
 84,
 71,
 87,
 76,
 87,
 85,
 80,
 78,
 85,
 80,
 84,
 74,
 92,
 71,
 89,
 79,
 90,
 79,
 83,
 74,
 80,
 71,
 75,
 68,
 74,
 76,
 68,
 87,
 84,
 68,
 78,
 89,
 72,
 68,
 85,
 79,
 78,
 72,
 86,
 74,
 87,
 81,
 85,
 92,
 72,
 79,
 74,
 85,
 72,
 78,
 73,
 79,
 72,
 78,
 68,
 90,
 72,
 90,
 78,
 72,
 72,
 92,
 68,
 68,
 73,
 76,
 85,
 86,
 84,
 90,
 72,
 78,
 79,
 74,
 79,
 74,
 79,
 71,
 92,
 86,
 85,
 72,
 85,
 73,
 87,
 79,
 71,
 72,
 74,
 79,
 86,
 71,
 68,
 89,
 85,
 72,
 89,
 73,
 89,
 78,
 79,
 92,
 72,
 78,
 74,
 79,
 76,
 92,
 85,
 74,
 72,
 92,
 76,
 76,
 81,
 90,
 71,
 83,
 86,
 87,
 78,
 87

In [35]:
generated = torch.tensor() .unsqueeze(0).to(device)

ValueError: type of None unknown: <class 'NoneType'>. Should be one of a python, numpy, pytorch or tensorflow object.

# Training

In [11]:
model = ESM.from_pretrained(LoRAConfig(lora_r = 32, lora_key = True, lora_mlp = True, lora_projection = True, lora_alpha = 16), 'esm2_t30_150M_UR50D')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

loading weights from pretrained gpt: esm2_t30_150M_UR50D


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
tokens_for_grad_update = 30000
epochs = 1
batch_size = 24
grad_accum_steps = 5

train_dl = DataLoader(train_ds, batch_size = batch_size, collate_fn = seq_collate_fn)
valid_dl = DataLoader(valid_ds, batch_size = batch_size, collate_fn = seq_collate_fn)
iterations = epochs * len(train_dl) + 5

opt = AdamW(model.parameters(), lr = 3e-04, betas = (0.9, 0.95), eps = 1e-05)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr = 6e-04, total_steps = iterations, final_div_factor=10.0)

steps_processed_after_gradstep = 0
total_tokens_trained = 0
loss = None
losses = []

vl_losses_track = {0: -np.log(1/384)} # Ideal loss at epoch 0 before any finetuning
vl_losses_all = []

for i in range(epochs):
    c = 0
    for iter, batch in enumerate(train_dl):

        outputs = model(batch['input_ids'].to(device), y = batch['labels'].to(device), attention_mask = batch['attention_mask'].to(device))
        total_tokens_trained += batch['attention_mask'].sum()

        losses.append(outputs['loss'].item())
        outputs['loss'] = outputs['loss']/grad_accum_steps
        outputs['loss'].backward()
        steps_processed_after_gradstep += 1

        if steps_processed_after_gradstep == grad_accum_steps:
            # Do a backward pass and optimizer step
            norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            lr_scheduler.step()

            steps_processed_after_gradstep = 0
            progress = iter/len(train_dl)
            print(f"Train Progress: {progress:.6f}%, train loss: {losses[-1]:.6f}, norm: {norm:.4f}")
            c += 1

        if c >= 5: # approximately for every 400k tokens trained on, lets calculate the validation loss
            c = 0
            vl_losses = []
            v_iter = 0
            with torch.no_grad():
                for vl_batch in valid_dl:
                    vl_outputs = model(vl_batch['input_ids'].to(device), y = vl_batch['labels'].to(device), attention_mask = vl_batch['attention_mask'].to(device))
                    vl_losses.append(vl_outputs['loss'].item())
                    v_iter += 1

                    if v_iter%10 == 0: break

            vl_losses_all += vl_losses
            vl_losses_track[total_tokens_trained] = np.mean(vl_losses)
            print(f"valid loss: {losses[-1]:.6f}")

Train Progress: 0.000131%, train loss: 6.474083, norm: 15.2240
Train Progress: 0.000295%, train loss: 6.337980, norm: 14.6919
Train Progress: 0.000458%, train loss: 6.048813, norm: 18.4991
Train Progress: 0.000622%, train loss: 5.744795, norm: 20.5863
Train Progress: 0.000785%, train loss: 5.480638, norm: 25.1512
valid loss: 5.480638
Train Progress: 0.000949%, train loss: 5.136886, norm: 31.5467
Train Progress: 0.001113%, train loss: 4.900673, norm: 31.3607
Train Progress: 0.001276%, train loss: 4.690476, norm: 32.7555
Train Progress: 0.001440%, train loss: 4.494570, norm: 33.0060
Train Progress: 0.001603%, train loss: 4.299559, norm: 26.9208
valid loss: 4.299559
Train Progress: 0.001767%, train loss: 4.093726, norm: 31.4063
Train Progress: 0.001931%, train loss: 3.903715, norm: 26.8552
Train Progress: 0.002094%, train loss: 3.689974, norm: 38.7530
Train Progress: 0.002258%, train loss: 3.488058, norm: 28.6293
Train Progress: 0.002421%, train loss: 3.366394, norm: 37.4899
valid loss: 3

KeyboardInterrupt: 

In [14]:
torch.save(model, 'model_finetune_v2.pt')