In [23]:
import torch
from operator import itemgetter
from torch.utils.data import DataLoader
import random
import numpy as np
import math
import os




from update_utilities import update_utilities_class
import pickle

# 1 - Data Preparation

## 1.1. Examining the Data

In [24]:
with open('lord-of-the-rings-processed.txt','r',encoding='utf-8') as f:
    text = f.read()



    

In [25]:
print(f"length of the book - {len(text)} characters")

length of the book - 3729059 characters


In [26]:
print(text[:100])

The Music of the Ainur There was Eru, the One, who in Arda is called lluvatar; and he made first the


## 1.2. Format Data

In [27]:
chars = sorted(list(set(text)))
print(chars)

[' ', '!', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '—']


In [30]:
common = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ();:.!?-,"
special = list (set(chars) - set(list(common)))
print(special)

["'", ' ', '—']


In [37]:
special = ['\n', ' ', '"', "'", '®', '—', '‘', '’', '“', '”']

In [38]:
text = text.replace("\n"," ")
text = text.replace("  ", " ")
text = text.replace("®", "u")

In [39]:
special_char = list(itemgetter(*[6,7,8,9])(special))
special_char.extend([",",";",":","!","?"])
special_char

['‘', '’', '“', '”', ',', ';', ':', '!', '?']

In [40]:
no_space_after = list(itemgetter(*[0,2])(special_char))
no_space_after

['‘', '“']

In [41]:
no_space_before = list(itemgetter(*set(range(len(special_char)))-set([0,2]))(special_char))
no_space_before

['’', '”', ',', ';', ':', '!', '?']

In [42]:
# replace such as <' sss> to <'sss>
for s in no_space_after:
    text = text.replace(s+" ", s)

# replace such as <s ,> to <s,>
for s in no_space_before:
    text = text.replace(" "+s,s)


In [43]:
# standardize the use of quotation marks
text = text.replace('"',"'")
text = text.replace('‘',"'")
text = text.replace('’',"'")
text = text.replace('“',"'")
text = text.replace('”',"'")

In [44]:
with open("lord-of-the-rings-processed.txt","w") as f:
    f.write(text)

## 1.3. Create Dictionary and Tokenize the Data

**tokenizer**

In [45]:
chars = sorted(list(set(text)))
common = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ();:.!?-,"
special = [char for char in chars if char not in list(common)]
print(special)

[' ', "'", '—']


In [46]:
# encode_char = {char:i for i, char in enumerate(chars)}
# decode_char = {i:char for i, char in enumerate(chars)}
# print(len(encode_char))
# vocab_size = len(encode_char)

In [47]:
# import pickle

# with open('saved_encoder_dict.pkl','wb') as f:
#     pickle.dump(encode_char,f)

# with open('saved_decoder_dict.pkl','wb') as f:
#     pickle.dump(decode_char,f)

In [48]:
with open('saved_encoder_dict.pkl','rb') as f:
    encode_char = pickle.load(f)

with open('saved_decoder_dict.pkl','rb') as f:
    decode_char = pickle.load(f)

print(len(encode_char))
vocab_size = len(encode_char)

74


In [49]:
encode = lambda string: [encode_char[s] for s in string]
decode = lambda nums: ''.join([decode_char[n] for n in nums])

In [50]:
encode("This is good")a

[40, 54, 55, 65, 0, 55, 65, 0, 53, 61, 61, 50]

In [51]:
decode([8,20,69,44,27])

'0?wXG'

## 1.4. Load data and construct batches + dataloaders

**take percentage of text from each book as validation data**

In [52]:
# update_utilities_class(file_name="general_functions.py",current_path=os.getcwd()).run()


In [53]:
from general_functions import HelperFunctionsClass
h = HelperFunctionsClass()

In [54]:
book1_train, book1_val, end_idx = h.train_test_split(text=text,ending="and an end was come for the Eldar of story and of song.",ratio=0.85,starting_idx=0)

In [55]:
book2_train, book2_val, end_idx = h.train_test_split(text=text, ending="and handed him the tobacco-jar.",ratio=0.85,starting_idx=end_idx)

In [56]:
book3_train, book3_val, end_idx = h.train_test_split(text=text, ending="THE RETURN OF THE KING.",ratio=0.85,starting_idx=end_idx)

In [57]:
book4_train, book4_val, end_idx = h.train_test_split(text=text, ending="was alive but taken by the Enemy.",ratio=0.85,starting_idx=end_idx)

In [58]:
book5_train, book5_val, end_idx2 = h.train_test_split(text=text, ending="I'm back,' he said.",ratio=0.85,starting_idx=end_idx)

In [59]:
train_data =  book3_train + book4_train + book5_train + book2_train + book1_train
val_data = book3_val + book4_val + book5_val + book2_val + book1_val

In [60]:
len(train_data + val_data)

3729058

In [61]:
train_data2 = torch.tensor(encode(train_data))
val_data2 = torch.tensor(encode(val_data))

In [62]:
len(train_data2), len(val_data2)

(3170090, 558968)

**dataset and dataloader**

In [63]:
# update_utilities_class(file_name="custom_text_dataset.py",current_path=os.getcwd()).run()

In [64]:
from custom_text_dataset import slideTokenizedTextDataset

In [65]:
block_size = 512

train_dataset = slideTokenizedTextDataset(full_txt = train_data2,
                                                 block_size = block_size)

val_dataset = slideTokenizedTextDataset(full_txt = val_data2,
                                               block_size = block_size)

In [66]:
len(train_dataset), len(val_dataset)

(3169578, 558456)

In [70]:
batch_size = 64
train_num_samples = 500000
train_sampler = torch.utils.data.RandomSampler(train_dataset,replacement=False,)
                                               #num_samples=train_num_samples)
train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,drop_last=True,sampler=train_sampler)

In [72]:
val_num_samples = 100000
val_sampler = torch.utils.data.RandomSampler(val_dataset,replacement=False,)
                                             # num_samples=val_num_samples)
val_dataloader = DataLoader(dataset=val_dataset,batch_size=batch_size,sampler=val_sampler,drop_last=True)

In [73]:
len(train_dataloader), len(val_dataloader)

(49524, 8725)

# 2 - Model definition

In [28]:
# update_utilities_class(file_name="Transformer.py",current_path=os.getcwd()).run()

In [74]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [75]:
import Transformer
transformer = Transformer.TransformerClass(vocab_size=vocab_size,emb_dim=512,n_layer=8,num_heads=8,block_size=block_size,
                               dropout_rate_attention=0.1,dropout_rate_ff=0.2,dropout_rate_pos_enc=0.1, 
                               is_decoder = True, ff_multiplier = 4).to(device)


In [79]:
print(round(sum(p.numel() for p in transformer.parameters())/1e6,2), 'M parameters')

25.28 M parameters


In [83]:
os.chdir(rf"/Users/seangao/Downloads/Lord-of-the-Rings-SLM-main")
transformer.load_state_dict(torch.load("base_line_GPT stats/base_line_GPT - Training Information.txt"))

UnpicklingError: invalid load key, '\x0a'.

In [45]:
# update_utilities_class(file_name="loss_functions.py",current_path=os.getcwd()).run()

File copied, now the file is available to import from the destinated path


In [33]:
# update_utilities_class(file_name="train_test_loop.py",current_path=os.getcwd()).run()

File already exist in destination folder, it is now removed
File copied, now the file is available to import from the destinated path


In [84]:
from train_test_loop import train_test_loop_class
optimizer = torch.optim.AdamW(transformer.parameters(),lr=1e-5)
overwrite=False

train_loop = train_test_loop_class(model=transformer,train_loader=train_dataloader,val_loader=val_dataloader,test_loader=None, epochs=1, print_every_n_batch=500,
                                   device=device,model_name="base_line_GPT",optimizer=optimizer,calculate_accuracy=False,overwrite_message=overwrite, problem_type = "Multiclass Classification",
                                   update_loss_fn=False, print_result = True, print_full = False, lr_rate_tuning=False,clip_batch=False,clip_batch_size=20,lr_start=-5,lr_end=-2)

In [85]:
train_loop.train()

  0%|          | 0/49524 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [54]:
# train_loop.overwrite_message = False
# optimizer = torch.optim.AdamW(transformer.parameters(),lr=5e-6)
# train_loop.optimizer = optimizer

In [56]:
# refresh the sampling
train_loop.train_loader = train_dataloader
train_loop.val_loader = val_dataloader