In [1]:
# gptname = 'gpt2' #small
gptname = 'gpt2-xl'

colab = False
enablenews = True
external_save = True
enable_trash = True

# large
# lrmult = 0.5
# xl
lrmult = 0.1
accumn = 10
devicename = 'cpu'

if not colab:
    filepath = '.'
else:
    filepath = '/content/drive/MyDrive/bio'
    ! pip3 install transformers
    from google.colab import drive
    drive.mount('/content/drive')
 
import sys
sys.path.append(filepath + '/utils')
sys.path.append(filepath + '/dataset')

In [2]:
import argparse, json, os, pickle, random, time, numpy as np, transformers, torch

from transformers import GPT2Config, GPT2LMHeadModel,AdamW, GPT2Tokenizer
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm.notebook import tnrange, tqdm

from utils import SummarizationDataset, FineTuningDataset, get_tokenizer, generate_sample, sample_seq, set_seed, top_k_top_p_filtering

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--lr",default=5e-5, type=float, help="learning rate")
parser.add_argument("--seed",default=42, type=int,  help="seed to replicate results")
parser.add_argument("--n_gpu",default=1, type=int,  help="no of gpu available")
parser.add_argument("--gradient_accumulation_steps",default=10, type=int, help="gradient_accumulation_steps")
parser.add_argument("--batch_size",default=1, type=int,  help="batch_size")
parser.add_argument("--num_workers",default=2, type=int,  help="num of cpus available")
parser.add_argument("--device",default=torch.device(devicename), help="torch.device object")
parser.add_argument("--num_train_epochs",default=1, type=int,  help="no of epochs of training")
if external_save:
    parser.add_argument("--model_dir", default='D:/external_weights', type=str,  help="path to save trained model")
else:
    parser.add_argument("--model_dir",default= filepath + '/weights', type=str,  help="path to save trained model")
parser.add_argument("--max_grad_norm",default=1.0, type=float, help="max gradient norm.")
parser.add_argument("--root_dir",default= filepath + '/bignews/gpt2_1024_data', type=str, help="location of json dataset.")
parser.add_argument("--ids_file",default= filepath + '/bignews/ids.json', type=str, help="location of train, valid and test file indexes")
args = parser.parse_args([])


if enablenews:
    strnews = 'ye'
else:
    strnews = 'no'
if enable_trash:
    str_trash = 'ye'
else:
    str_trash = 'no'

model_text = '_bio_{}_lr{}_accum{}_{}news_{}trash'.format(gptname, int(lrmult*100), accumn, strnews, str_trash)
model_file = os.path.join(args.model_dir, 'model' + model_text + '.bin')
config_file = os.path.join(args.model_dir, 'config' + model_text + '.json')
log_file = os.path.join(args.model_dir, 'log' + model_text + '.txt')

In [4]:
def evaluate(args, model, eval_dataset):
    sumloss = 0.0
    with torch.no_grad():
        model.eval()
        eval_dl = DataLoader(
            eval_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
        loss_fct = CrossEntropyLoss()
        
        for batch in eval_dl:
            inputs, labels = batch['article'].to(args.device), batch['article'].to(args.device)
            logits = model(inputs)[0]
            shift_logits = logits[..., batch['sum_idx']:-1, :].contiguous()
            shift_labels = labels[..., batch['sum_idx']+1:].contiguous()
            sumloss += loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).item()

    return sumloss / len(eval_dataset)

In [5]:
def finetune(args, model, tokenizer, finetune_dataset, eval_dataset, model_file, config_file, log_file):
    loss_fct = CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=1*args.lr)
    scheduler = transformers.get_linear_schedule_with_warmup(
      optimizer, 0,
      2*len(finetune_dataset)*args.num_train_epochs//args.gradient_accumulation_steps)
    
    sumloss = 0.0
    log_text = ''

    for epoch in range(args.num_train_epochs):
      train_sampler = RandomSampler(finetune_dataset)
      train_dl = DataLoader(
        finetune_dataset,sampler=train_sampler,
        batch_size=args.batch_size, num_workers=args.num_workers)
      model.zero_grad()
      set_seed(args)
      epoch_iterator = tqdm(train_dl, desc="Training")
      for step, batch in enumerate(epoch_iterator):
        if step % 50 == 0:
          if step % args.gradient_accumulation_steps != 0:
            print("Gradient loss!!!")
          
          log_add = 'Step: {}; validation loss: {}'.format(
            epoch*len(finetune_dataset) + step,
            evaluate(args, model, eval_dataset))
          log_text += log_add + '\n'
          print(log_add)
          my_file = open(log_file, 'w')
          my_file.write(log_text)
          my_file.close()

        inputs, labels = batch['article'].to(args.device), batch['article'].to(args.device)

        # print(inputs.shape)
        # print(tokenizer.decode(list(inputs[0])))

        model.train()
        logits = model(inputs)[0]
        shift_logits = logits[..., batch['sum_idx']:-1, :].contiguous()
        shift_labels = labels[..., batch['sum_idx']+1:].contiguous()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        sumloss += float(loss.item())
        loss = loss/args.gradient_accumulation_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

        if (step + 1) % args.gradient_accumulation_steps == 0:
          optimizer.step()
          scheduler.step()  # Update learning rate schedule
          model.zero_grad()
              
        if (step + 1) % 10 == 0:
          log_add = 'Step: {}; loss: {}'.format(
            epoch*len(finetune_dataset)+step+1, sumloss/10)
          sumloss = 0.0
          log_text += log_add + '\n'
          print(log_add)
          my_file = open(log_file, 'w')
          my_file.write(log_text)
          my_file.close()
    return log_text

In [6]:
model_file_load = os.path.join(args.model_dir,  'model_data_summarization_{}_lr{}_accum{}.bin'.format(
          gptname, int(lrmult*100), accumn))
config_file_load = os.path.join(args.model_dir, 'config_data_summarization_{}_lr{}_accum{}.json'.format(
          gptname, int(lrmult*100), accumn))

config = GPT2Config.from_json_file(config_file_load)
tokenizer = get_tokenizer(gptname)
model = GPT2LMHeadModel(config)
state_dict = torch.load(model_file_load, map_location=args.device)
_ = model.load_state_dict(state_dict)
_ = model.to(args.device)
print(len(FineTuningDataset(filepath + '/dataset', tokenizer, mode='train', addnews=enablenews)))
finetune_data = FineTuningDataset(filepath + '/dataset', tokenizer, mode='train', addnews=enablenews, length=200)
if enable_trash:
    trash_data = FineTuningDataset(filepath + '/dataset', tokenizer, mode='train', addnews=enablenews, length=320, trash=True)
eval_data = FineTuningDataset(filepath + '/dataset', tokenizer, mode='valid', addnews=enablenews, length=10)

210


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Sergei\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [7]:
start = time.time()
if enable_trash:
    finetune(
        args, model, tokenizer, trash_data, eval_data,
        model_file, config_file, log_file)
log_text = finetune(
    args, model, tokenizer, finetune_data, eval_data,
    model_file, config_file, log_file)
log_add = 'Step: {}; validation loss: {}'.format(
    args.num_train_epochs*len(finetune_data),
    evaluate(args, model, eval_data))
log_text += log_add + '\n'
print(log_add)
time_text = 'Total time: {} minutes.'.format((time.time()-start)/60)
log_text += time_text + '\n'
print(time_text)
my_file = open(log_file, 'w')
my_file.write(log_text)
my_file.close()

print('Saving trained model...')
torch.save(model.state_dict(), model_file)
model.config.to_json_file(config_file)
print('Saved.')

Training:   0%|          | 0/320 [00:00<?, ?it/s]

Step: 0; validation loss: 2.017832398414612
Step: 10; loss: 2.744695484638214
Step: 20; loss: 2.39229496717453
Step: 30; loss: 2.0948919653892517
Step: 40; loss: 2.0456962108612062
Step: 50; loss: 1.905524206161499
Step: 50; validation loss: 1.4457423031330108
Step: 60; loss: 2.0527047753334045
Step: 70; loss: 1.7390436947345733
Step: 80; loss: 2.0440563321113587
Step: 90; loss: 1.7129560589790345
Step: 100; loss: 1.9657523393630982
Step: 100; validation loss: 1.4232341051101685
Step: 110; loss: 1.9156450986862184
Step: 120; loss: 1.4878210186958314
Step: 130; loss: 1.8335509598255157
Step: 140; loss: 2.1168410062789915
Step: 150; loss: 1.4607580363750459
Step: 150; validation loss: 1.4426957011222838
Step: 160; loss: 1.9130440831184388
Step: 170; loss: 1.784747564792633
Step: 180; loss: 1.7393852949142456
Step: 190; loss: 1.7665969729423523
Step: 200; loss: 1.5732300579547882
Step: 200; validation loss: 1.4460708558559419
Step: 210; loss: 1.827338981628418
Step: 220; loss: 1.599316740

Training:   0%|          | 0/200 [00:00<?, ?it/s]

Step: 0; validation loss: 1.383796513080597
Step: 10; loss: 1.6835374236106873
Step: 20; loss: 1.554909771680832
Step: 30; loss: 1.7516204297542572
Step: 40; loss: 1.6227215230464935
Step: 50; loss: 1.1976910412311554
Step: 50; validation loss: 1.384033751487732
Step: 60; loss: 1.7024429589509964
Step: 70; loss: 1.4106873512268066
Step: 80; loss: 1.4218810975551606
Step: 90; loss: 1.414726510643959
Step: 100; loss: 1.6808338314294815
Step: 100; validation loss: 1.381706714630127
Step: 110; loss: 1.4146191239356996
Step: 120; loss: 1.7321368098258971
Step: 130; loss: 1.5767818987369537
Step: 140; loss: 1.6718943536281585
Step: 150; loss: 1.4519405961036682
Step: 150; validation loss: 1.3609151184558868
Step: 160; loss: 1.6389724373817445
Step: 170; loss: 1.769489012658596
Step: 180; loss: 1.4944372475147247
Step: 190; loss: 1.5277702450752257
Step: 200; loss: 1.7421887338161468
Step: 200; validation loss: 1.3520565748214721
Total time: 2037.226572517554 minutes.
Saving trained model...
