In [1]:
from transformer.config import FetchFromPretrained as ConfigFromPretrained

In [2]:
from transformer.tokenizer import FetchFromPretrained as TokenizerFromPretrained

In [3]:
from datasets import load_dataset

In [4]:
import torch

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device

'cuda'

In [6]:
data = load_dataset('AlekseyKorshuk/books')
data

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 741
    })
})

In [7]:
model_ckpt = 'bert-base-uncased'

In [8]:
config = ConfigFromPretrained(model_ckpt=model_ckpt).fetch()
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.39.3",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [9]:
config.hidden_size = 128
config.intermediate_size = 1024
config.compress_layer_size = 32
config.max_position_embeddings = 512
config.num_attention_heads = 16
config.num_hidden_layers = 16
config.device = device
config.dtype = torch.float32

In [10]:
batch_size = 8

In [11]:
tokenizer = TokenizerFromPretrained(model_ckpt=model_ckpt).fetch()
tokenizer

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [12]:
import pandas as pd

In [13]:
import matplotlib.pyplot as plt

In [14]:
import seaborn as sns

In [15]:
import numpy as np

In [16]:
import os

In [17]:
import pickle

In [18]:
import random

In [19]:
import re

In [20]:
files_count = len(data['train'])
files_count

741

In [21]:
split_index_1, split_index_2 = int(files_count * 0.8), int(files_count * 0.9)
split_index_1, split_index_2 

(592, 666)

In [22]:
def get_data_from_random_files(split):
    indices = []
    if split == 'train':
        file_id = random.randint(0, split_index_1)
    elif split == 'test':
        file_id = random.randint(split_index_1, split_index_2)
    elif split == 'val':
        file_id = random.randint(split_index_2, files_count)

    file = data['train'][file_id]['text']

    input_ids = tokenizer(file, return_tensors='pt', add_special_tokens=False).input_ids[0]
    return input_ids

In [23]:
def get_batch(split):
    input_ids = get_data_from_random_files(split=split)
    ix = torch.randint(len(input_ids) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([input_ids[i:i + config.max_position_embeddings] for i in ix])
    y = torch.stack([input_ids[i + 1:i + config.max_position_embeddings + 1] for i in ix])
    x, y = x.to(torch.int64).to(device), y.to(torch.int64).to(device)
    return x, y

In [24]:
get_data_from_random_files('train')

Token indices sequence length is longer than the specified maximum sequence length for this model (1665100 > 512). Running this sequence through the model will result in indexing errors


tensor([ 3342,  2008,  1010,  ..., 12316,  9681,  1012])

In [25]:
get_batch('train')

(tensor([[ 2075,  1010,  7653,  ...,  3238,  1010,  2021],
         [ 7168, 13061, 17326,  ...,  2067,  2026,  2214],
         [ 3432,  2938,  1010,  ...,  2002,  2001,  2763],
         ...,
         [ 2026,  8102,  1012,  ..., 17081,  1012,  2016],
         [ 6497,  2408,  1996,  ...,  2063,  1012,  1036],
         [ 1037, 14401,  1036,  ...,  1005,  1049,  2025]], device='cuda:0'),
 tensor([[ 1010,  7653,  2008,  ...,  1010,  2021,  2014],
         [13061, 17326, 18968,  ...,  2026,  2214,  2814],
         [ 2938,  1010,  4895,  ...,  2001,  2763,  2074],
         ...,
         [ 8102,  1012,  2043,  ...,  1012,  2016,  3368],
         [ 2408,  1996,  5532,  ...,  1012,  1036,  1036],
         [14401,  1036,  1036,  ...,  1049,  2025,  2469]], device='cuda:0'))

In [26]:
from transformer.head.text_generator_decoder_only import TextGenerator

In [27]:
generator = TextGenerator(config)

In [28]:
gen = generator.to(device)

In [29]:
eval_iters = 100

In [30]:
eval_iter_ticks = 5

In [31]:
X, Y = get_batch('train')
logits, loss = gen(X, Y, tokenizer)
logits, logits.size(), loss

(tensor([[[-0.3887, -0.1111,  0.6715,  ...,  0.4277, -0.5717, -0.1654],
          [ 0.3556,  1.0960,  0.6213,  ..., -0.1883, -0.4950,  0.4808],
          [-0.6133,  0.0990, -0.2922,  ..., -0.2802, -0.5862,  0.0142],
          ...,
          [-0.5669, -0.3371, -0.4981,  ...,  0.0988,  0.2552,  0.6113],
          [ 0.0875, -0.0600, -0.3122,  ..., -0.9416,  0.3699,  0.0823],
          [-0.4890,  0.2113,  1.3513,  ..., -0.0512, -0.7785, -0.0725]],
 
         [[ 0.5806,  0.1446,  0.4918,  ...,  0.3851, -0.6560,  0.5583],
          [ 0.3842,  0.7758, -0.5038,  ..., -1.0940, -1.0805,  0.5588],
          [-0.2339, -0.3715,  0.7282,  ..., -0.2762,  0.3370,  0.3887],
          ...,
          [-0.9291,  0.9337, -0.0767,  ..., -0.1166, -0.2772,  0.3225],
          [-0.1502,  0.0225,  0.4395,  ..., -0.0625, -0.9919,  0.5122],
          [-0.1836,  0.3707,  0.4989,  ..., -0.0991, -0.0219,  0.7326]],
 
         [[ 0.0176, -0.7595, -0.0051,  ...,  1.2122, -0.1668, -0.1664],
          [ 0.0707, -0.0090,

In [32]:
pretrained = False

In [33]:
build_pretrained = False

In [42]:
def test(prompt='She was going to '):
    prompt_ids = tokenizer(prompt, return_tensors='pt', add_special_tokens=False).input_ids
    generated_ids = gen.generate(prompt_ids.to(device), max_new_tokens=20)
    result = tokenizer.decode(generated_ids[0])
    print(result)

In [35]:
def test_deep(logits, targets):
    # focus only on the last time step
    logits = logits[:, -1, :] # becomes (B, C)
    # apply softmax to obtain probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1).to(config.device) # (B, C)
    index_next = torch.multinomial(probs, num_samples=1) # (B, C)

    print(tokenizer.decode(index_next.squeeze(-1)))
    print(tokenizer.decode(targets[:, -1]))

In [36]:
@torch.no_grad()
def estimate_loss():
    out = {}
    if (not pretrained) or build_pretrained:
        with open('text_generator.pkl', 'wb') as handler:
            pickle.dump(gen, handler)
    gen.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iter_ticks):
            X, Y = get_batch(split)
            logits, loss = gen(X, Y)
            losses[k] = loss.item()
            test_deep(logits, Y)
        out[split] = losses.mean()
    gen.train()
    return out
    

In [37]:
learning_rate = 5e-4

In [38]:
iterations = 1000000

In [39]:
if build_pretrained:
    with open('text_generator.pkl', 'wb') as handler:
        pickle.dump(gen, handler)

In [40]:
if (not pretrained) or build_pretrained:
    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(gen.parameters(), lr=learning_rate, weight_decay=0.01, amsgrad=True)

    for iter in range(iterations):
        try:
            
            if (iter != 0) and (iter % eval_iters == 0):
                losses = estimate_loss()
                train_loss = losses['train']
                val_loss = losses['val']
                print(f'Loss at step = {iter} for train data is {train_loss:.4f} for val it is {val_loss:.4f}')

            # sample a batch of data
            xb, yb = get_batch('train')

            # evaluate the loss
            logits, loss = gen.forward(xb, yb, tokenizer)

            if not torch.isnan(loss).any():
                test_deep(logits, yb)
                print(iter, loss)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                # torch.torch.nn.utils.clip_grad_norm_(gen.parameters(), max_norm=15.0)
                optimizer.step()
            else:
                print('NaN loss')
                print(iter, loss)
                input()
        except Exception as e:
            print(e)

spectacular burkina skatingeth umm authorship τ rosalie
' lieutenant know told su what get threatening
0 tensor(10.4805, device='cuda:0', grad_fn=<NllLossBackward0>)
acidicwe spaces caroliness mai 智 artist
nodded ever around movie get you cuffs talk
1 tensor(10.2242, device='cuda:0', grad_fn=<NllLossBackward0>)
##ท portssibility apartheid boardsvre [unused326] norm
onely to the smiled rental on it
2 tensor(10.0011, device='cuda:0', grad_fn=<NllLossBackward0>)
homesteadtina faerie approachedfalls allowed theatre singers
family favors that hoped himt we,
3 tensor(9.9684, device='cuda:0', grad_fn=<NllLossBackward0>)
jinidae nouvelle overturned realises charcoal panther nair
those should. us ask him down forced
4 tensor(9.4289, device='cuda:0', grad_fn=<NllLossBackward0>)
hit ramsey cautious tired soothe kneesren.
out stripes hang who'am could wish
5 tensor(9.4832, device='cuda:0', grad_fn=<NllLossBackward0>)
administer flora bearer concluding strong damian defense seal
his captain are tim

KeyboardInterrupt: 

In [46]:
test('Summer season is upon us and you are all set to head to the beach or to an outdoor event. Outfit. Check, Sunglasses. Check. Hat. Check. Sunscreen? Well, check the ingredients before applying it. While sunscreen is essential for protecting your skin')

summer season is upon us and you are all set to head to the beach or to an outdoor event. outfit. check, sunglasses. check. hat. check. sunscreen? well, check the ingredients before applying it. while sunscreen is essential for protecting your skin the cross doesnt fulfill the physical fresh book anymore. by the time get sucked, you cant


In [47]:
test('An apple is')

an apple is going on. we need you to know the thing you can if i can do the same. maybe


In [51]:
test('This was the greatest outcome for me because')

this was the greatest outcome for me because i'm here,'both men said. i'd already forgiven him on the lower side


In [53]:
sum(p.numel() for p in gen.parameters())

14286138

In [58]:
prompt = 'They wanted to fight and '
test(prompt)

they wanted to fight and be together. claws and butter on juvenile in was a continual texas worker sitting. ` ` did


In [59]:
prompt = 'He was walking on the middle of road when a car started coming at high speed behind him '
test(prompt)

he was walking on the middle of road when a car started coming at high speed behind him and aquaticing around the diner. at some point his formation had shook their beer hard by now as


In [None]:
if (not pretrained) or build_pretrained:
    with open('text_generator.pkl', 'wb') as handler:
        pickle.dump(gen, handler)

In [None]:
with open('text_generator.pkl', 'rb') as handler:
    gen = pickle.load(handler)

In [None]:
if (not pretrained) or build_pretrained:
    # create a PyTorch optimizer
    # optimizer = torch.optim.AdamW(gen.parameters(), lr=learning_rate, weight_decay=0.01, amsgrad=True)

    for iter in range(iterations):
        try:
            
            if (iter != 0) and (iter % eval_iters == 0):
                losses = estimate_loss()
                train_loss = losses['train']
                val_loss = losses['val']
                print(f'Loss at step = {iter} for train data is {train_loss:.4f} for val it is {val_loss:.4f}')

            # sample a batch of data
            xb, yb = get_batch('train')

            # evaluate the loss
            logits, loss = gen.forward(xb, yb, tokenizer)
            if not torch.isnan(loss).any():
                test_deep(logits, yb)
                print(iter, loss)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                # torch.torch.nn.utils.clip_grad_norm_(gen.parameters(), max_norm=15.0)
                optimizer.step()
            else:
                print('NaN loss')
                print(iter, loss)
                input()
        except Exception as e:
            print(e)