In [1]:
import torch
import torch.nn as nn

import numpy as np 
import matplotlib.pyplot as plt

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import SingleLetterAlphabet

from transformers import AutoTokenizer, AutoModel, EsmForProteinFolding

import os
import copy
from tqdm import tqdm

from linear_quant import *

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=False)

model = model.cuda('cuda:3')
model.esm = model.esm.float()

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.weight', 'esm.contact_head.regression.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
num_params = sum(param.numel() for param in model.parameters())
print(num_params)

3527665475


In [4]:
num_params = 0

for name, parameters in model.named_parameters():
    if name.startswith("esm"):
        num_params += parameters.numel()
        # print(name,':',parameters.numel())

print(num_params)
print(num_params / 3527664034)

2838748102
0.8047104470946906


In [5]:
num_params = 0

for name, parameters in model.named_parameters():
    if name.startswith("trunk"):
        num_params += parameters.numel()
        # print(name,':',parameters.numel())

print(num_params)
print(num_params / 3527664034)

688548524
0.19518540239764795


In [6]:
num_params = 0

for name, parameters in model.named_parameters():
    if name.startswith("trunk.structure_module"):
        num_params += parameters.numel()
        # print(name,':',parameters.numel())

print(num_params)
print(num_params / 3527664034)

2019116
0.0005723662969431176


In [7]:
quant_layers = []

for key in model.state_dict().keys():
    key_size = model.state_dict()[key].size()
    # if key.startswith("esm.encoder.layer") or key.startswith("trunk.block") or key.startswith("trunk.structure_module"):
    if key.startswith("esm.encoder.") or key.startswith("trunk.block"):
    # if key.startswith("esm.encoder."):
    # if key.startswith("trunk.block"):
    # if key.startswith("trunk.block") or key.startswith("trunk.structure_module"):
        quant_layers.append(key)

In [8]:
# checkpoint, rmse = quant_checkpoint(model, quant_layers, wei_quant_scheme="pw-1")
checkpoint, rmse = quant_checkpoint(model, quant_layers)
model.load_state_dict(checkpoint)
del checkpoint


total quant RMSE: 9.4016e-04


In [9]:
rmse
# 0.015075289140647612

0.0009401647041497145

In [10]:
seq_fasta = list(SeqIO.parse("../data/sequences_cameo.fasta", "fasta"))

seq_list = [seq.seq.__str__() for seq in seq_fasta]
key_list = [seq.id.__str__().split("_")[0] for seq in seq_fasta]

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
ecoli_tokenized = tokenizer(seq_list, padding=False, add_special_tokens=False)['input_ids']

outputs = []

with torch.no_grad():
    for input_ids in tqdm(ecoli_tokenized):
        input_ids = torch.tensor(input_ids, device='cuda:3').unsqueeze(0)
        output = model(input_ids)
        outputs.append({key: val.cpu() for key, val in output.items()})
        
pdb_list = [convert_outputs_to_pdb(output) for output in outputs]
for identifier, pdb in zip(key_list, pdb_list):
    with open(f"../output/pred_quant_weight_pdb_v1/{identifier}.pdb", "w") as f:
        f.write("".join(pdb))

100%|██████████| 139/139 [13:27<00:00,  5.81s/it]


In [11]:
import os
from tqdm import tqdm
from TMscore import TMscore

real_pdbs = os.listdir("../data/cameo_real_pdb")
pred_quant_pdbs = os.listdir("../output/pred_quant_weight_pdb_v1")

tmscore = TMscore("TMscore")

tmscore_list = []
lddt_list = []
for a, b in tqdm(zip(real_pdbs, pred_quant_pdbs)):
    tmscore(os.path.join("../data/cameo_real_pdb", a), os.path.join("../output/pred_quant_weight_pdb_v1", b))
    score = tmscore.get_tm_score()
    if score is not None:
        tmscore_list.append(tmscore.get_tm_score())

tmscore_pred = sum(tmscore_list) / len(tmscore_list)       

print(tmscore_pred)

139it [01:30,  1.54it/s]

0.8010496402877697





In [12]:
# 0.8006791366906474