In [None]:
from task import CreateDataBeforeBatch,TMPDataset,CreateLable,MapAtomNode,node_accuracy,GaussianSmoothing,batchdata
from data_utils import ProcessRawData,ParseStructure
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from egnnmodel import EGNNModel
import numpy as np
import wandb
import seaborn as sns

In [None]:
# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="DL_egnn_new", #项目名称
    entity="transmembrane-topology", # 用户名
    group="CV setup1", # 对比实验分组
    name= "epoch50_size1 ", #实验的名字
    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.001,
    "architecture": "egnn",
    "dataset": "protein 3D structures ",
    "epochs":50,
    'batch_size':1,
    'hidden_channels' :128,
    'weight_decay': 1e-4
    }
)
sns.set_style("whitegrid")

In [None]:
# initialization(split data to setup1-5/download pdb/parse pdb)

# file_name = "DeepTMHMM.3line"
# path='/work3/s230027/DL/codebase/'
# processor = ProcessRawData(path,file_name)
# processor.run() # split data and download the pdb

# processor = ParseStructure(path)
# processor.run() # prase pdb and store them 


In [None]:
file_name = "DeepTMHMM.3line"
path='/work3/s230027/DL/codebase/'
batch_size=100
setup = 'setup1' # choose crossvalidation (total 5)
processsor= CreateDataBeforeBatch(path)
train_data_dict_before_batch,val_data_dict_before_batch,test_data_dict_before_batch=processsor.get_data(setup)

## dataloader for processing label 
train_dataset = TMPDataset(train_data_dict_before_batch)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

val_dataset = TMPDataset(val_data_dict_before_batch)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

In [None]:
# put train label togther
train_residual_level_label={}
train_atom_levl_label = {}
train_dismatch_index_pred ={}
train_dismatch_index_type ={}
for data_batch in train_data_loader:
    batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
    labelprocessor=CreateLable(batchname,data_batch,path,file_name)
    atom_level_label_dict,redidual_level_label_dict,dismatch_index_pred,dismatch_index_type,df_train,_,_=labelprocessor.labeldispatcher(setup,subset='train')

    train_atom_levl_label.update(atom_level_label_dict) 
    train_residual_level_label.update(redidual_level_label_dict) 
    train_dismatch_index_pred.update(dismatch_index_pred)
    train_dismatch_index_type.update(dismatch_index_type)

In [None]:
# put val label togther
val_residual_level_label={}
val_atom_levl_label = {}
val_dismatch_index_pred ={}
val_dismatch_index_type ={}
for data_batch in val_data_loader:
    batchname=[data_batch[num]['name'] for num in range(len(data_batch))]

    labelprocessor=CreateLable(batchname,data_batch,path,file_name)
    atom_level_label_dict,redidual_level_label_dict,dismatch_index_pred,dismatch_index_type,_,df_val,_=labelprocessor.labeldispatcher(setup,subset='val')
    val_atom_levl_label.update(atom_level_label_dict) 
    val_residual_level_label.update(redidual_level_label_dict) 
    val_dismatch_index_pred.update(dismatch_index_pred)
    val_dismatch_index_type.update(dismatch_index_type)

In [None]:
# dataloader for model 
batch_size=1

train_dataset = TMPDataset(train_data_dict_before_batch)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

val_dataset = TMPDataset(val_data_dict_before_batch)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

In [None]:
max_len= 20000*batch_size

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EGNNModel(out_dim=6,max_len=max_len,num_layers=5,emb_dim=128,residual=True,dropout=0.1).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-4)

In [None]:
total_epochs=50

global_step = 0

epoch_atom_level_accuracy_record_train = []
epoch_loss_record_train=[]
epoch_residual_level_accuracy_record_train = []
epoch_atom_level_accuracy_record_val = []
epoch_loss_record_val = []
epoch_residual_level_accuracy_record_val = []
smoothing = GaussianSmoothing(6, 29, 5)
for epoch in range(total_epochs):
     epoch_atom_level_accuracy_train = []
     epoch_loss_train=[]
     epoch_residual_level_accuracy_train = []
     # train
     for data_batch in train_data_loader:
          global_step += 1 
          batchname=[data_batch[num]['name'] for num in range(len(data_batch))]

          label_part = [value.unsqueeze(0) for name in batchname for value in train_atom_levl_label[name].to_dense()]
          atom_levl_label = torch.cat(label_part).to(device)
          residual_level_label = [value for name in batchname for value in train_residual_level_label[name]]

          data =batchdata(data_batch) #按照batch_size进行组装

          optimizer.zero_grad()  
          outputs = model(data.to(device)) 
          prediction = outputs["node_embedding"] 
          predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
          predicted = F.pad(predicted, (14, 14), mode='reflect')
          predicted = smoothing(predicted)
          prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))
          loss = criterion(prediction_Gauss.to(device), atom_levl_label)
          loss.backward()
          optimizer.step() 

          #calulate atom-level accuracy and node-level accuracy
          _, predicted = torch.max(prediction_Gauss.to(device), 1) 
          correct = (predicted == atom_levl_label ).sum().item()
          total = atom_levl_label.size(0)
          atom_level_accuracy =  correct / total

          processor = MapAtomNode(predicted.cpu(),batchname,train_dismatch_index_pred,train_dismatch_index_type,df_train)
          train_predict_node_label = processor.map_atom_node() 
          residual_level_accuracy = node_accuracy(train_predict_node_label,residual_level_label)

          wandb.log({'train_loss_step':loss.item(), 'global_step':global_step})
          wandb.log({'train_atom_level_accuracy_step':atom_level_accuracy,  'global_step':global_step})
          wandb.log({'train_residual_level_accuracy_step':residual_level_accuracy, 'global_step':global_step})


          epoch_loss_train.append(loss.item())
          epoch_atom_level_accuracy_train.append(atom_level_accuracy)
          epoch_residual_level_accuracy_train.append(residual_level_accuracy)

     epoch_loss_record_train.append(np.mean(epoch_loss_train))
     epoch_atom_level_accuracy_record_train.append(np.mean(epoch_atom_level_accuracy_train))
     epoch_residual_level_accuracy_record_train.append(np.mean(epoch_residual_level_accuracy_train))
     wandb.log({'train_loss_epoch':np.mean(epoch_loss_train), 'global_step':global_step})
     wandb.log({'train_atom_level_accuracy_epoch':np.mean(epoch_atom_level_accuracy_train),  'global_step':global_step})
     wandb.log({'train_residual_level_accuracy_epoch':np.mean(epoch_residual_level_accuracy_train), 'global_step':global_step})
    

     # val
     model.eval()  
     with torch.no_grad():  

          epoch_atom_level_accuracy_val = []
          epoch_loss_val = []
          epoch_residual_level_accuracy_val = []

          for data_batch in val_data_loader:
               batchname=[data_batch[num]['name'] for num in range(len(data_batch))]

               label_part = [value.unsqueeze(0) for name in batchname for value in val_atom_levl_label[name].to_dense()]
               atom_levl_label = torch.cat(label_part).to(device)
               residual_level_label = [value for name in batchname for value in val_residual_level_label[name]]



               data =batchdata(data_batch) #按照batch_size进行组装
               outputs = model(data.to(device)) 
               prediction = outputs["node_embedding"] 

               predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
               predicted = F.pad(predicted, (14, 14), mode='reflect')
               predicted = smoothing(predicted)
               prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))
               loss = criterion(prediction_Gauss.to(device), atom_levl_label)

               _, predicted = torch.max(prediction_Gauss.to(device), 1) 
               correct = (predicted == atom_levl_label ).sum().item()
               total = atom_levl_label.size(0)
               atom_level_accuracy =  correct / total

               processor = MapAtomNode(predicted.cpu(),batchname,val_dismatch_index_pred,val_dismatch_index_type,df_val)
               val_predict_node_label = processor.map_atom_node() 
               residual_level_accuracy = node_accuracy(val_predict_node_label,residual_level_label)

               epoch_loss_val.append(loss.item())
               epoch_atom_level_accuracy_val.append(atom_level_accuracy)
               epoch_residual_level_accuracy_val.append(residual_level_accuracy)

               wandb.log({'val_loss_step':loss.item(), 'global_step':global_step})
               wandb.log({'val_atom_level_accuracy_step':atom_level_accuracy,  'global_step':global_step})
               wandb.log({'val_residual_level_accuracy_step':residual_level_accuracy, 'global_step':global_step})


          epoch_loss_record_val.append(np.mean(epoch_loss_val))
          epoch_atom_level_accuracy_record_val.append(np.mean(epoch_atom_level_accuracy_val))
          epoch_residual_level_accuracy_record_val.append(np.mean(epoch_residual_level_accuracy_val))
          wandb.log({'val_loss_epoch':np.mean(epoch_loss_val), 'global_step':global_step})
          wandb.log({'val_atom_level_accuracy_epoch':np.mean(epoch_atom_level_accuracy_val), 'global_step':global_step})
          wandb.log({'val_residual_level_accuracy_epoch':np.mean(epoch_residual_level_accuracy_val), 'global_step':global_step})


wandb.finish()


print("Finished training.")

torch.save(model.state_dict(), '/work3/s230027/DL/result/egnn/egnn_model_size1_epoch50.pth')
print('epoch_residual_level_accuracy_record_train',epoch_residual_level_accuracy_record_train)
print('epoch_residual_level_accuracy_record_val',epoch_residual_level_accuracy_record_val)
print('epoch_loss_record_train',epoch_loss_record_train)
print('epoch_loss_record_val',epoch_loss_record_val)

node_acc_results = np.concatenate([ [np.array(epoch_residual_level_accuracy_record_train)], [np.array(epoch_residual_level_accuracy_record_val)] ])
np.savetxt("/work3/s230027/DL/result/egnn/CVsetup1_residual_acc_results.csv", node_acc_results, delimiter=',', comments="", fmt='%s')

loss_results = np.concatenate([[np.array(epoch_loss_record_train)], [np.array(epoch_loss_record_val)] ])
np.savetxt("/work3/s230027/DL/result/egnn/CVsetup1_loss_results.csv", loss_results, delimiter=',', comments="", fmt='%s')


   
        