In [1]:
import os
import esm
import torch
import argparse
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataset.dataset import ProteinsDataset, protein_collate_fn
from transformers import BertModel, BertTokenizer

In [2]:
def symmetrize(x):
    "Make layer symmetric in final two dimensions, used for contact prediction."
    return x + x.transpose(-1, -2)


def apc(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)
    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

In [3]:
### obj creation for pre-trained model
pre_trained_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt" to /mnt/nvme/home/bbabatun/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm1b_t33_650M_UR50S-contact-regression.pt" to /mnt/nvme/home/bbabatun/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S-contact-regression.pt


In [7]:
atten_infer = pre_trained_model.to('cuda:1')

In [10]:
model_list = ["contact_jits/atten_single_contact_gpu.pth",
                "contact_jits/atten_only_contact_gpu.pth",
                "contact_jits/atten_all_contact_gpu.pth",
                "contact_jits/atten_sgl_dist_gpu.pth",
                "contact_jits/atten_only_dist_gpu.pth",
                "contact_jits/atten_all_dist_gpu.pth"]
for model_num, model_path in enumerate(model_list):

    model = torch.jit.load(model_path)
    # print(f"{model_path} loaded.")
    model = model.to('cuda:1')
    print(model)
    del model

RecursiveScriptModule(
  original_name=Network
  (conv1): RecursiveScriptModule(original_name=Conv2d)
  (bn1): RecursiveScriptModule(original_name=InstanceNorm2d)
  (relu): RecursiveScriptModule(original_name=ReLU)
  (firstblock): RecursiveScriptModule(
    original_name=FirstBlock
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (bn1): RecursiveScriptModule(original_name=InstanceNorm2d)
    (relu): RecursiveScriptModule(original_name=ReLU)
    (conv2): RecursiveScriptModule(original_name=Conv2d)
    (bn2): RecursiveScriptModule(original_name=InstanceNorm2d)
    (dp1): RecursiveScriptModule(original_name=Dropout)
    (relu2): RecursiveScriptModule(original_name=ReLU)
    (conv3): RecursiveScriptModule(original_name=Conv2d)
    (bn3): RecursiveScriptModule(original_name=InstanceNorm2d)
  )
  (secondblock): RecursiveScriptModule(
    original_name=SecondBlock
    (relu0): RecursiveScriptModule(original_name=ReLU)
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (b