In [1]:
import time
import numpy as np
import torchmetrics

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as Data
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable

import transformers
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
 

In [2]:
# choose train mode
global mode
# mode = 'pretrain'
mode = 'finetune'
# mode = 'else'


In [3]:
# device and tokenizer

device = torch.device('cuda:1')


tokenizer = GPT2Tokenizer('./vocab_file/vocab.json', './vocab_file/merges.txt')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
if mode == 'pretrain':
    tokenizer.save_pretrained('./save_model/pretrain_model/pretrained-gpt-10-64raw-50epochs') # when pretrain model
print(tokenizer)

PreTrainedTokenizer(name_or_path='', vocab_size=23, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '[PAD]'})


In [4]:
#定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self):
        if mode == 'finetune':
#             with open('./data/ADP3_amp.txt') as f:  # when finetune
            with open('./data/ADP3_amp.txt') as f:  # when finetune
                lines = f.readlines()
        elif mode == 'pretrain':
            with open('./data/pretrain_data/uniprot10-63.txt') as f:  # when pretrain model 
                lines = f.readlines()
        else:
            print('train mode error')
            with open('./data/ADP3_amp.txt') as f:
                lines = f.readlines()
        lines = [i.strip() for i in lines]

        self.lines = lines

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, i):
        return self.lines[i]


global val_split
if mode == 'pretrain':
    val_split = 0.01
elif mode =='finetune':
    val_split = 0.1
else:
    val_split = 0.4

shuffle_dataset = True
random_seed = 42

dataset = Dataset()
dataset_size = len(dataset)

indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = Data.SubsetRandomSampler(train_indices)
val_sampler = Data.SubsetRandomSampler(val_indices)

def collate_fn(data):
    data = tokenizer.batch_encode_plus(data,
                                       padding=True,
                                       truncation=True,
                                       max_length=64,
                                       return_tensors='pt')

    data['labels'] = data['input_ids'].clone()

    return data



train_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=64, 
    sampler=train_sampler,
    collate_fn=collate_fn,
    drop_last=True,)


val_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=64, 
    sampler=val_sampler,
    collate_fn=collate_fn,
    drop_last=True,)

# for i, data in enumerate(val_loader):
    

#     for k, v in data.items():
#         print(k, v.shape, v)

len(train_loader)

89

In [5]:
# define GPT model

from transformers import GPT2Model, GPT2Config

# Initializing a GPT2 configuration
configuration = GPT2Config(n_layer=12, 
                           n_head=12,
                           n_embd=768)

# print(configuration)

# Initializing a model from the configuration
if mode == 'pretrain':
    model = GPT2LMHeadModel(configuration)  # when pretrain model
elif mode == 'finetune':
    model = GPT2LMHeadModel.from_pretrained('./save_model/pretrain_model/pretrained-GPT-10-64washed-30epochs')
else:
    pass
#     model = GPT2LMHeadModel.from_pretrained('./save_model/pretrain_model/pretrained-GPT-10-64raw-20epochs/')
# model = torch.load('./save_model/pretrained-gpt-10-48-30epochs/pytorch_model.bin')  # pretrain model use


In [6]:
from torch.optim import AdamW
from transformers.optimization import get_scheduler
from torchmetrics import Perplexity
from torchmetrics import Accuracy

accuracy = Accuracy(task="multiclass",num_classes=23,ignore_index=23)
accuracy = accuracy.to(device)

perplexity = Perplexity(ignore_index=23).to(device)

epochs = 100

train_loss_list = []
val_loss_list = []
train_acc_list =[]
val_acc_list = []
train_pp_list = []
val_pp_list = []
    
#训练
def train():
    global model
#     device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=1e-6)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=400,
                              num_training_steps=len(train_loader)*epochs,
                              optimizer=optimizer)

    model.train()
    print('开始训练')
    start_time = time.time()

    for epoch in range(epochs):
        train_loss = []
        train_accuracy = []
        train_pp = []
        val_loss = []
        val_accuracy = []
        val_pp = []
        
        for batch_idx, batch_data in enumerate(train_loader):
            batch_data = batch_data.to(device)
            out = model(**batch_data)
#             print(out.keys())
            loss = out['loss']
            train_loss.append(loss.item())
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            optimizer.zero_grad()
            model.zero_grad()
            
            labels = batch_data['labels'][:, 1:]
            outs = out['logits'].argmax(dim=2)[:, :-1]
#             print(out)
#             print(labels)
            
            perplexity_train = perplexity(out['logits'][:,:-1,:], labels)
            train_acc = accuracy(outs, labels)
            
            train_accuracy.append(train_acc.tolist())
            train_pp.append(perplexity_train.tolist())
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            
            if batch_idx % 50 == 0:
                print('train_batch: {:3d}  loss:{:.4f}  accuracy:{:.4f}  perplexity:{:.4f}'
                      .format(batch_idx, loss.item(), train_acc.item(), perplexity_train.item()))
            
        for batch_idx, batch_data in enumerate(val_loader):
            batch_data = batch_data.to(device)
            out = model(**batch_data)
            loss = out['loss']
            labels = batch_data['labels'][:, 1:]
            outs = out['logits'].argmax(dim=2)[:, :-1]
            val_acc = accuracy(outs, labels)
            val_loss.append(loss.item())
            val_accuracy.append(val_acc.tolist())
            
            perplexity_test = perplexity(out['logits'][:,:-1,:], labels)
            val_pp.append(perplexity_test.tolist())
#             total_perplexity_test += perplexity_test * 64 # len(inputs)
#         perplexity_test = total_perplexity_test / len(batch_data)
#         print('Epoch: %d, Test Perplexity: %.4f' % (epoch+1, perplexity_test))
        
            if batch_idx % 50 == 0:
                print('val_batch:   {:3d}  loss:{:.4f}  accuracy:{:.4f}  perplexity:{:.4f}'
                      .format(batch_idx, loss.item(), val_acc.item(), perplexity_test.item()))
            
        train_loss_list.append(np.mean(train_loss))
        train_acc_list.append(np.mean(train_accuracy))
        train_pp_list.append(np.mean(train_pp))
        val_loss_list.append(np.mean(val_loss))
        val_acc_list.append(np.mean(val_accuracy))
        val_pp_list.append(np.mean(val_pp))
        train_time = time.time()
        print('第{}代训练完成,历时{}秒'.format(epoch+1,train_time-start_time))
        print('epoch {} mean training loss:{:.4f}'.format(epoch+1, np.mean(train_loss)))
        print('epoch {} mean training accuracy:{:.4f}'.format(epoch+1, np.mean(train_accuracy)))
        print('epoch {} mean training perplexity:{:.4f}'.format(epoch+1, np.mean(train_pp)))
        print('epoch {} mean val loss:{:.4f}'.format(epoch+1, np.mean(val_loss)))
        print('epoch {} mean val accuracy:{:.4f} '.format(epoch+1, np.mean(val_accuracy)))
        print('epoch {} mean val perplexity:{:.4f} '.format(epoch+1, np.mean(val_pp)))
        print(' ')
        
    
    end_time = time.time()
    print('训练结束,训练时长：',end_time-start_time, '秒')   
    
    

In [7]:
# 训练模型

train()


开始训练
train_batch:   0  loss:1.4211  accuracy:0.1414  perplexity:16.4079
train_batch:  50  loss:1.3451  accuracy:0.1506  perplexity:16.0595
val_batch:     0  loss:1.5209  accuracy:0.1799  perplexity:14.9772
第1代训练完成,历时23.507112979888916秒
epoch 1 mean training loss:1.4416
epoch 1 mean training accuracy:0.1502
epoch 1 mean training perplexity:16.1885
epoch 1 mean val loss:1.4703
epoch 1 mean val accuracy:0.1629 
epoch 1 mean val perplexity:15.5706 
 
train_batch:   0  loss:1.3644  accuracy:0.1555  perplexity:15.4248
train_batch:  50  loss:1.3489  accuracy:0.1950  perplexity:13.9398
val_batch:     0  loss:1.3611  accuracy:0.1779  perplexity:14.4641
第2代训练完成,历时46.66934895515442秒
epoch 2 mean training loss:1.3717
epoch 2 mean training accuracy:0.1689
epoch 2 mean training perplexity:14.9685
epoch 2 mean val loss:1.4001
epoch 2 mean val accuracy:0.1867 
epoch 2 mean val perplexity:14.1032 
 
train_batch:   0  loss:1.4946  accuracy:0.1905  perplexity:13.8138
train_batch:  50  loss:1.2920  accura

train_batch:   0  loss:1.0542  accuracy:0.3235  perplexity:9.1769
train_batch:  50  loss:1.0117  accuracy:0.3649  perplexity:8.1135
val_batch:     0  loss:1.1677  accuracy:0.2898  perplexity:10.1705
第20代训练完成,历时469.920743227005秒
epoch 20 mean training loss:1.0221
epoch 20 mean training accuracy:0.3755
epoch 20 mean training perplexity:7.8159
epoch 20 mean val loss:1.1571
epoch 20 mean val accuracy:0.3278 
epoch 20 mean val perplexity:9.2370 
 
train_batch:   0  loss:0.9048  accuracy:0.4454  perplexity:6.3453
train_batch:  50  loss:1.0944  accuracy:0.3843  perplexity:7.5128
val_batch:     0  loss:1.1457  accuracy:0.3366  perplexity:8.9830
第21代训练完成,历时493.42784786224365秒
epoch 21 mean training loss:1.0158
epoch 21 mean training accuracy:0.3797
epoch 21 mean training perplexity:7.6974
epoch 21 mean val loss:1.1563
epoch 21 mean val accuracy:0.3314 
epoch 21 mean val perplexity:9.0814 
 
train_batch:   0  loss:1.0837  accuracy:0.3245  perplexity:8.9451
train_batch:  50  loss:1.1997  accuracy

val_batch:     0  loss:1.1369  accuracy:0.3528  perplexity:8.6844
第47代训练完成,历时1104.7645363807678秒
epoch 47 mean training loss:0.8785
epoch 47 mean training accuracy:0.4617
epoch 47 mean training perplexity:5.8810
epoch 47 mean val loss:1.1160
epoch 47 mean val accuracy:0.3703 
epoch 47 mean val perplexity:8.4300 
 
train_batch:   0  loss:0.9822  accuracy:0.4403  perplexity:6.1965
train_batch:  50  loss:0.9438  accuracy:0.4361  perplexity:6.3305
val_batch:     0  loss:1.0498  accuracy:0.3892  perplexity:7.8886
第48代训练完成,历时1128.2565503120422秒
epoch 48 mean training loss:0.8753
epoch 48 mean training accuracy:0.4637
epoch 48 mean training perplexity:5.8434
epoch 48 mean val loss:1.1327
epoch 48 mean val accuracy:0.3670 
epoch 48 mean val perplexity:8.4554 
 
train_batch:   0  loss:0.9128  accuracy:0.4578  perplexity:6.1568
train_batch:  50  loss:0.7667  accuracy:0.5339  perplexity:4.7162
val_batch:     0  loss:1.1130  accuracy:0.3508  perplexity:8.7464
第49代训练完成,历时1151.7446439266205秒
epoch 4

train_batch:   0  loss:0.8394  accuracy:0.5025  perplexity:5.2963
train_batch:  50  loss:0.8293  accuracy:0.4967  perplexity:5.3396
第66代训练完成,历时1551.1488945484161秒
epoch 66 mean training loss:0.8243
epoch 66 mean training accuracy:0.4938
epoch 66 mean training perplexity:5.2845
epoch 66 mean val loss:1.1232
epoch 66 mean val accuracy:0.3704 
epoch 66 mean val perplexity:8.4698 
 
train_batch:   0  loss:0.9469  accuracy:0.4469  perplexity:6.0518
train_batch:  50  loss:0.8694  accuracy:0.4613  perplexity:5.8648
val_batch:     0  loss:1.2907  accuracy:0.3510  perplexity:9.5989
第67代训练完成,历时1574.6456258296967秒
epoch 67 mean training loss:0.8238
epoch 67 mean training accuracy:0.4928
epoch 67 mean training perplexity:5.2689
epoch 67 mean val loss:1.1111
epoch 67 mean val accuracy:0.3779 
epoch 67 mean val perplexity:8.3641 
 
train_batch:   0  loss:0.8742  accuracy:0.4690  perplexity:5.7445
train_batch:  50  loss:0.6869  accuracy:0.5280  perplexity:4.6488
val_batch:     0  loss:1.0619  accurac

train_batch:   0  loss:0.8217  accuracy:0.5040  perplexity:5.0919
train_batch:  50  loss:0.8185  accuracy:0.4987  perplexity:5.1080
val_batch:     0  loss:1.2195  accuracy:0.3356  perplexity:9.8599
第86代训练完成,历时2021.2709577083588秒
epoch 86 mean training loss:0.7954
epoch 86 mean training accuracy:0.5100
epoch 86 mean training perplexity:4.9948
epoch 86 mean val loss:1.1181
epoch 86 mean val accuracy:0.3795 
epoch 86 mean val perplexity:8.5176 
 
train_batch:   0  loss:0.9119  accuracy:0.4537  perplexity:5.9325
train_batch:  50  loss:0.7770  accuracy:0.5145  perplexity:4.9067
val_batch:     0  loss:1.1981  accuracy:0.3568  perplexity:9.1928
第87代训练完成,历时2044.7435276508331秒
epoch 87 mean training loss:0.7953
epoch 87 mean training accuracy:0.5108
epoch 87 mean training perplexity:4.9786
epoch 87 mean val loss:1.1252
epoch 87 mean val accuracy:0.3770 
epoch 87 mean val perplexity:8.6112 
 
train_batch:   0  loss:0.9115  accuracy:0.4629  perplexity:5.8637
train_batch:  50  loss:0.7547  accurac

In [8]:
# # 保存模型

# model = model.to('cpu')

# if mode == 'pretrain':
#     model.save_pretrained('./save_model/pretrain_model/pretrained-GPT-10-64raw-20epochs/')  # when pretrain model
# elif mode == 'finetune':
#     torch.save(model, './save_model/finetune-model/finetune_with_AMPbert_data/finetune-10-48-GPT-'+str(epochs)+'epochs')  # finetune model
# else:
#     pass
    

In [None]:
# 困惑度趋势图

import matplotlib.pyplot as plt

save_dir = '/home/xms/AMP-master/generate_file/train_acc_loss/'

x1 = [(x+1) for x in range(len(train_pp_list))]
x2 = [(x+1) for x in range(len(val_pp_list))]
y1 = train_pp_list
y2 = val_pp_list

plt.plot(x1, y1, label="AMP training perplexity")
plt.plot(x1, y2, label="AMP val_perplexity")
plt.xlabel('step')
plt.ylabel('perplexity')
plt.title('AMP train perplexity show')
plt.legend()
plt.savefig(save_dir+str(epochs)+'_AMP train perplexity show')
plt.show()

In [None]:
# # 结果作图

# import matplotlib.pyplot as plt

# x1 = [(x+1) for x in range(len(train_loss_list))]
# x2 = [(x+1) for x in range(len(val_acc_list))]
# y1 = train_loss_list
# y2 = train_acc_list
# y3 = val_loss_list
# y4 = val_acc_list

# plt.plot(x1, y1, label="AMP training loss")
# plt.plot(x1, y3, label="AMP val_loss")
# plt.xlabel('step')
# plt.ylabel('loss')
# plt.title('AMP train losses show')
# plt.legend()
# if mode == 'pretrain':
#     plt.savefig('/xms/AMP-master/generate_file/train_graphical_result-2023/pretrain_result/'+
#                 mode+'10-48washed-'+str(epochs)+'epochs_loss.jpg')
# elif mode == 'finetune':
#     plt.savefig('/xms/AMP-master/generate_file/train_graphical_result-2023/finetune_result/'+
#                 mode+'10-48washed-'+str(epochs)+'epochs_loss.jpg')
# else:
#     pass
# plt.show()


# plt.plot(x2, y2, label="AMP train_acc curse")
# plt.plot(x2, y4, label="AMP val_acc curse")
# plt.xlabel('step')
# plt.ylabel('acc')
# plt.title('AMP val_acc show')
# plt.legend()
# if mode == 'pretrain':
#     plt.savefig('/xms/AMP-master/generate_file/train_graphical_result-2023/pretrain_result/'+
#                 mode+'10-48washed-'+str(epochs)+'epochs_accuracy.jpg')
# elif mode == 'finetune':
#     plt.savefig('/xms/AMP-master/generate_file/train_graphical_result-2023/finetune_result/'+
#                 mode+'10-48washed-'+str(epochs)+'epochs_accuracy.jpg')
# else:
#     pass

# plt.show()

In [None]:
import torch
from torchmetrics import Perplexity
preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
print(preds)
target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
print(target)
target[0, 6:] = -100
perp = Perplexity(ignore_index=-100)
perp(preds, target)


In [None]:
from torchmetrics import Accuracy

accuracy = Accuracy(task="multiclass",num_classes=23,ignore_index=23)

a = torch.tensor([[1,2],[12,23]])
b = torch.tensor([[2,2],[12,23]])
print(accuracy(a,b))