**For Training**
===
> -*- coding: utf-8 -*-

> Author: xinghd



## Import libraries

In [None]:
from utils.utils import Trainer, Predictor, TrainReader,TestReader,PredictReader, data_collate, get_shuffle_data, collate_fn, seed_everything
from model.model import Encoder, Decoder, ModelCat

%reload_ext autoreload
%autoreload 2

In [None]:
import os
import json
import random
import time
import timeit
import pickle
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
from torch import nn
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import torch
from torch.utils.data import DataLoader, random_split
from utils.word2vec import seq_to_kmers, get_protein_embedding, train_word2vec
from gensim.models import Word2Vec
from torch.utils import data as torch_data
import warnings

### Parameters

In [None]:
assert torch.cuda.is_available(), "Must have a avaliable gpu"

class CFG :
    IFwandb = False

    trainfirst = True # True or False, train model from scratch if True else finetune model
    # Data root
    DATA_ROOT = r'./data/'
    # Save unlabel data that may be positive
    U_m_P_savepath = r"./data/U_m_P/"
    # Word2vec model path
    word2vec_path = './model/model_pretrained/word2vec_pretrained.model' 
    # Pre_trained model path , if given None, training from scratch
    state_dict_path = None
    # Number CUDA Devices:
    gpu_number = torch.cuda.device_count()
    # Device
    DEVICE = torch.device('cuda:0')
    # ====================================================
    # Hyperparameters
    # ====================================================
    # Batch size
    BATCH_SIZE = 4
    # Epoch
    EPOCHES = 100
    # Block layers
    layer_num = 12
    # The last dimension of the protein data
    protein_dim = 100
    # The last dimension of the compound data
    atom_dim = 46
    # Hidden dimension
    hid_dim = 128
    # Norm_shape: layernorm parameter
    norm_shape = 128
    # ====================================================
    # To create trainreader
    # ====================================================
    # If use pu learning
    ifpu = True
    # If use label smoothing
    ifsmoothing = True
    # ====================================================
    # To create trainer
    # ====================================================
    # Learning rate
    lr = 1e-4
    # Weight_decay
    weight_decay = 1e-4
    # Start removing potential positives from unlabeled samples when AUC > del_threshold
    del_threshold=0.9
    
    # create result file
    result_file_path = './results/log/'
    # save_best_model(depends on AUC)
    best_model_savepath = "./model/model_save/"
    # 
    modelsave_file_suffix = 'example_epoch.pt'
    result_file_suffix = 'example_log.txt'
    
    quantile = 0.9


## WandB

In [None]:
if CFG.IFwandb:
    import wandb
    wandb.login()
    wandb.init(project='PU-EPP', name='train')


## Parameters

### Random Seed

In [None]:
seed_everything(seed=42)

## Model

In [None]:
encoder = Encoder(CFG.protein_dim, CFG.hid_dim, CFG.norm_shape)
decoder = Decoder(CFG.atom_dim, CFG.hid_dim, CFG.norm_shape)
model = ModelCat(encoder, decoder, atom_dim=CFG.atom_dim)
model = model.to(CFG.DEVICE)
model = nn.DataParallel(model, device_ids = list(range(CFG.gpu_number)))

In [None]:
if CFG.IFwandb:
    wandb.watch(model, log='all')

## Trainer

In [None]:
trainer = Trainer(model, CFG, trainfirst=CFG.trainfirst)

## Dataset & Dataloader

### Dataset 
(you can use get_shuffle_data to divide the dataset into a testset and a trainset')

In [None]:
traindata = pd.read_csv('./data/example_train.csv')
testdata = pd.read_csv('./data/example_test.csv')
id2seq = json.load(open('./data/example_id2seq.json', 'r')) # To save storage space
traindata['Protein'] = traindata['Protein'].map(lambda x: id2seq[x])
testdata['Protein'] = testdata['Protein'].map(lambda x: id2seq[x])

if CFG.trainfirst: 
    # if you want to train model from scratch or our pretrained word2vec model does not include all the protein data in your dataset, 
    # you need to train word2vec model by yourself
    train_word2vec(list(traindata['Protein'].unique()) + list(testdata['Protein'].unique()), './model/model_pretrained/word2vec_yourself.model')
    # protein data, word2vec model path 
    CFG.word2vec_path = './model/model_pretrained/word2vec_yourself.model'

### Dataloader

In [None]:
traindata = TrainReader(data=traindata,U_m_P_savepath=CFG.U_m_P_savepath,
                                word2vec_path=CFG.word2vec_path, ifpu=CFG.ifpu)
testdata = TestReader(data=testdata, word2vec_path=CFG.word2vec_path)

In [None]:
val_dataloader = DataLoader(testdata, batch_size=CFG.BATCH_SIZE,collate_fn=collate_fn)

## Train

In [None]:
file_model = CFG.modelsave_file_suffix
reshead = 'Epoch\tTime(sec)\tLoss_train\tAUC_dev\tPre\tRecall\tPRC_dev'
file_res = CFG.result_file_path+CFG.result_file_suffix
if os.path.exists(file_res):
    warnings.warn("----------------duplicate filename-----------------")
else:
    with open(file_res, 'w') as f:
        f.write(reshead + '\n')

In [None]:
max_AUC_dev = 0
for epoch in range(CFG.EPOCHES):
    print('epoch: ' + str(epoch))
    train_dataloader = DataLoader(traindata, batch_size=CFG.BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

    """Start training."""
    print('Training........................')
    torch.cuda.empty_cache()
    start = timeit.default_timer()

    loss_train, U_m_P = trainer.train(train_dataloader)
    AUC_dev, precision, recall, PRC_dev = trainer.test(val_dataloader)

    end = timeit.default_timer()
    time = end - start
    
    reslist = [epoch, time, loss_train, AUC_dev, precision, recall, PRC_dev]
    if CFG.IFwandb:
        wandb.log({'loss_train':loss_train, "AUC_dev":AUC_dev, 'pre':precision, 'recall':recall, 'prc':PRC_dev})
    trainer.save_AUCs(reslist, file_res)
    if AUC_dev > max_AUC_dev:
        trainer.save_model(model, f"{CFG.best_model_savepath}_epoch{epoch}_{AUC_dev}.pt")
        max_AUC_dev = AUC_dev
    print('\t'.join(map(str, reslist)))

    if AUC_dev > 0.85: #changed
        print(f"del {epoch}:{len(U_m_P)}")
        traindata.del_U(epoch, U_m_P)
    else:
        traindata.reset_T()
    if epoch % 5 == 0:
        trainer.save_model(model, f"{CFG.best_model_savepath}_epoch{epoch}_{AUC_dev}.pt")
    torch.cuda.empty_cache()