In [1]:
!pip install transformers



In [2]:
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, GPT2Model
from transformers import get_linear_schedule_with_warmup


In [3]:
data_list = set()
with open('compl.txt','r') as f:
    text = f.read()
    text  =text.split('\n')
    data_list.update(text)

In [4]:
print(len(data_list))
count = 0.0
max_len = 0
for s in data_list:
  lena = len(s.split())
  count += lena
  if lena > max_len:
    max_len = lena
print(count / len(data_list))
print(max_len)

663
13.470588235294118
41


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
max_length = 25

class ComplDataset(Dataset):
    def __init__(self, compl, tokenizer, length):
        self.compliments = []

        for sent in data_list:
            encod_dic = tokenizer('<SOS> ' + sent + ' <EOS>', truncation=True, max_length=length,
                                  padding='max_length')
            self.compliments.append(torch.tensor(encod_dic['input_ids']))

    def __len__(self):
        return len(self.compliments)

    def __getitem__(self, idx):
        return self.compliments[idx]

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<SOS>', eos_token='<EOS>', pad_token = '<EOS>')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
dataset = ComplDataset(data_list, tokenizer, max_length)

In [9]:
import numpy as np

train_ids, val_ids = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    shuffle=True)

train = torch.utils.data.SubsetRandomSampler(train_ids)
val = torch.utils.data.SubsetRandomSampler(val_ids)

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=train)
val_dataloader =  torch.utils.data.DataLoader(dataset, batch_size=1, sampler=val)

train_size = len(train_ids)
val_size = len(val_ids)

In [10]:
from torch.optim import AdamW

# Used documentation: https://huggingface.co/transformers/model_doc/gpt2.html

configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
model = GPT2LMHeadModel.from_pretrained('gpt2', config=configuration)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=3e-3, eps = 1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5000, num_training_steps = -1)

In [11]:
from tqdm import tqdm

c = 0
best = 100

for epoch in range(100):
    epoch_loss_train = 0
    model.train()
    for i, entity in enumerate(tqdm(train_dataloader)):
        c += 1
        input_ids = entity.to(device)

        outputs = model(input_ids, labels = input_ids)

        loss = outputs[0]

        batch_loss = loss.item()
        epoch_loss_train += batch_loss

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        model.zero_grad()

        if (c % 100 == 0):
          model.eval()
          epoch_loss_val = 0

          for entity in val_dataloader:
              with torch.no_grad():
                  input_ids = entity.to(device)
                  outputs = model(input_ids, labels = input_ids)
                  loss = outputs[0]
              batch_loss = loss.item()
              epoch_loss_val += batch_loss

          av_val_loss = epoch_loss_val / val_size

          if av_val_loss < best:
            print('Average val loss: {}'.format(av_val_loss))
            best = av_val_loss
            torch.save(model.state_dict(), 'models/GPT2.h5')
          
          model.train()

    torch.save(model.state_dict(), 'models/GPT2-train.h5')
    print('Average train loss: {}'.format(epoch_loss_train / train_size))


 37%|███▋      | 99/265 [00:18<00:30,  5.44it/s]

Average val loss: 2.8853331006559215


 75%|███████▌  | 199/265 [00:41<00:12,  5.44it/s]

Average val loss: 2.173771455771941


100%|██████████| 265/265 [00:57<00:00,  4.61it/s]


Average train loss: 4.484271156225565


 13%|█▎        | 34/265 [00:06<00:42,  5.46it/s]

Average val loss: 2.0447496583587244


 51%|█████     | 134/265 [00:28<00:24,  5.44it/s]

Average val loss: 1.9627350492584974


 88%|████████▊ | 234/265 [00:51<00:05,  5.44it/s]

Average val loss: 1.938529326055283


100%|██████████| 265/265 [01:01<00:00,  4.28it/s]


Average train loss: 0.9605991725651724


100%|██████████| 265/265 [00:54<00:00,  4.85it/s]


Average train loss: 0.7805764688635772


100%|██████████| 265/265 [00:57<00:00,  4.59it/s]


Average train loss: 1.5998601562009667


100%|██████████| 265/265 [00:57<00:00,  4.61it/s]


Average train loss: 1.400366287073999


100%|██████████| 265/265 [00:54<00:00,  4.85it/s]


Average train loss: 0.8493150084086184


100%|██████████| 265/265 [00:57<00:00,  4.60it/s]


Average train loss: 0.584467004436367


100%|██████████| 265/265 [00:57<00:00,  4.62it/s]


Average train loss: 0.5163109631470915


100%|██████████| 265/265 [00:54<00:00,  4.86it/s]


Average train loss: 0.43366229663480005


100%|██████████| 265/265 [00:57<00:00,  4.59it/s]


Average train loss: 0.42183804916885664


100%|██████████| 265/265 [00:57<00:00,  4.59it/s]


Average train loss: 0.7622163698920664


100%|██████████| 265/265 [00:54<00:00,  4.83it/s]


Average train loss: 0.48002403457209747


100%|██████████| 265/265 [00:57<00:00,  4.59it/s]


Average train loss: 0.3977814089577153


100%|██████████| 265/265 [00:57<00:00,  4.60it/s]


Average train loss: 0.3774192527217685


100%|██████████| 265/265 [00:54<00:00,  4.87it/s]


Average train loss: 0.4094827897143814


100%|██████████| 265/265 [00:57<00:00,  4.62it/s]


Average train loss: 0.44810959336892614


100%|██████████| 265/265 [00:57<00:00,  4.61it/s]


Average train loss: 0.5196779423164871


100%|██████████| 265/265 [00:54<00:00,  4.85it/s]


Average train loss: 0.4784615492483355


100%|██████████| 265/265 [00:57<00:00,  4.61it/s]


Average train loss: 0.48074653542266704


100%|██████████| 265/265 [00:57<00:00,  4.60it/s]


Average train loss: 0.39706166797089126


100%|██████████| 265/265 [00:54<00:00,  4.85it/s]


Average train loss: 0.3978948083969782


100%|██████████| 265/265 [00:57<00:00,  4.61it/s]


Average train loss: 0.39455254626161645


 13%|█▎        | 35/265 [00:06<00:42,  5.38it/s]


KeyboardInterrupt: ignored

In [35]:
state_dict = torch.load('models/GPT2.h5')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [40]:
model.eval()
generated = torch.tensor(tokenizer.encode('<SOS>')).unsqueeze(0)
generated = generated.to(device)

print(generated)

sample_outputs = model.generate(
                                generated, 
                                do_sample=True,   
                                top_k=30, 
                                max_length =25,
                                top_p=0.8, 
                                bos_token = '<SOS>',
                                eos_token = '<EOS>',
                                num_return_sequences=1
                                )


for i, sample_output in enumerate(sample_outputs):
    print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True).replace('\n',' ')))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[50257]], device='cuda:0')
0:  the only thing that was missing from the experience was my sense of humor, the joy that came from knowing that you were




In [None]:
from google.colab import drive
drive.mount('/content/drive')