In [9]:
import numpy as np 
import random 
import pandas as pd
from sklearn.metrics.pairwise import pairwise_distances
import os 
import json 
import trainUtils
import torch
import testUtils 
from torch.utils.data import DataLoader

In [2]:
def pick_sample(distance_matrix, num, weight, dis_weight, use_random=True, random_seed=1013):
    if use_random:
        random.seed(random_seed)
        np.random.seed(random_seed)
    picked = [] 
    for i in range(num):
        q = np.zeros_like(distance_matrix[0])
        for j in picked:
            q += distance_matrix[j] 
        q /= len(picked)+0.001 
        pick_weight = weight+q*dis_weight 
        pick_weight[picked] = 0
        if use_random:
            pick_weight = pick_weight/np.sum(pick_weight)
            picked.append(np.random.choice(len(pick_weight), p=pick_weight))
        else:
            picked.append(np.argmax(pick_weight))
    return picked

In [4]:
path = "/data2/tyfei/trainresults/ionChannels/ESMCFinal/logit2/"
with open(os.path.join(path, "config.json"), "r") as f:
    configs = json.load(f)
pretrain_model = trainUtils.loadPretrainModel(configs)
model = trainUtils.buildModel(configs, pretrain_model)

initized model for base_learning stage




In [8]:
ckpt = torch.load("/data2/tyfei/trainresults/ionChannels/ESMCFinal/logit2/last.ckpt")
model.load_state_dict(ckpt['state_dict'], strict=False)

_IncompatibleKeys(missing_keys=['esm_model.embed.weight', 'esm_model.transformer.blocks.0.attn.layernorm_qkv.1.linear.weight', 'esm_model.transformer.blocks.0.attn.out_proj.linear.weight', 'esm_model.transformer.blocks.0.ffn.0.weight', 'esm_model.transformer.blocks.0.ffn.0.bias', 'esm_model.transformer.blocks.0.ffn.1.linear.weight', 'esm_model.transformer.blocks.0.ffn.3.linear.weight', 'esm_model.transformer.blocks.1.attn.layernorm_qkv.1.linear.weight', 'esm_model.transformer.blocks.1.attn.out_proj.linear.weight', 'esm_model.transformer.blocks.1.ffn.0.weight', 'esm_model.transformer.blocks.1.ffn.0.bias', 'esm_model.transformer.blocks.1.ffn.1.linear.weight', 'esm_model.transformer.blocks.1.ffn.3.linear.weight', 'esm_model.transformer.blocks.2.attn.layernorm_qkv.1.linear.weight', 'esm_model.transformer.blocks.2.attn.out_proj.linear.weight', 'esm_model.transformer.blocks.2.ffn.0.weight', 'esm_model.transformer.blocks.2.ffn.0.bias', 'esm_model.transformer.blocks.2.ffn.1.linear.weight', 'es

In [10]:
import VirusDataset
test = trainUtils.loadPickle("/data/tyfei/datasets/ion_channel/Interprot/test885.pkl")
test_set = VirusDataset.ESM3MultiTrackDatasetTEST(test, tracks=["seq_t"]) 
dl = DataLoader(test_set, batch_size=1, shuffle=False) 
embed2, _ = testUtils.getEmbeddings(pretrain_model, dl, 6, 300000, "esmc", "mean")

100%|██████████| 885/885 [01:35<00:00,  9.26it/s]


In [11]:
q = np.array(embed2)
q = q.squeeze()
print(q.shape)

(885, 1152)


In [12]:
#prepare dataframe 
df = pd.DataFrame({"predict":np.random.uniform(0, 1, 885), "ER":np.random.uniform(0, 1, 885), "CA":np.random.uniform(0, 1, 885), "id":range(885)})
df.head()

Unnamed: 0,predict,ER,CA,id
0,0.16852,0.245961,0.718966,0
1,0.224602,0.839389,0.7502,1
2,0.655674,0.258202,0.468684,2
3,0.46476,0.526605,0.238183,3
4,0.140256,0.881562,0.170585,4


In [14]:
df = df.sort_values("predict")
df["weight"] = 1 + df["ER"]*0.2+df["CA"]*0.3
df.head()

Unnamed: 0,predict,ER,CA,id,weight
291,0.00551,0.15459,0.948633,291,1.315508
604,0.006572,0.559203,0.375007,604,1.224343
804,0.00738,0.300001,0.627633,804,1.24829
578,0.013698,0.030708,0.360003,578,1.114143
72,0.015034,0.012941,0.378243,72,1.116061


In [26]:
se = [0, 110, 220, 330, 440, 550, 660, 770, 888]
pick_num = [8, 8, 8, 8, 8, 8, 8, 8]
all_picked = []
for i in range(8):
    subdf = df.iloc[se[i]:se[i+1]] 
    embeds = q[subdf["id"].values] 
    distance_matrix = pairwise_distances(embeds, metric="euclidean")
    picked = pick_sample(distance_matrix, pick_num[i], subdf["weight"].values, 0.5)
    p = subdf.iloc[picked]
    all_picked.extend(p["id"].values)

In [27]:
len(all_picked)

64

In [25]:
distance_matrix

array([[0.        , 0.44032705, 0.6662917 , ..., 0.5475367 , 0.7406728 ,
        0.7203726 ],
       [0.44032705, 0.        , 0.59878606, ..., 0.7082523 , 0.5878906 ,
        0.8197343 ],
       [0.6662917 , 0.59878606, 0.        , ..., 0.65641963, 0.5301595 ,
        0.68961334],
       ...,
       [0.5475367 , 0.7082523 , 0.65641963, ..., 0.        , 0.9186178 ,
        0.81841576],
       [0.7406728 , 0.5878906 , 0.5301595 , ..., 0.9186178 , 0.        ,
        0.6362694 ],
       [0.7203726 , 0.8197343 , 0.68961334, ..., 0.81841576, 0.6362694 ,
        0.        ]], dtype=float32)