In [2]:
import pickle
import numpy as np
from scipy.stats import spearmanr
import json
from tqdm import tqdm

# Load data
with open('../data/shuffled_selected_protein.json', "r") as json_file:
    shuffled = json.load(json_file)

with open('../data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

def correlation_calculation(pdb, n_std):
    try:
        with open('../results/inv_cov_msa/' + pdb + '_inv_cov_msa.pkl', 'rb') as f:
            msa_tmp = pickle.load(f)
    
        with open('../results/esm2_jac/' + pdb + '_esm2_jac.pkl', 'rb') as f:
            jac = pickle.load(f)
        
        L = len(seq_dict[pdb])
        ic = msa_tmp["ic"].reshape(L, 21, L, 21)[:,:20,:,:20]
        idx = np.triu_indices(L, 1)
        a = jac[idx[0], :, idx[1], :]
        b = ic[idx[0], :, idx[1], :]
        ORDER = msa_tmp["apc"][idx[0], idx[1]].argsort()[::-1]
        cutoff = L
        b_std = np.sqrt(np.square(b[ORDER[:cutoff]]).mean())
        mask = np.abs(b[ORDER[:cutoff]]) > b_std * n_std

        return spearmanr(a[ORDER[:cutoff]][mask].flatten(), b[ORDER[:cutoff]][mask].flatten()).statistic
    except Exception as e:
        print(f"An error occurred with {pdb} for n_std={n_std}: {e}")
        return None

cc_info = {}
for pdb in tqdm(shuffled):
    cc_info.setdefault(pdb, {})
    for i in range(0, 5):
        cc_info[pdb][i] = correlation_calculation(pdb, i)

# Save results
with open('../results/3B_L_std.json', 'w') as file:
    json.dump(cc_info, file)


100%|██████████| 1431/1431 [1:02:46<00:00,  2.63s/it]
