In [1]:
import torch
import torch.nn as nn

import numpy as np 
import matplotlib.pyplot as plt

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import SingleLetterAlphabet

from transformers import AutoTokenizer, AutoModel, EsmForProteinFolding

import os
import copy
from tqdm import tqdm

from piecewise_quant import *

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=False)

# model = model.cuda('cuda:0')
# model.esm = model.esm.half()
model = model.cpu()

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
quant_layers = []

for key in model.state_dict().keys():
    # if key.startswith("esm.encoder.layer") or key.startswith("trunk.block") or key.startswith("trunk.structure_module"):
    if key.startswith("esm.encoder.") or key.startswith("trunk.block"):
        quant_layers.append(key)
    
checkpoint, rmse = quant_checkpoint(model, quant_layers)
model.load_state_dict(checkpoint)
del checkpoint


total quant RMSE: 9.4016e-04


In [4]:
model = quant_model_acts(model, 0, True, exclude_part=["base_model"])
model = model.cuda("cuda:0")

In [5]:
seq_fasta = list(SeqIO.parse("../data/casp14.fasta", "fasta"))

seq_list = [seq.seq.__str__() for seq in seq_fasta][:50]
key_list = [seq.id.__str__().split("_")[0] for seq in seq_fasta][:50]
print(len(seq_list))

50


In [6]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
ecoli_tokenized = tokenizer(seq_list, padding=False, add_special_tokens=False)['input_ids']

In [7]:
outputs = []

with torch.no_grad():
    for input_ids in tqdm(ecoli_tokenized):
        input_ids = torch.tensor(input_ids, device='cuda:0').unsqueeze(0)
        output = model(input_ids)
        outputs.append({key: val.cpu() for key, val in output.items()})

os.makedirs('../output/stats/', exist_ok=True)
act_stats_save_path = '../output/stats/act_stats_v1.pth'
act_dict = save_model_act_stats(model, act_stats_save_path)

100%|██████████| 50/50 [18:15<00:00, 21.91s/it]


In [8]:
new_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=False)
new_model = new_model.cpu()
# new_model = quant_model_acts(new_model, 8, False, exclude_part=["base_model"], cali_batch_size=50)
new_model = quant_model_acts(new_model, 8, False, exclude_part=["base_model"], cali_batch_size=50, quant_scheme="pwlq-3")

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
mode = load_model_act_stats(new_model, act_stats_save_path, act_clip_method="top_2")
# mode = load_model_act_stats(new_model, act_stats_save_path, act_clip_method="clip_0.999")

In [10]:
torch.save(new_model, "../output/quant_acts/quant_model_8b_full_v2.pt")

In [1]:
import torch
import torch.nn as nn

import numpy as np 
import matplotlib.pyplot as plt

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import SingleLetterAlphabet

from transformers import AutoTokenizer, AutoModel, EsmForProteinFolding

import os
import copy
from tqdm import tqdm

from linear_quant import *
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f24c18a0a30>

In [2]:
new_model = torch.load("../output/quant_acts/quant_model_8b_full_v2.pt")
new_model = new_model.cuda("cuda:0")

In [3]:
seq_fasta = list(SeqIO.parse("../data/sequences_cameo.fasta", "fasta"))

seq_list = [seq.seq.__str__() for seq in seq_fasta]
key_list = [seq.id.__str__().split("_")[0] for seq in seq_fasta]

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
ecoli_tokenized = tokenizer(seq_list, padding=False, add_special_tokens=False)['input_ids']

outputs = []

with torch.no_grad():
    for input_ids in tqdm(ecoli_tokenized):
        input_ids = torch.tensor(input_ids, device='cuda:0').unsqueeze(0)
        output = new_model(input_ids)
        outputs.append({key: val.cpu() for key, val in output.items()})

100%|██████████| 139/139 [1:50:33<00:00, 47.72s/it]


In [4]:
pdb_list = [convert_outputs_to_pdb(output) for output in outputs]
for identifier, pdb in zip(key_list, pdb_list):
    with open(f"../output/pred_quant_combine_pdb_v2/{identifier}.pdb", "w") as f:
        f.write("".join(pdb))

In [5]:
### import os
from TMscore import TMscore
from tqdm import tqdm


real_pdbs = os.listdir("../data/cameo_real_pdb")
pred_quant_pdbs = os.listdir("../output/pred_quant_combine_pdb_v2")

tmscore = TMscore("TMscore")

tmscore_list = []
lddt_list = []
for a, b in tqdm(zip(real_pdbs, pred_quant_pdbs)):
    tmscore(os.path.join("../data/cameo_real_pdb", a), os.path.join("../output/pred_quant_combine_pdb_v2", b))
    score = tmscore.get_tm_score()
    if score is not None:
        tmscore_list.append(tmscore.get_tm_score())

tmscore_pred = sum(tmscore_list) / len(tmscore_list)   
print("{}/{}".format(sum(tmscore_list), len(tmscore_list)))
print(tmscore_pred)

# 0.7945769784172667 173 top100 norm

139it [01:48,  1.29it/s]

111.34760000000006/139
0.8010618705035976





In [4]:
# 0.7495877697841726 50 top_20 act_stats_50_layernorm

# 0.6797410071942447 cali_batch_size=173 top_1 act_stats_173 8 bits
# 0.7993798561151076 cali_batch_size=173 top_1 act_stats_173 16 bits

# 0.6850496402877699 cali_batch_size=50 top_1 act_stats_50 8 bits
# 0.8006000000000004 cali_batch_size=50 top_1 act_stats_50 16 bits

#  cali_batch_size=173 top_1 act_stats_173_layernorm 8 bits
#  cali_batch_size=173 top_1 act_stats_173_layernorm 16 bits

# 0.6850496402877699 cali_batch_size=50 top_1 act_stats_50_layernorm 8 bits
# 0.7993237410071943 cali_batch_size=50 top_1 act_stats_50_layernorm 16 bits

In [5]:
# Origin
# TM-Score：0.80085

# 量化类型 in [nn.Linear, nn.Softmax, nn.functional.softmax, nn.Sigmoid(), nn.ReLU(), EsmFoldLinear, LayerNorm]
# TM-Score：0.7417683453237413 # 全量化
# TM-Score：0.7871827338129491 # 只量化encoder
# TM-Score：0.7672676258992803 # 只量化trunk

# 量化类型 in [nn.Linear, nn.Softmax, nn.functional.softmax, nn.Sigmoid(), nn.ReLU(), EsmFoldLinear]
# TM-Score：0.7969366906474817 # 全量化
# TM-Score：0.7984546762589924 # 只量化encoder
# TM-Score：0.8017553956834531 # 只量化trunk

# 量化类型 in [nn.Linear, nn.Sigmoid(), nn.ReLU(), EsmFoldLinear, LayerNorm]:
# TM-Score：0.7426244604316544 # 全量化
# TM-Score：0.7871827338129491 # 只量化encoder
# TM-Score：0.764428057553957  # 只量化trunk

# 量化类型 in [nn.Softmax, nn.functional.softmax]:
# TM-Score：0.8004014388489211 # 全量化
# TM-Score：0.8006791366906475 # 只量化encoder
# TM-Score：0.8004014388489211 # 只量化trunk

# 量化类型 in [LayerNorm]
# TM-Score：0.7312964028776979 # 全量化
# TM-Score：0.7850712230215826 # 只量化encoder
# TM-Score：0.7639489208633091 # 只量化trunk

# 量化类型 in [nn.Linear, EsmFoldLinear]
# TM-Score：0.7980323741007195 # 全量化
# TM-Score：0.7984546762589924 # 只量化encoder
# TM-Score：0.8011856115107912 # 只量化trunk
