In [None]:
%load_ext autoreload
%autoreload 2
from PPI_prediction import *

In [None]:
make_ensg_id()
uniprot_request()
make_ensp_ensg()
make_interaction_file()

In [None]:
make_protbert_embedding()
make_prott5_embedding()
make_esm_embedding('esm1b')
make_esm_embedding('esm2')

embedding_fname을 수정하여 각 임베딩에 접근하면 됨

In [None]:
interaction_fname = 'HI-union.tsv'
sub_interaction_fname = 'STRING_interactions.txt'
sequence_fname = 'HI_union_Uniprot_Sequence2.txt'
fname = [
    "biomedical/HI_union_protbert_embeddings.npz",
    "biomedical/HI_union_protT5_embeddings.npz",
    "biomedical/HI_union_esm1b_650M_embeddings.npz",
    "biomedical/HI_union_esm2_650M_embeddings.npz"
]
embedding_fname = fname[0]
pca = 0
norm = False

graph_data2 = make_graph(
    sequence_fname, 
    interaction_fname, 
    sub_interaction_fname,
    embedding_fname,
    pca,
    norm
)

In [None]:
train_data, val_data, test_data = spilt_data(graph_data2)

In [None]:
model = 'GCN'
if pca:
    if 'esm1b' in embedding_fname:
        fname = 'log_esm1b_' + model + '_pca'+pca+'.txt'
        best_model = 'log_esm1b_' + model +'_best_pca'+pca+'.pt'
    elif 'esm2' in embedding_fname:
        fname = 'log_esm2_' + model + '_pca'+pca+'.txt'
        best_model = 'log_esm2_' + model +'_best_pca'+pca+'.pt'
    elif 'T5' in embedding_fname:
        fname = 'log_ProtT5_' + model + '_pca'+pca+'.txt'
        best_model = 'log_ProtT5_' + model +'_best_pca'+pca+'.pt'
    elif 'bert' in embedding_fname:
        fname = 'log_ProtBERT_' + model + '_pca'+pca+'.txt'
        best_model = 'log_ProtBERT_' + model +'_best_pca'+pca+'.pt'
elif norm:
    if 'esm1b' in embedding_fname:
        fname = 'log_esm1b_' + model + '_norm.txt'
        best_model = 'log_esm1b_' + model +'_best_norm.pt'
    elif 'esm2' in embedding_fname:
        fname = 'log_esm2_' + model + '_norm.txt'
        best_model = 'log_esm2_' + model +'_best_norm.pt'
    elif 'T5' in embedding_fname:
        fname = 'log_ProtT5_' + model + '_norm.txt'
        best_model = 'log_ProtT5_' + model +'_best_norm.pt'
    elif 'bert' in embedding_fname:
        fname = 'log_ProtBERT_' + model + '_norm.txt'
        best_model = 'log_ProtBERT_' + model +'_best_norm.pt'
else:
    if 'esm1b' in embedding_fname:
        fname = 'log_esm1b_' + model + '.txt'
        best_model = 'log_esm1b_' + model +'_best.pt'
    elif 'esm2' in embedding_fname:
        fname = 'log_esm2_' + model + '.txt'
        best_model = 'log_esm2_' + model +'_best.pt'
    elif 'T5' in embedding_fname:
        fname = 'log_ProtT5_' + model + '.txt'
        best_model = 'log_ProtT5_' + model +'_best.pt'
    elif 'bert' in embedding_fname:
        fname = 'log_ProtBERT_' + model + '.txt'
        best_model = 'log_ProtBERT_' + model +'_best.pt'

In [None]:
import os, torch

in_channel = graph_data2.x.shape[1]
state_record = {}
top, best_val, early_stop, best_epoch = 0, 0, 0, 0

home = os.path.expanduser("~")
model_file = os.path.join(home, 'biomedical', 'log_data', best_model)
target = os.path.join(home, 'biomedical', 'log_data', fname)
f = open(target, 'w')

for i in range(1):

    print(i)
    hidden, layers, out, rate, drop, decay = hyperparameter_tuning(in_channel)
    f.write(f"\nIn: {in_channel} | Hidden: {hidden:3d} | Layers: {layers:2d} | Out: {out:3d} | Rate: {rate:.4f} | Drop: {drop:.2f} | Decay: {decay:.5f}\n")
    f.flush()
    key = (hidden, layers, out, rate, drop, decay)
    state_record[key] = [[], []]  # loss, val_f1


    pred_model, optimizer = set_model(in_channel, *key, model)

    for epoch in range(1, 301):

        loss = train(pred_model, optimizer, train_data)   
        # 각 epoch마다 train data로 학습을 진행함

        val_f1 = evaluate(pred_model, val_data, train_data, True)['f1']   
        # 각 epoch마다 validation data로 검증 진행

        state_record[key][0].append(loss)
        state_record[key][1].append(val_f1)
        if val_f1 > best_val:
            best_val = val_f1
            early_stop = 0
        else:
            early_stop += 1
            if early_stop >= 20:
                break

        if top < best_val:
            top = best_val
            best_epoch = epoch
            torch.save(pred_model.state_dict(), model_file)
        
        print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | Val: {val_f1:.4f} | Best: {best_val:.4f}")
        print(f"Epoch {epoch:3d} | Loss: {loss:.4f} | Val: {val_f1:.4f} | Best: {best_val:.4f}", file=f)
        f.flush()
    best_val = 0



# 최종 테스트
for key, value in state_record.items():
    if max(value[1]) == top:
        loss = value[0]
        f1 = value[1]
        param = key

pred_model, optimizer = set_model(in_channel, *param, model)
pred_model.load_state_dict(torch.load(model_file))
test_f1 = evaluate(pred_model, test_data, train_data, True)
f.write(" | ".join(map(str, param)))
print(f"TEST | Best Val : {top:.4f} (Epoch {best_epoch}) | Test AUC: {test_f1['auc']:.4f} | Test F1: {test_f1['f1']:.4f} | Test Accuracy: {test_f1['accuracy']:.4f}", file = f)
print(f"\nBest Val : {top:.4f} (Epoch {best_epoch})")
print(f"Test AUC:     {test_f1['auc']:.4f}")
print(f"Test F1:      {test_f1['f1']:.4f}")
print(f"Test Accuracy: {test_f1['accuracy']:.4f}")

f.close()

In [None]:
import matplotlib.pyplot as plt

for key, value in state_record.items():
    if max(value[1]) == top:
        loss = value[0]
        f1 = value[1]
        hidden, layers, out, rate, drop = key

x = list(range(1, len(loss)+1))

plt.plot(x, loss, label='Loss', marker='o')
plt.plot(x, f1, label='F1', marker='x')
plt.title(f'Hidden: {hidden}, Layers: {layers}, Out: {out}, Rate: {rate}, Drop: {drop}')
plt.xlabel('Epochs')
plt.ylabel('Evaluation')
plt.legend()
plt.show()

In [None]:
# 외부 데이터로 시험

for key, value in state_record.items():
    if max(value[1]) == top:
        loss = value[0]
        f1 = value[1]
        hidden, layers, out, rate, drop = key

pred_model, optimizer = GCN.set_model(in_channel, hidden, layers, out, rate, drop)
pred_model.load_state_dict(torch.load('best.pt'))


