<a href="https://colab.research.google.com/github/sokrypton/roscon2024/blob/main/finetune_esm2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**finefine esm2** (tutorial for wednesday)
Special Thanks to Amelie Schreiber
https://github.com/Amelie-Schreiber/esm2_loras

In [1]:
model_name = "esm2_t6_8M_UR50D" # @param ["esm2_t33_650M_UR50D", "esm2_t30_150M_UR50D", "esm2_t12_35M_UR50D", "esm2_t6_8M_UR50D"]

In [2]:
%%time
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import EsmForSequenceClassification, EsmForTokenClassification, AutoTokenizer

trainable_params = lambda x: sum(p.numel() for p in x.parameters() if p.requires_grad)


DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = EsmForTokenClassification.from_pretrained(f"facebook/{model_name}",
                                                  num_labels=1,
                                                  hidden_dropout_prob=0.15)

tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}")
model = model.to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


CPU times: user 2.17 s, sys: 356 ms, total: 2.52 s
Wall time: 3.93 s


In [4]:
trainable_params(model)

7737722

In [None]:
model

EsmForTokenClassification(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.15, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.15, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, 

In [3]:
#@markdown ##Get DATA
batch_size = 32 # @param {"type":"integer"}
max_crop_len = 512 # @param {"type":"integer"}

!wget -qnc https://github.com/sokrypton/roscon2024/raw/main/af2bind_data_0.pkl
import pickle
with open("af2bind_data_0.pkl", "rb") as handle:
  DATA = pickle.load(handle)

import numpy as np
import pickle
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset

# Helper function to pad sequences
def pad_sequence(seq, max_len, pad_value=0):
    pad_size = max(0, max_len - len(seq))
    return np.pad(seq, (0, pad_size), 'constant', constant_values=pad_value)[:max_len]

class CustomProteinDataset(Dataset):
    def __init__(self, inputs, attention_masks, outputs, masks, max_crop_len=128):
        self.inputs = inputs
        self.attention_masks = attention_masks
        self.outputs = outputs
        self.masks = masks
        self.max_crop_len = max_crop_len

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_ids = self.inputs[idx]
        attention_mask = self.attention_masks[idx]
        output = self.outputs[idx]
        mask = self.masks[idx]

        # Calculate the true length of the sequence (where attention_mask == 1)
        true_len = int(np.sum(attention_mask))

        # Determine the crop length (if the true length is less than max_crop_len, use true_len)
        crop_len = min(self.max_crop_len, true_len)

        # Randomly sample a crop starting index
        if true_len > crop_len:
            start_idx = np.random.randint(0, true_len - crop_len + 1)
        else:
            start_idx = 0

        # Crop the sequences
        input_ids = input_ids[start_idx:start_idx + crop_len]
        attention_mask = attention_mask[start_idx:start_idx + crop_len].astype(np.float32)
        output = output[start_idx:start_idx + crop_len].astype(np.float32)
        mask = mask[start_idx:start_idx + crop_len].astype(np.float32)

        # Pad the cropped sequences to max_crop_len
        input_ids = pad_sequence(input_ids, self.max_crop_len)
        attention_mask = pad_sequence(attention_mask, self.max_crop_len)
        output = pad_sequence(output, self.max_crop_len)
        mask = pad_sequence(mask, self.max_crop_len)

        return torch.tensor(input_ids), torch.tensor(attention_mask), torch.tensor(output), torch.tensor(mask)

# Create DataLoaders
dataloaders = []
for v in range(3):  # train/test/validation
    dataset = CustomProteinDataset(DATA["inputs"][v], DATA["attention_masks"][v],
                                   DATA["outputs"][v], DATA["masks"][v],
                                   max_crop_len=max_crop_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=(v == 0))
    dataloaders.append(dataloader)

In [4]:
#@markdown ##Training Code

def compute_loss(logits, labels, mask):
  """Compute masked loss."""
  loss = nn.BCEWithLogitsLoss(reduction='none')(logits, labels)
  masked_loss = loss * mask
  mean_loss = masked_loss.sum() / mask.sum()
  return mean_loss

def train_one_epoch(model, dataloader, optimizer):
  """Train the model for one epoch."""
  model.train()
  total_loss = 0

  for batch in dataloader:
    inputs_batch, attention_masks_batch, true_labels_batch, mask_batch = [x.to(DEVICE) for x in batch]

    # Forward pass
    outputs = model(input_ids=inputs_batch, attention_mask=attention_masks_batch)
    logits = outputs.logits.squeeze(-1)

    # Compute loss
    mean_loss = compute_loss(logits, true_labels_batch, mask_batch)

    # Backward pass and optimization
    optimizer.zero_grad()
    mean_loss.backward()
    optimizer.step()

    total_loss += mean_loss.item()

  average_loss = total_loss / len(dataloader)
  return average_loss

def validate(model, dataloader):
    """Validate the model."""
    model.eval()
    total_loss = 0

    with torch.no_grad():
      for batch in dataloader:
        inputs_batch, attention_masks_batch, true_labels_batch, mask_batch = [x.to(DEVICE) for x in batch]

        # Forward pass
        outputs = model(input_ids=inputs_batch, attention_mask=attention_masks_batch)
        logits = outputs.logits.squeeze(-1)

        # Compute loss
        mean_loss = compute_loss(logits, true_labels_batch, mask_batch)

        total_loss += mean_loss.item()

    average_loss = total_loss / len(dataloader)
    return average_loss

def train_model(model, train_dataloader, test_dataloader, num_epochs, optimizer):
    """Train and validate the model."""
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_dataloader, optimizer)
        test_loss = validate(model, test_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

    print("Training complete.")

In [7]:
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 20  # Number of epochs to train
train_model(model, dataloaders[0], dataloaders[1], num_epochs, optimizer)

Epoch 1/20, Train Loss: 0.3418, Test Loss: 0.2845
Epoch 2/20, Train Loss: 0.3179, Test Loss: 0.2827
Epoch 3/20, Train Loss: 0.3076, Test Loss: 0.2754
Epoch 4/20, Train Loss: 0.2998, Test Loss: 0.2737
Epoch 5/20, Train Loss: 0.2943, Test Loss: 0.2698
Epoch 6/20, Train Loss: 0.2901, Test Loss: 0.2686
Epoch 7/20, Train Loss: 0.2843, Test Loss: 0.2684
Epoch 8/20, Train Loss: 0.2806, Test Loss: 0.2681
Epoch 9/20, Train Loss: 0.2736, Test Loss: 0.2737
Epoch 10/20, Train Loss: 0.2657, Test Loss: 0.2861
Epoch 11/20, Train Loss: 0.2512, Test Loss: 0.2865
Epoch 12/20, Train Loss: 0.2442, Test Loss: 0.2897
Epoch 13/20, Train Loss: 0.2340, Test Loss: 0.3044
Epoch 14/20, Train Loss: 0.2193, Test Loss: 0.3309
Epoch 15/20, Train Loss: 0.2132, Test Loss: 0.3183
Epoch 16/20, Train Loss: 0.2029, Test Loss: 0.3780
Epoch 17/20, Train Loss: 0.2003, Test Loss: 0.3446
Epoch 18/20, Train Loss: 0.1851, Test Loss: 0.3667
Epoch 19/20, Train Loss: 0.1681, Test Loss: 0.3600
Epoch 20/20, Train Loss: 0.1600, Test Lo

In [5]:
# https://github.com/huggingface/peft
!pip -q install --no-dependencies peft

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/296.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.4/296.4 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [6]:
from peft import LoraConfig, get_peft_model, TaskType

In [7]:
config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    target_modules=["query", "key", "value"],
    r=4,
    lora_dropout=0.15,
)
model = get_peft_model(model, config)

In [None]:
model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): EsmForTokenClassification(
      (esm): EsmModel(
        (embeddings): EsmEmbeddings(
          (word_embeddings): Embedding(33, 320, padding_idx=1)
          (dropout): Dropout(p=0.15, inplace=False)
          (position_embeddings): Embedding(1026, 320, padding_idx=1)
        )
        (encoder): EsmEncoder(
          (layer): ModuleList(
            (0-5): 6 x EsmLayer(
              (attention): EsmAttention(
                (self): EsmSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=320, out_features=320, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.15, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=320, out_features=4, bias=False)
                    )
                    (lora_B): ModuleDict(
             

In [10]:
trainable_params(model)

46401

In [11]:
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 20  # Number of epochs to train
train_model(model, dataloaders[0], dataloaders[1], num_epochs, optimizer)

Epoch 1/20, Train Loss: 0.5323, Test Loss: 0.3209
Epoch 2/20, Train Loss: 0.3232, Test Loss: 0.2884
Epoch 3/20, Train Loss: 0.3218, Test Loss: 0.2860
Epoch 4/20, Train Loss: 0.3197, Test Loss: 0.2852
Epoch 5/20, Train Loss: 0.3158, Test Loss: 0.2852
Epoch 6/20, Train Loss: 0.3183, Test Loss: 0.2845
Epoch 7/20, Train Loss: 0.3184, Test Loss: 0.2839
Epoch 8/20, Train Loss: 0.3171, Test Loss: 0.2823
Epoch 9/20, Train Loss: 0.3155, Test Loss: 0.2776
Epoch 10/20, Train Loss: 0.3076, Test Loss: 0.2751
Epoch 11/20, Train Loss: 0.3012, Test Loss: 0.2762
Epoch 12/20, Train Loss: 0.3050, Test Loss: 0.2808
Epoch 13/20, Train Loss: 0.3021, Test Loss: 0.2722
Epoch 14/20, Train Loss: 0.2976, Test Loss: 0.2719
Epoch 15/20, Train Loss: 0.2939, Test Loss: 0.2721
Epoch 16/20, Train Loss: 0.2932, Test Loss: 0.2714
Epoch 17/20, Train Loss: 0.2893, Test Loss: 0.2746
Epoch 18/20, Train Loss: 0.2892, Test Loss: 0.2818
Epoch 19/20, Train Loss: 0.2889, Test Loss: 0.2708
Epoch 20/20, Train Loss: 0.2872, Test Lo