**For Fine-tuning**
===
> -*- coding: utf-8 -*-

> Author: xinghd


## Import libraries

In [None]:
from utils.utils import Trainer, Predictor, TrainReader,TestReader,PredictReader, data_collate, get_shuffle_data, collate_fn, seed_everything
from model.model import Encoder, Decoder, ModelCat

%reload_ext autoreload
%autoreload 2

In [None]:
import os
import json
import random
import time
import timeit
import pickle
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors
from torch import nn
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import torch
from torch.utils.data import DataLoader, random_split
from utils.word2vec import seq_to_kmers, get_protein_embedding
from gensim.models import Word2Vec
from torch.utils import data as torch_data
import warnings

### Parameters

In [None]:
assert torch.cuda.is_available(), "Must have avaliable gpu"

class CFG :
    # train dataset and test dataset
    traindata_path = './data/traindata_finetuning_example.csv'
    testdata_path = './data/testdata_finetuning_example.csv'
    # epoch
    EPOCHES = 100
    # batch size
    BATCH_SIZE = 30
    
    # learning rate
    lr = 1e-4
    # weight_decay
    weight_decay = 1e-4
    
    # suffix of log/model filename
    modelsave_file_suffix = 'funetuning_epoch.pt'
    result_file_suffix = 'funetuning_log.txt'
    
    # create log file
    result_file_path = './results/log/funetuning_log/'
    # save_best_model(depends on AUC)
    best_model_savepath = "./model/model_funetuning/"

    
    # use wandb(https://wandb.ai/)?
    IFwandb = False
    # data root
    DATA_ROOT = r'./data/'
    # save unlabel data that may be positive
    U_m_P_savepath = r"./data/U_m_P/"
    # word2vec model path
    word2vec_path = './model/model_pretrained/word2vec_pretrained.model'
    # pre_trained model path , if given None, training from scratch
    state_dict_path = './model/model_pretrained/PU-EPP_pretrained.pt'
    # Number CUDA Devices:
    gpu_number = torch.cuda.device_count()
    # DEVICE
    DEVICE = torch.device('cuda:0')

    
    
    # block layers
    layer_num = 12
    # The last dimension of the protein data
    protein_dim = 100
    # The last dimension of the compound data
    atom_dim = 46
    # hidden dimension
    hid_dim = 128
    # norm_shape: layernorm parameter
    norm_shape = 128
    # ====================================================
    # to create trainreader
    # ====================================================
    # if use pu learning
    ifpu = True
    # if use label smoothing
    ifsmoothing = True

    # start deleting data when auc value is greater than del_threshold
    del_threshold=0.9

    quantile = 0.9


## WandB

In [None]:
if CFG.IFwandb:
    import wandb
    wandb.login()
    wandb.init(project='PU-EPP', name='finetuning')

## Parameters

### Random Seed

In [None]:
seed_everything(seed=42)

## Model

In [None]:
encoder = Encoder(CFG.protein_dim, CFG.hid_dim, CFG.norm_shape)
decoder = Decoder(CFG.atom_dim, CFG.hid_dim, CFG.norm_shape)
model = ModelCat(encoder, decoder, atom_dim=CFG.atom_dim)
model = model.to(CFG.DEVICE)
model = nn.DataParallel(model, device_ids = list(range(CFG.gpu_number)))

In [None]:
model.load_state_dict(torch.load(CFG.state_dict_path, map_location=CFG.DEVICE))

In [None]:
if CFG.IFwandb:
    wandb.watch(model, log='all')

## Trainer

In [None]:
""" create trainer """
# /utils/builder.py
trainer = Trainer(model, CFG) #trainfirst=False

## Dataset & Dataloader

### Dataset 
(you can use get_shuffle_data to divide the dataset into a testset and a trainset')

In [None]:
trainset = pd.read_csv(CFG.traindata_path)
testset = pd.read_csv(CFG.testdata_path)



### Dataloader

In [None]:
traindata = TrainReader(data=trainset,U_m_P_savepath=CFG.U_m_P_savepath,
                                word2vec_path=CFG.word2vec_path, ifpu=CFG.ifpu)
testdata = TestReader(data=testset, word2vec_path=CFG.word2vec_path)

''' if our pretrained word2vec model does not include all the protein data in your dataset, you need to train word2vec model by yourself'''

# train_word2vec(list(traindata['Protein'].unique()) + list(testdata['Protein'].unique()), './model/model_pretrained/word2vec_yourself.model') 

# CFG.word2vec_path = './model/model_pretrained/word2vec_yourself.model' 



In [None]:
val_dataloader = DataLoader(testdata, batch_size=CFG.BATCH_SIZE,collate_fn=collate_fn)

## Train

In [None]:
file_model = CFG.modelsave_file_suffix
reshead = 'Epoch\tTime(sec)\tLoss_train\tAUC_dev\tPre\tRecall\tPRC_dev'
file_res = CFG.result_file_path+CFG.result_file_suffix
if os.path.exists(file_res):
    warnings.warn("----------------file name duplicate-----------------")
else:
    with open(file_res, 'w') as f:
        f.write(reshead + '\n')

In [None]:
max_AUC_dev = 0
for epoch in range(CFG.EPOCHES):
    print('epoch: ' + str(epoch))
    train_dataloader = DataLoader(traindata, batch_size=CFG.BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

    """Start training."""
    print('Training........................')
    torch.cuda.empty_cache()
    start = timeit.default_timer()

    loss_train, U_m_P = trainer.train(train_dataloader)
    AUC_dev, precision, recall, PRC_dev = trainer.test(val_dataloader)

    end = timeit.default_timer()
    time = end - start
    
    reslist = [epoch, time, loss_train, AUC_dev, precision, recall, PRC_dev]
    if CFG.IFwandb:
        wandb.log({'loss_train':loss_train, "AUC_dev":AUC_dev, 'pre':precision, 'recall':recall, 'prc':PRC_dev})
    trainer.save_AUCs(reslist, file_res)
    if AUC_dev > max_AUC_dev:
        trainer.save_model(model, f"{CFG.best_model_savepath}_epoch{epoch}_{AUC_dev}.pt")
        max_AUC_dev = AUC_dev
    print('\t'.join(map(str, reslist)))

    if AUC_dev > 0.85: #changed
        print(f"del {epoch}:{len(U_m_P)}")
        traindata.del_U(epoch, U_m_P)
    else:
        traindata.reset_T()
    if epoch % 5 == 0:
        trainer.save_model(model, f"{CFG.best_model_savepath}_epoch{epoch}_{AUC_dev}.pt")
    torch.cuda.empty_cache()