In [1]:
# gptname = 'gpt2' #small
gptname = 'gpt2-xl'
external_save = True
colab = False

# best for gpt2 large
# lrmult = 0.5

# best for gpt2 xl
lrmult = 0.1

accumn = 10
devicename = 'cpu'

if not colab:
    filepath = '.'
else:
    filepath = '/content/drive/MyDrive/bio'
    ! pip3 install transformers
    from google.colab import drive
    drive.mount('/content/drive')
 
import sys
sys.path.append(filepath + '/utils')

In [2]:
import argparse, json, os, pickle, random, time, numpy as np, transformers, torch

from transformers import GPT2Config, GPT2LMHeadModel,AdamW, GPT2Tokenizer
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm.notebook import tqdm

from utils import SummarizationDataset, FineTuningDataset, get_tokenizer, generate_sample, sample_seq, set_seed, top_k_top_p_filtering

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--lr",default=5e-5, type=float, help="learning rate")
parser.add_argument("--seed",default=42, type=int,  help="seed to replicate results")
parser.add_argument("--n_gpu",default=1, type=int,  help="no of gpu available")
parser.add_argument("--gradient_accumulation_steps",default=accumn, type=int, help="gradient_accumulation_steps")
parser.add_argument("--batch_size",default=1, type=int,  help="batch_size")
parser.add_argument("--num_workers",default=2, type=int,  help="num of cpus available")
parser.add_argument("--device",default=torch.device(devicename), help="torch.device object")
parser.add_argument("--num_train_epochs",default=1, type=int,  help="no of epochs of training")
if external_save:
    parser.add_argument("--model_dir", default='D:/external_weights', type=str,  help="path to save trained model")
else:
    parser.add_argument("--model_dir",default= filepath + '/weights', type=str,  help="path to save trained model")
parser.add_argument("--max_grad_norm",default=1.0, type=float, help="max gradient norm.")
parser.add_argument("--root_dir",default= filepath + '/bignews/gpt2_1024_data', type=str, help="location of json dataset.")
parser.add_argument("--ids_file",default= filepath + '/bignews/ids.json', type=str, help="location of train, valid and test file indexes")
args = parser.parse_args([])

model_text = '_data_summarization_{}_lr{}_accum{}'.format(gptname, int(lrmult*100), accumn)
model_file = os.path.join(args.model_dir, 'model' + model_text + '.bin')
config_file = os.path.join(args.model_dir, 'config' + model_text + '.json')
log_file = os.path.join(args.model_dir, 'log' + model_text + '.txt')

In [4]:
def train(args, model, tokenizer, train_dataset, model_file, config_file, log_file):
  model.train()
  train_sampler = RandomSampler(train_dataset)
  train_dl = DataLoader(
    train_dataset, sampler=train_sampler,
    batch_size=args.batch_size, num_workers=args.num_workers)
  loss_fct = CrossEntropyLoss()
  optimizer = AdamW(model.parameters(), lr=args.lr*lrmult)
  # best for gpt2 large
  # scheduler = transformers.get_constant_schedule_with_warmup(optimizer, 100)
  scheduler = transformers.get_constant_schedule_with_warmup(optimizer, 10)
  
  sumloss = 0.0
  log_text = ''
 
  model.zero_grad()
  set_seed(args)
  epoch_iterator = tqdm(train_dl, desc="Training")
  for step, batch in enumerate(epoch_iterator):
    inputs, labels = batch['article'].to(args.device), batch['article'].to(args.device)
    logits = model(inputs)[0]
    shift_logits = logits[..., batch['sum_idx']:-1, :].contiguous()
    shift_labels = labels[..., batch['sum_idx']+1:].contiguous()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    sumloss += float(loss.item())
    loss /= args.gradient_accumulation_steps
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
    
    if (step + 1) % args.gradient_accumulation_steps == 0:
      optimizer.step()
      scheduler.step()  # Update learning rate schedule
      model.zero_grad()
      
    if (step + 1) % 200 == 0:
      log_add = 'Step: {}; loss: {}'.format(step+1, sumloss/200)
      sumloss = 0.0
      log_text += log_add + '\n'
      print(log_add)
      my_file = open(log_file, 'w')
      my_file.write(log_text)
      my_file.close()

      torch.save(model.state_dict(), model_file)
      model.config.to_json_file(config_file)
  
  return log_text 

In [5]:
# load pretrained GPT2
tokenizer = get_tokenizer(gptname)
model = GPT2LMHeadModel.from_pretrained(gptname)
_ = model.resize_token_embeddings(len(tokenizer))
_ = model.to(args.device)
train_data = SummarizationDataset(
    args.root_dir, args.ids_file, tokenizer)

In [6]:
start = time.time()
log_text = train(args, model, tokenizer, train_data, model_file, config_file, log_file)
time_text = 'Total time: {} minutes.'.format((time.time()-start)/60)
log_text += time_text + '\n'
print(time_text)
my_file = open(log_file, 'w')
my_file.write(log_text)
my_file.close()

print('Saving trained model...')
torch.save(model.state_dict(), model_file)
model.config.to_json_file(config_file)
print('Saved.')

Training:   0%|          | 0/4055 [00:00<?, ?it/s]

Step: 200; loss: 2.213261762857437
Step: 400; loss: 1.8163666179776192
Step: 600; loss: 1.637988668680191
Step: 800; loss: 1.5505161045491695
Step: 1000; loss: 1.5445746886730194
Step: 1200; loss: 1.6098970843851566
Step: 1400; loss: 1.519104880541563
Step: 1600; loss: 1.5559707699716092
Step: 1800; loss: 1.4829402489960193
Step: 2000; loss: 1.5120421469956637
Step: 2200; loss: 1.6211941618472339
Step: 2400; loss: 1.4789573391526938
Step: 2600; loss: 1.5224540742486716
Step: 2800; loss: 1.4764484215155245
Step: 3000; loss: 1.507630355609581
Step: 3200; loss: 1.494736626893282
Step: 3400; loss: 1.446859211428091
Step: 3600; loss: 1.5236246314668096
Step: 3800; loss: 1.4448084862693213
Step: 4000; loss: 1.5419006405724212
Total time: 3791.489837058385 minutes.
Saving trained model...
Saved.
