In [1]:
!pip install transformers



In [2]:
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, GPT2Model
from transformers import get_linear_schedule_with_warmup


In [3]:
data_list = set()
with open('com.txt','r') as f:
    text = f.read()
    text  =text.split('\n')
    data_list.update(text)

In [4]:
print(len(data_list))
count = 0.0
max_len = 0
for s in data_list:
  lena = len(s.split())
  count += lena
  if lena > max_len:
    max_len = lena
print(count / len(data_list))
print(max_len)

668
13.47754491017964
41


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
max_length = 35

class ComplDataset(Dataset):
    def __init__(self, compl, tokenizer, length):
        self.compliments = []

        for sent in data_list:
            encod_dic = tokenizer('<SOS> ' + sent + ' <EOS>', truncation=True, max_length=length,
                                  padding='max_length')
            self.compliments.append(torch.tensor(encod_dic['input_ids']))

    def __len__(self):
        return len(self.compliments)

    def __getitem__(self, idx):
        return self.compliments[idx]

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<SOS>', eos_token='<EOS>', pad_token = '<EOS>')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
dataset = ComplDataset(data_list, tokenizer, max_length)

In [9]:
import numpy as np

train_ids, val_ids = train_test_split(
    np.arange(len(dataset)),
    test_size=0.1,
    shuffle=True)

train = torch.utils.data.SubsetRandomSampler(train_ids)
val = torch.utils.data.SubsetRandomSampler(val_ids)

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=train)
val_dataloader =  torch.utils.data.DataLoader(dataset, batch_size=1, sampler=val)

train_size = len(train_ids)
val_size = len(val_ids)

In [10]:
from torch.optim import AdamW

# Used documentation: https://huggingface.co/transformers/model_doc/gpt2.html

configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
model = GPT2LMHeadModel.from_pretrained('gpt2', config=configuration)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=3e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5000, num_training_steps = -1)

In [11]:
from tqdm import tqdm

c = 0
best = 100

for epoch in range(10):
    epoch_loss_train = 0
    model.train()
    for i, entity in enumerate(tqdm(train_dataloader)):
        c += 1
        input_ids = entity.to(device)

        outputs = model(input_ids, labels = input_ids)

        loss = outputs[0]

        batch_loss = loss.item()
        epoch_loss_train += batch_loss

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        model.zero_grad()

        if (c % 100 == 0):
          model.eval()
          epoch_loss_val = 0

          for entity in val_dataloader:
              with torch.no_grad():
                  input_ids = entity.to(device)
                  outputs = model(input_ids, labels = input_ids)
                  loss = outputs[0]
              batch_loss = loss.item()
              epoch_loss_val += batch_loss

          av_val_loss = epoch_loss_val / val_size

          if av_val_loss < best:
            print('Average val loss: {}'.format(av_val_loss))
            best = av_val_loss
            torch.save(model.state_dict(), 'models/GPT2.h5')
          
          model.train()


    print('Average train loss: {}'.format(epoch_loss_train / train_size))


 33%|███▎      | 99/301 [00:17<00:35,  5.74it/s]

Average val loss: 45.77679829099285


 66%|██████▌   | 199/301 [00:37<00:17,  5.76it/s]

Average val loss: 37.255958386321566


 99%|█████████▉| 299/301 [00:57<00:00,  5.78it/s]

Average val loss: 13.897107022911754


100%|██████████| 301/301 [01:00<00:00,  4.95it/s]


Average train loss: 17.487877554584067


 33%|███▎      | 98/301 [00:17<00:35,  5.77it/s]

Average val loss: 4.603687579952069


 66%|██████▌   | 198/301 [00:37<00:17,  5.76it/s]

Average val loss: 3.15096568883355


 99%|█████████▉| 298/301 [00:57<00:00,  5.76it/s]

Average val loss: 2.7454295967941853


100%|██████████| 301/301 [01:00<00:00,  4.97it/s]


Average train loss: 2.260743638243334


 32%|███▏      | 97/301 [00:16<00:35,  5.77it/s]

Average val loss: 2.433997559903273


 65%|██████▌   | 197/301 [00:37<00:18,  5.78it/s]

Average val loss: 2.276442594492613


 99%|█████████▊| 297/301 [00:57<00:00,  5.74it/s]

Average val loss: 2.0934685922380702


100%|██████████| 301/301 [01:00<00:00,  4.94it/s]


Average train loss: 1.296801243169534


 32%|███▏      | 96/301 [00:16<00:35,  5.75it/s]

Average val loss: 2.0199313306096776


 65%|██████▌   | 196/301 [00:36<00:18,  5.75it/s]

Average val loss: 1.9166038187582102


 98%|█████████▊| 296/301 [00:57<00:00,  5.78it/s]

Average val loss: 1.8446971413804525


100%|██████████| 301/301 [01:00<00:00,  4.96it/s]


Average train loss: 1.023288653962029


 32%|███▏      | 95/301 [00:16<00:35,  5.77it/s]

Average val loss: 1.7880815863609314


 65%|██████▍   | 195/301 [00:36<00:18,  5.76it/s]

Average val loss: 1.7606081993722205


 98%|█████████▊| 295/301 [00:56<00:01,  5.78it/s]

Average val loss: 1.7511682835087847


100%|██████████| 301/301 [01:00<00:00,  4.98it/s]


Average train loss: 0.9123176918450291


 31%|███       | 94/301 [00:16<00:36,  5.75it/s]

Average val loss: 1.6810610272101503


 64%|██████▍   | 194/301 [00:36<00:18,  5.72it/s]

Average val loss: 1.6632352151087861


 98%|█████████▊| 294/301 [00:57<00:01,  5.74it/s]

Average val loss: 1.619646264990764


100%|██████████| 301/301 [01:01<00:00,  4.92it/s]


Average train loss: 0.8432074449026644


 64%|██████▍   | 193/301 [00:34<00:18,  5.74it/s]

Average val loss: 1.6171343820308572


 97%|█████████▋| 293/301 [00:54<00:01,  5.77it/s]

Average val loss: 1.6132736125988747


100%|██████████| 301/301 [00:59<00:00,  5.10it/s]


Average train loss: 0.7802560542169308


 31%|███       | 92/301 [00:15<00:36,  5.76it/s]

Average val loss: 1.5933099776061612


 64%|██████▍   | 192/301 [00:36<00:18,  5.76it/s]

Average val loss: 1.584849691213067


 97%|█████████▋| 292/301 [00:56<00:01,  5.75it/s]

Average val loss: 1.5277713492735108


100%|██████████| 301/301 [01:00<00:00,  4.97it/s]


Average train loss: 0.7427779349332642


 30%|███       | 91/301 [00:15<00:36,  5.77it/s]

Average val loss: 1.508146572024075


 63%|██████▎   | 191/301 [00:35<00:19,  5.75it/s]

Average val loss: 1.4954226546323122


100%|██████████| 301/301 [00:58<00:00,  5.10it/s]


Average train loss: 0.7163445063914713


100%|██████████| 301/301 [00:56<00:00,  5.37it/s]

Average train loss: 0.6699844568025649





In [14]:
state_dict = torch.load('models/GPT2.h5')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [15]:
model.eval()
generated = torch.tensor(tokenizer.encode('<SOS>')).unsqueeze(0)
generated = generated.to(device)

print(generated)

sample_outputs = model.generate(
                                generated, 
                                do_sample=True,   
                                top_k=50, 
                                max_length =35,
                                top_p=0.8, 
                                bos_token = '<SOS>',
                                eos_token = '<EOS>',
                                num_return_sequences=5
                                )


for i, sample_output in enumerate(sample_outputs):
    print("{}: {}\n\n".format(i, tokenizer.decode(sample_output, skip_special_tokens=True).replace('\n',' ')))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[50257]], device='cuda:0')
0:  that you have been with me since daybreak has been the greatest joy of my life 


1:  a moment of sunshine is enough for me to smile and to inspire you 


2:  if i say this word, i will not stop loving you and caring for you forever 


3:  that i am with you, it is an honor to be with you 


4:  in your heart, i feel like you are my guardian angel 


