# Refine ESM-Inpaintign pipeline

Key point : ground truth backbone frame be fed into structure module rather than to use black hole initialization.

In [31]:
import torch
import torch.nn as nn
import esm
import numpy as np
import esm_inpaint.esm.esmfold.v1.esmfold as ESM
import esm_inpaint.utils as utils
import esm_inpaint.modules as modules

In [2]:
# origin esmfold

model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()

# Optionally, uncomment to set a chunk size for axial attention. This can help reduce memory.
# Lower sizes will have lower memory requirements at the cost of increased speed.
# model.set_chunk_size(128)

sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# Multimer prediction can be done with chains separated by ':'

with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # this will be the pLDDT
# 88.3

88.28930830039526


1. class esmfold 
- self.lm_head 输出序列 已经有了
- self.distance_embedding 给pair-wise 输入结构，需要自己定义

In [2]:
model_path = "/root/.cache/torch/hub/checkpoints/esmfold_3B_v1.pt"
model_data = torch.load(str(model_path), map_location="cpu") #读取一个pickle文件为一个dict

In [5]:
cfg = model_data["cfg"]["model"]
model_state = model_data["model"]
model = ESM.ESMFold(esmfold_config=cfg)
model.load_state_dict(model_state, strict=False)

In [6]:
esm_data = utils.load_jsonl("./esm_inpaint/data/chain_set.jsonl")

In [6]:
test_data = torch.from_numpy(np.stack([i for i in esm_data[0]['coords'].values()],axis=-2))

In [7]:
test_data.unsqueeze_(0)

tensor([[[[    nan,     nan,     nan],
          [    nan,     nan,     nan],
          [    nan,     nan,     nan],
          [    nan,     nan,     nan]],

         [[    nan,     nan,     nan],
          [    nan,     nan,     nan],
          [    nan,     nan,     nan],
          [    nan,     nan,     nan]],

         [[    nan,     nan,     nan],
          [    nan,     nan,     nan],
          [    nan,     nan,     nan],
          [    nan,     nan,     nan]],

         ...,

         [[10.3340, 38.2370, 18.8380],
          [11.5050, 37.8970, 19.6330],
          [12.6560, 37.2870, 18.8430],
          [13.4520, 36.5200, 19.3900]],

         [[12.7790, 37.6800, 17.5800],
          [13.8600, 37.1940, 16.7440],
          [15.1600, 37.9140, 17.0640],
          [15.1860, 39.1280, 17.2060]],

         [[16.2050, 37.1030, 17.2270],
          [17.5680, 37.4520, 17.6440],
          [17.7590, 37.1340, 19.1240],
          [18.6790, 36.3400, 19.4270]]]], dtype=torch.float64)

In [8]:
bb_rotation,bb_translation = utils.get_bb_frames(test_data)
bb_frame = torch.zeros((*bb_rotation.shape[:-2],4,4))
bb_frame[...,:3,:3] = bb_rotation
bb_frame[...,:3,3] = bb_translation # [B, L, 4, 4]

In [9]:
import openfold.utils.rigid_utils as rigid_utils

In [10]:
frame = rigid_utils.Rigid.from_tensor_4x4(bb_frame)
rotate_points = frame[...,None].apply(test_data)

In [11]:
test_data.shape

torch.Size([1, 330, 4, 3])

In [12]:
bb_rotation.shape

torch.Size([1, 330, 3, 3])

In [13]:
bb_translation.unsqueeze(-2).shape

torch.Size([1, 330, 1, 3])

In [14]:
# point : [B, L, n_p, 3]
# frame : [B, L, 4, 4] tensor ---- slicing时应当转化为 [B, L, 1, 3, 3] 
# 对其 point [*, 3] ---- frame [*, 4,4]
points = (test_data @ bb_rotation.transpose(-1,-2)) + bb_translation.unsqueeze(-2)

In [15]:
points = utils.nan_to_num(points)
rotate_points = utils.nan_to_num(rotate_points)

In [None]:
(points - rotate_points) < 1e-4

## 统计模型参数

esm2,2842768487(3B)  
folding trunk,686106624(700M)    
structure module,2019116(20M)    
other,791655  


In [25]:
esm_total = 0
trunk_block = 0
trunk_sture_module = 0
other = 0
for i in model.state_dict().keys():
    if i.startswith("esm"):
        esm_total += ((model.state_dict())[i]).numel()
    elif i.startswith("trunk.blocks"):
        trunk_block += ((model.state_dict())[i]).numel()
    elif i.startswith("trunk.structure_module"):
        trunk_sture_module += ((model.state_dict())[i]).numel()
    else:
        other += ((model.state_dict())[i]).numel()

In [32]:
print(f"esm2 has {esm_total}\nfolding trunk has {trunk_block}\nstructure module has {trunk_sture_module}\nother has {other}\n")

esm2 has 2842768487
folding trunk has 686106624
structure module has 2019116
other has 791655



## 验证ESM-Inpainting POC

### 数据集构建 --- collate_batch + dataloader函数

aa_type : 0-19

x : 20 --- 只对输入有用

padding : 21 --- 输入输出都有用  
padding mask 的1表示这一块计算loss,0表示这一块不计算loss


In [24]:
batch_list = [esm_data[i] for i in range(0,3)]
batch_to_collate = []
for i in range(len(batch_list)):
    coord = np.stack([batch_list[i]['coords'][atom] for atom in batch_list[i]['coords'].keys()],axis=-2)
    confidence = None
    seq = batch_list[i]["seq"]
    batch_to_collate.append((coord,confidence,seq))

In [33]:
restypes = [
    "A",
    "R",
    "N",
    "D",
    "C",
    "Q",
    "E",
    "G",
    "H",
    "I",
    "L",
    "K",
    "M",
    "F",
    "P",
    "S",
    "T",
    "W",
    "Y",
    "V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}

In [26]:
res = utils.CoordBatchConverter.collate_dense_tensors([torch.tensor(i[0]) for i in batch_to_collate],-1)

In [37]:
seq = utils.CoordBatchConverter.collate_dense_tensors([torch.tensor([restype_order[res] for res in i[2]]) for i in batch_to_collate],21)

In [51]:
len_seq = 10
bert_mask_fraction = np.random.uniform(low=0.5, high=0.8)
bert_mask = torch.tensor([False for _ in range(10)] if np.random.random() < bert_mask_fraction else True) # 0, mask; 1 unmask 

In [53]:
bert_mask_fraction

0.7873464911436472

In [45]:
padding_mask = seq != 21

In [54]:
padding_mask.dtype

torch.bool

In [3]:
cfg = model_data["cfg"]["model"]
model = modules.esm_inpaint(cfg)

In [None]:
model_state = model_data["model"]
model.esmfold.load_state_dict(model_state, strict=False)

In [None]:
x