In [1]:
import torch
import pickle
import re
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from glob import glob
from collections import OrderedDict

from MSA_Pairformer.model import MSAPairformer
from MSA_Pairformer.dataset import MSA, prepare_msa_masks, aa2tok_d


In [2]:
# Weights file
weights_dir = "../model_checkpoints/"
Path(snapshot_download(repo_id="yakiyama/MSA-Pairformer", cache_dir=weights_dir))

Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

PosixPath('../model_checkpoints/models--yakiyama--MSA-Pairformer/snapshots/10dff5245d6fe736013fdfe8be74df23be4db2cc')

In [3]:
model_weights_path = glob("../model_checkpoints/models--yakiyama--MSA-Pairformer/snapshots/*/model.bin")[0]
model_weights = torch.load(model_weights_path)
dict_items = []
for i, (k, v) in enumerate(model_weights.items()):
    if re.match(r"core_stack\.layers\.\d+\.3\.tri_mult_(incoming|outgoing)\.norm", k):
        new_k = k.replace(".norm.", ".norm_in.")
        dict_items.append((new_k, v))
    elif re.match(r"core_stack\.layers\.\d+\.3\.tri_mult_(incoming|outgoing)\.fn.left_right_proj", k):
        p_in_weight = v[:256 * 2, :]
        g_in_weight = v[256 * 2:, :]
        new_p_in_k = k.replace(".fn.left_right_proj.0.", ".p_in.")
        new_g_in_k = k.replace(".fn.left_right_proj.0.", ".g_in.")
        dict_items.append((new_p_in_k, p_in_weight))
        dict_items.append((new_g_in_k, g_in_weight))
    elif re.match(r"core_stack\.layers\.\d+\.3\.tri_mult_(incoming|outgoing)\.fn.to_out_norm", k):
        new_k = k.replace(".fn.to_out_norm.", ".norm_out.")
        dict_items.append((new_k, v))
    elif re.match(r"core_stack\.layers\.\d+\.3\.tri_mult_(incoming|outgoing)\.fn.out_gate", k):
        new_k = k.replace(".fn.out_gate.", ".g_out.")
        dict_items.append((new_k, v))
    elif re.match(r"core_stack\.layers\.\d+\.3\.tri_mult_(incoming|outgoing)\.fn.to_out.0", k):
        new_k = k.replace(".fn.to_out.0.", ".p_out.")
        dict_items.append((new_k, v))
    else:
        dict_items.append((k, v))
new_ordered_weights = OrderedDict(dict_items)

In [4]:
cb_head_weights_path = glob("../model_checkpoints/models--yakiyama--MSA-Pairformer/snapshots/*/contact.bin")[0]
cb_head_weights = torch.load(cb_head_weights_path)
confind_head_weights_path = glob("../model_checkpoints/models--yakiyama--MSA-Pairformer/snapshots/*/confind_contact.bin")[0]
confind_head_weights = torch.load(confind_head_weights_path)

new_ordered_weights.update(cb_head_weights)
new_ordered_weights.update(confind_head_weights)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MSAPairformer()
model.load_state_dict(new_ordered_weights)
model = model.to(device).eval().to(dtype=torch.bfloat16)

# msa_file = "../data/1B70_A_1B70_B.fas"
# max_msa_depth = 512
# max_length = 10240
# chain_break_idx = 265
# np.random.seed(42)
# msa_obj = MSA(
#     msa_file_path=msa_file,
#     max_seqs=max_msa_depth,
#     max_length=max_length,
#     max_tokens=np.inf,
#     diverse_select_method="hhfilter",
#     hhfilter_kwargs={"binary": "hhfilter"}
# )
# # Prepare MSA and mask tensors
# msa_tokenized_t = msa_obj.diverse_tokenized_msa
# msa_onehot_t = torch.nn.functional.one_hot(msa_tokenized_t, num_classes=len(aa2tok_d)).unsqueeze(0).float().to(device)
# mask, msa_mask, full_mask, pairwise_mask = prepare_msa_masks(msa_obj.diverse_tokenized_msa.unsqueeze(0))
# mask, msa_mask, full_mask, pairwise_mask = mask.to(device), msa_mask.to(device), full_mask.to(device), pairwise_mask.to(device)

# Run MSA Pairformer to generate embeddings and predict contacts
with torch.no_grad():
    with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
        res = model(
            msa=msa_onehot_t.to(torch.bfloat16),
            mask=mask,
            msa_mask=msa_mask,
            full_mask=full_mask,
            pairwise_mask=pairwise_mask,
            complex_chain_break_indices=[[chain_break_idx]],
            return_seq_weights=True,
            return_pairwise_repr_layer_idx=None,
            return_msa_repr_layer_idx=None
        )