# Loading pLMs from Wilke Lab Shared Folder  
This repository is intended to facilitate the use of pLM models without using up all the space in your home directory.   

All models are located at:   
/stor/work/Wilke/wilkelab/pLMs_checkpoints/   


NOTE: All these models are not in eval() mode by default. So if you are loading them for inference, DO NOT forget to do the following:   
``` 
model.eval()
```

In [None]:
!tree /stor/work/Wilke/wilkelab/pLMs_checkpoints/

[01;34m/stor/work/Wilke/wilkelab/pLMs_checkpoints/[00m
├── [01;34mAMPLIFY[00m
│   ├── [01;34mAMPLIFY_120M[00m
│   │   ├── amplify.py
│   │   ├── config.json
│   │   ├── config.yaml
│   │   ├── model.safetensors
│   │   ├── pytorch_model.pt
│   │   ├── rmsnorm.py
│   │   ├── rotary.py
│   │   ├── special_tokens_map.json
│   │   ├── tokenizer_config.json
│   │   └── tokenizer.json
│   └── [01;34mAMPLIFY_350M[00m
│       ├── amplify.py
│       ├── config.json
│       ├── model.safetensors
│       ├── README.md
│       ├── rmsnorm.py
│       ├── rotary.py
│       ├── special_tokens_map.json
│       ├── tokenizer_config.json
│       └── tokenizer.json
├── [01;34mESM1[00m
│   ├── esm1b_t33_650M_UR50S-contact-regression.pt
│   └── esm1b_t33_650M_UR50S.pt
├── [01;34mESM2[00m
│   ├── esm2_t12_35M_UR50D-contact-regression.pt
│   ├── esm2_t12_35M_UR50D.pt
│   ├── esm2_t30_150M_UR50D-contact-regression.pt
│   ├── esm2_t30_150M_UR50D.pt
│   ├── esm2_t33_650M_UR50D-contact-regression.pt


# Emptying massive space used by ESM2 3B and 15B

In [None]:
# Models downloaded from Torch Hub (via ESM2 scripts) or Hugging Face will be stored in the cache.
# Uncomment the following lines to clear the cache if needed:

#rm -rf .cache/torch/hub/checkpoints/

#rm -rf .cache/huggingface/hub/


# Loading ESM2 models.

In [4]:
import os
import esm
import torch

model_path = '/stor/work/Wilke/wilkelab/pLMs_checkpoints/'

supported_models = [
            'esm2_t6_8M_UR50D', 'esm2_t12_35M_UR50D', 'esm2_t30_150M_UR50D', 
            'esm2_t33_650M_UR50D', 'esm2_t36_3B_UR50D', 'esm2_t48_15B_UR50D'
            ]

# seting a device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Choose from the list of models above
model_name = 'esm2_t12_35M_UR50D'  

checkpoint_path = os.path.join(model_path, 'ESM2', model_name + '.pt')

model, alphabet = esm.pretrained.load_model_and_alphabet(checkpoint_path)


# Load the model to the device
model = model.to(device)

#if loading the model for inference, such as extracting embeddings, you can set the model to evaluation mode:
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 480, padding_idx=1)
  (layers): ModuleList(
    (0-11): 12 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=480, out_features=480, bias=True)
        (v_proj): Linear(in_features=480, out_features=480, bias=True)
        (q_proj): Linear(in_features=480, out_features=480, bias=True)
        (out_proj): Linear(in_features=480, out_features=480, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((480,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=480, out_features=1920, bias=True)
      (fc2): Linear(in_features=1920, out_features=480, bias=True)
      (final_layer_norm): LayerNorm((480,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=240, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((480,), eps=1e-05, elementw

## Example of use os ESM2 models:  

In [None]:
# Extracting embeddings using the ESM2 script available in the ESM2 repository

model_full_path = os.path.join(model_path, 'ESM2', 'esm2_t6_8M_UR50D.pt')

#python scripts/extract.py model_full_path data/prot_seqs.fasta embeddings/esm2_650M/prot_seqs/ --repr_layers 6 --include mean

# Loading ESMC models

In [None]:
from esm.models.esmc import ESMC
from esm.tokenization import get_esmc_model_tokenizers

In [None]:
# load the models locally
def ESMC_300M_202412(model_path: str, device: torch.device | str = "cpu"):
    with torch.device(device):
        model = ESMC(
            d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
        )
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    # Convert model parameters to torch.bfloat16 or torch.float32
    model = model.to(torch.float32)
    return model


def ESMC_600M_202412(model_path: str, device: torch.device | str = "cpu"):
    with torch.device(device):
        model = ESMC(
            d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
        )
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    # Convert model parameters to float32
    model = model.to(torch.float32)
    return model

In [None]:
model_name = 'esmc_300m_2024_12_v0.pth'  
checkpoint_path = os.path.join(path, 'ESMC', model_name)

model = ESMC_300M_202412(checkpoint_path)
model

ESMC(
  (embed): Embedding(64, 960)
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0-29): 30 x UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=960, out_features=2880, bias=False)
          )
          (out_proj): Linear(in_features=960, out_features=960, bias=False)
          (q_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=960, out_features=5120, bias=False)
          (2): SwiGLU()
          (3): Linear(in_features=2560, out_features=960, bias=False)
        )
      )
    )
    (norm): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
  )
  (sequ

In [2]:
# Use the model tokenizer
model._tokenize("A"*2048)

NameError: name 'model' is not defined

In [None]:
model_name = 'esmc_600m_2024_12_v0.pth'  
checkpoint_path = os.path.join(path, 'ESMC', model_name)

model = ESMC_600M_202412(checkpoint_path)
model

ESMC(
  (embed): Embedding(64, 1152)
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0-35): 36 x UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1152, out_features=3456, bias=False)
          )
          (out_proj): Linear(in_features=1152, out_features=1152, bias=False)
          (q_ln): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1152, out_features=6144, bias=False)
          (2): SwiGLU()
          (3): Linear(in_features=3072, out_features=1152, bias=False)
        )
      )
    )
    (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True)


# Loading AMPLIFY Models

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer

In [29]:
def AMPLIFY(model_checkpoint):
    model = AutoModel.from_pretrained(model_checkpoint, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True)
    return model, tokenizer

In [None]:
model_name = 'AMPLIFY_120M/'
checkpoint_path = os.path.join(path, 'AMPLIFY', model_name)
model, tokenizer = AMPLIFY(model_checkpoint)
model

AMPLIFY(
  (encoder): Embedding(27, 640, padding_idx=0)
  (transformer_encoder): ModuleList(
    (0-23): 24 x EncoderBlock(
      (q): Linear(in_features=640, out_features=640, bias=False)
      (k): Linear(in_features=640, out_features=640, bias=False)
      (v): Linear(in_features=640, out_features=640, bias=False)
      (wo): Linear(in_features=640, out_features=640, bias=False)
      (resid_dropout): Dropout(p=0, inplace=False)
      (ffn): SwiGLU(
        (w12): Linear(in_features=640, out_features=3424, bias=False)
        (w3): Linear(in_features=1712, out_features=640, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (ffn_dropout): Dropout(p=0, inplace=False)
    )
  )
  (layer_norm_2): RMSNorm()
  (decoder): Linear(in_features=640, out_features=27, bias=True)
)