In [None]:
import os
import pandas as pd
import numpy as np
import h5py
import torch
from transformers import AutoTokenizer
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score

import sys
sys.path.insert(0, "/cluster/pixstor/xudong-lab/yuexu/SeqDance-main/SeqDance-main/model/")
from config import config # please first download the dataset and fill in the config.py file with the path where you downloaded the dataset
from model import ESMwrap

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Load model
esm2_select = 'model_35M'
model_select = 'seqdance' # or 'esmdance'
dance_model = ESMwrap(esm2_select, model_select).to(device)

# Load the SeqDance model from huggingface
dance_model = dance_model.from_pretrained("ChaoHou/ESMDance")
dance_model = dance_model.to(device)
dance_model.eval()

# Load dataset, use test set
df = pd.read_csv(config['file_path']['train_df_path'])
# df = df[df['label'] == 'test']

# Load HDF5 dataset
h5py_read = h5py.File(config['file_path']['h5py_path'], 'r')
max_len = 274

pro = '3ic3_B'
seq = df[df['name'] == pro]['seq'].values[0]
pair_f = h5py_read[f'{pro}_pair_feature'][:]

raw_input = tokenizer(seq, return_tensors="pt", max_length=max_len, truncation=True)
length = raw_input['input_ids'].shape[1]
pair_f = pair_f[:length, :length]

with torch.no_grad():
    output = dance_model(raw_input.to(device), return_attention_map=True)

atten = output['attention_map'][0].cpu().numpy()

# only analyze residues with distance > 2
row_indices, col_indices = np.where(np.abs(np.arange(length)[:, None] - np.arange(length)) > 5)
co_move = pair_f[:,:,9]
inter = (pair_f[:,:,:9]**3).sum(-1) # the interaction frequency is pow(x, 1/3) in the file, as we use this value to train the model
atten_scores = {'co_move_topL_ratio': [], 'co_move_posi_spearman': [], 'co_move_neg_spearman': [], 'inter_topL_ratio': [], 'inter_auroc': []}

f = co_move[row_indices, col_indices]
mask = f != -1 # -1 is padding
f_flat = f[mask]
for k in range(atten.shape[-1]):
    att = atten[:, :, k][row_indices, col_indices]
    att_flat = att[mask]
    # Compute fold change
    top_L_indices = np.argsort(-att_flat)[:length]  # Negative sign for descending sort
    top_L_mean = np.mean(np.abs(f_flat[top_L_indices]))
    other_indices = np.setdiff1d(np.arange(f_flat.shape[0]), top_L_indices)
    other_mean = np.mean(np.abs(f_flat[other_indices]))
    atten_scores['co_move_topL_ratio'].append(top_L_mean / (other_mean + 1e-8))
    # Compute Spearman correlations
    atten_scores['co_move_posi_spearman'].append(spearmanr(att_flat[f_flat > 0], f_flat[f_flat > 0])[0])
    atten_scores['co_move_neg_spearman'].append(spearmanr(att_flat[f_flat < 0], f_flat[f_flat < 0])[0])