In [6]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG']=":4096:8"
import typing as T
from pathlib import Path

from tqdm import *
import torch
from esm.esmfold.v1 import esmfold
import argparse
import json
from torch.utils.data.dataset import Subset
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import time 
import wandb
from data import ClusteredDataset_inturn, StructureDataset, batch_collate_function_nocluster

import TPFold 
import nolgfold
import utils
import torch.distributed as dist
import torch.optim as optim
from openfold.utils.rigid_utils import Rigid
from openfold.utils.loss import compute_fape
import noam_opt
import omegaconf
parser = argparse.ArgumentParser()
parser.add_argument('--shuffle', type=float, default=0., help='Shuffle fraction')
parser.add_argument('--data_jsonl', type=str, default="/pubhome/xtzhang/data/test_set.jsonl", help='Path for the jsonl data')
parser.add_argument('--split_json', type=str, default="/pubhome/bozhang/data/tmpnn_v8.json",help='Path for the split json file')
parser.add_argument('--output_folder',type=str,default="/pubhome/xtzhang/result/output/",help="output folder for the log files and model parameters")
parser.add_argument('--save_folder',type=str,default="/pubhome/xtzhang/result/save/",help="output folder for the model parameters")
parser.add_argument('--description',type=str,help="description the model information into wandb")
parser.add_argument('--job_name',type=str,default="noplm_eva",help="jobname of the wandb dashboard")
parser.add_argument('--num_tags',type=int,default=6,help="num tags for the sequence")
parser.add_argument('--epochs',type=int,default=5,help="epochs to train the model")
parser.add_argument('--batch_size',type=int,default=1,help="batch size tokens")
parser.add_argument('--max_length',type=int,default=800,help="max length of the training sequence")
parser.add_argument('--max_tokens',type=int,default=400,help="max length of the training sequence")
parser.add_argument('--mask',type=float,default=1.0,help="mask fractions into input sequences")
parser.add_argument("--local_rank", default=0, help="local device ID", type=int) 
parser.add_argument('--parameters',type=str,default="/pubhome/xtzhang/result/save/no_plm_or_else_384epoch4.pt", help="parameters path")
parser.add_argument('--lr',type=float,default=5e-4, help="learning rate of Adam optimizer")
parser.add_argument('--chunk_size',type=int,default=4,help="chunk size of the model")
parser.add_argument('--world_size',type=int,default=2,help="world_size")
parser.add_argument('--pattern',type=str,default="no",help="mode")
parser.add_argument('--add_tmbed',type=bool,default=False, help="whether addtmbed")
parser.add_argument('--watch_freq',type=int,default=500, help="watch gradient")
# parser.add_argument('--pdb',default="/pubhome/xtzhang/output/pdb/800aa_noseq_nocctop", help="Path to output PDB directory", type=Path, required=True)
# parser.add_argument("-o", "--pdb", help="Path to output PDB directory", type=Path, required=True)
args = parser.parse_args(args=[])

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg=omegaconf.dictconfig.DictConfig( 
 content={'_name': 'ESMFoldConfig', 'esm_type': 'esm2_3B', 'fp16_esm': True, 'use_esm_attn_map': False, 'esm_ablate_pairwise': False, 'esm_ablate_sequence': False, 'esm_input_dropout': 0, 'trunk': {'_name': 'FoldingTrunkConfig', 'num_blocks': 48, 'sequence_state_dim': 1024, 'pairwise_state_dim': 128, 'sequence_head_width': 32, 'pairwise_head_width': 32, 'position_bins': 32, 'dropout': 0, 'layer_drop': 0, 'cpu_grad_checkpoint': False, 'max_recycles': 4, 'chunk_size': None, 'structure_module': {'c_s': 384, 'c_z': 128, 'c_ipa': 16, 'c_resnet': 128, 'no_heads_ipa': 12, 'no_qk_points': 4, 'no_v_points': 8, 'dropout_rate': 0.1, 'no_blocks': 8, 'no_transition_layers': 1, 'no_resnet_blocks': 2, 'no_angles': 7, 'trans_scale_factor': 10, 'epsilon': 1e-08, 'inf': 100000.0}}, 'embed_aa': True, 'bypass_lm': False, 'lddt_head_hid_dim': 128}
 )
model_no = nolgfold.load_model(chunk_size=64, pattern=args.pattern, model_path="/pubhome/xtzhang/result/save/no_plm_or_else_384epoch4.pt", cfg=cfg)
model_no = model_no.to(device)
model_no.eval()



ESMFold(
  (tmbed_model): tmbed2(
    (encoder_model): T5EncoderModel(
      (shared): Embedding(128, 1024)
      (encoder): T5Stack(
        (embed_tokens): Embedding(128, 1024)
        (block): ModuleList(
          (0): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                  (relative_attention_bias): Embedding(32, 32)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerFF(
                (DenseReluDense): T5DenseActDense(
                  (wi): Linear(in_features=1024, out_

In [15]:
model = TPFold.load_model(chunk_size=64, pattern=args.pattern, model_path="/pubhome/xtzhang/result/save/800aa_noseq_nocctopepoch4.pt", cfg=cfg)
model = model.to(device)
model.eval()

KeyboardInterrupt: 

: 

In [14]:
jsonl_file = args.data_jsonl
dataset = StructureDataset(jsonl_file=jsonl_file, max_length=args.max_length) 
test_loader = DataLoader(
    dataset=dataset, 
    batch_size=args.batch_size, 
    collate_fn=batch_collate_function_nocluster)  

UNK token:0,too long:0, 'not_match':1


In [None]:
with torch.no_grad():
    for iteration, batch in enumerate(test_loader):            
                for key in batch:
                    batch[key] = batch[key].cuda(args.local_rank)
                C_pos, CA_pos, N_pos, seq, mask, residx, bb_pos = batch['C_pos'], batch['CA_pos'], batch['N_pos'],batch['seq'], batch['mask'], batch['residx'], batch['bb_pos']                
                output_dict = model(aa=seq, mask=mask, residx=residx)
                output_dict_no = model_no(aa=seq, mask=mask, residx=residx)
                target_frames = Rigid.from_3_points(C_pos, CA_pos, N_pos)

                loss_fape = torch.mean(compute_fape(
                                pred_frames=output_dict['pred_frames'],
                                target_frames=target_frames,
                                frames_mask=output_dict['frame_mask'],
                                pred_positions=output_dict['backbone_positions'],
                                target_positions=bb_pos,
                                positions_mask=output_dict['backbone_atoms_mask'],
                                length_scale=10,
                            ))
                loss_fape_no = torch.mean(compute_fape(
                                pred_frames=output_dict_no['pred_frames'],
                                target_frames=target_frames,
                                frames_mask=output_dict_no['frame_mask'],
                                pred_positions=output_dict_no['backbone_positions'],
                                target_positions=bb_pos,
                                positions_mask=output_dict_no['backbone_atoms_mask'],
                                length_scale=10,
                            ))
