In [1]:
import torch
import torch.nn as nn

import numpy as np 
import matplotlib.pyplot as plt

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import SingleLetterAlphabet

from transformers import AutoTokenizer, AutoModel, EsmForProteinFolding

import os
import copy
from tqdm import tqdm

from linear_quant import *

2023-03-30 16:26:54.761038: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-30 16:26:54.881856: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-03-30 16:26:55.429870: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-03-30 16:26:55.429936: W tensorflow/

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

model = model.cpu()

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model.esm = model.esm.float()
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

In [4]:
device = torch.device('cpu')
quantized_model = quantized_model.to(device)

In [5]:
seq_fasta = list(SeqIO.parse("../data/sequences_cameo.fasta", "fasta"))

In [6]:
seq_list = [seq.seq.__str__() for seq in seq_fasta]
key_list = [seq.id.__str__().split("_")[0] for seq in seq_fasta]

ecoli_tokenized = tokenizer(seq_list, padding=False, add_special_tokens=False)['input_ids']

In [7]:
from tqdm import tqdm

outputs = []

with torch.no_grad():
    for input_ids in tqdm(ecoli_tokenized):
        input_ids = torch.tensor(input_ids, device='cpu').unsqueeze(0)
        output = quantized_model(input_ids)
        outputs.append({key: val.cpu() for key, val in output.items()})
        
pdb_list = [convert_outputs_to_pdb(output) for output in outputs]
for identifier, pdb in zip(key_list, pdb_list):
    with open(f"../output/pred_origin_quant_pdb/{identifier}.pdb", "w") as f:
        f.write("".join(pdb))

100%|██████████| 139/139 [10:25:35<00:00, 270.04s/it]  


In [8]:
import os

real_pdbs = os.listdir("real_pdb")
pred_pdbs = os.listdir("pred_quant_pdb")
print(len(real_pdbs))
print(len(pred_pdbs))

139
139


In [12]:
from TMscore import TMscore

tmscore = TMscore("TMscore")

tmscore_list = []
for a, b in zip(real_pdbs, pred_pdbs):
    tmscore(os.path.join("real_pdb", a), os.path.join("pred_quant_pdb", b))
    score = tmscore.get_tm_score()
    if score is not None:
        tmscore_list.append(tmscore.get_tm_score())


7VNA.pdb 0.9134
7TCR.pdb 0.2635
7QRY.pdb 0.8576
7N3T.pdb 0.7266
7MKU.pdb 0.8882
7OB6.pdb 0.8056
7Q05.pdb 0.8936
7R09.pdb 0.9426
7U5Y.pdb 0.9422
7F0A.pdb 0.885
7QBG.pdb 0.8832
7P0H.pdb 0.5517
7ULH.pdb 0.827
7VU7.pdb 0.4124
7PNO.pdb 0.5722
7F2Y.pdb 0.8933
7RAW.pdb 0.3376
7RPS.pdb 0.8201
7ETS.pdb 0.8868
7QAO.pdb 0.3219
7PC3.pdb 0.8058
7PSG.pdb 0.8437
7O4O.pdb 0.9646
7EQH.pdb 0.8225
7SCI.pdb 0.9068
7Q4L.pdb 0.6965
7MYV.pdb 0.8819
7TNI.pdb 0.7011
7OPB.pdb 0.9863
7T7Y.pdb 0.9769
7SO5.pdb 0.9255
7W5U.pdb 0.9722
7Z79.pdb 0.9316
7VGM.pdb 0.9437
7ED6.pdb 0.8983
7N0E.pdb 0.9855
7RQF.pdb 0.6335
7TV9.pdb 0.8211
7F9H.pdb 0.7484
7OA7.pdb 0.596
7POI.pdb 0.4347
7PB4.pdb 0.8773
7KOB.pdb 0.9574
7OSW.pdb 0.246
7KO9.pdb 0.9412
7PC4.pdb 0.7963
7MLA.pdb 0.5275
7YXG.pdb 0.944
7X8V.pdb 0.8576
7QAP.pdb 0.3332
7TZG.pdb 0.4905
7PC7.pdb 0.8064
7U2R.pdb 0.9844
7MSK.pdb 0.858
7V4S.pdb 0.9606
7RCZ.pdb 0.8831
7QS2.pdb 0.9424
7UGH.pdb 0.897
7S2R.pdb 0.7594
7F3A.pdb 0.9469
7N29.pdb 0.8342
7VWT.pdb 0.9364
7TBU.pdb 0.8291

In [13]:
print(sum(tmscore_list) / len(tmscore_list))

0.792691366906475
