In [None]:
import torch

In [None]:
%pip install datasets

In [None]:
import datasets

In [None]:
all_ds = datasets.list_datasets()
len(all_ds)

In [None]:
'oscar' in all_ds

In [None]:
la_dataset = datasets.load_dataset('oscar', 'unshuffled_deduplicated_la')

In [None]:
la_dataset

In [None]:
la_dataset['train'][0]

In [None]:
from tqdm.auto import tqdm

text_data = []
file_count = 0

for sample in tqdm(la_dataset['train']):
    sample = sample['text'].replace('\n', ' ')
    text_data.append(sample)
    if len(text_data) == 10_000:
        with open(f'la_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))
        text_data = []
        file_count += 1
with open(f'la_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))

In [None]:
from pathlib import Path

In [None]:
paths = [str(x) for x in Path('./').glob('*.txt')]
paths[:5]

In [None]:
# Training the tokenizer

#BPE Tokenizer we are breaking into bytes. We don't need [UNK] tokenizer for unknown tokens

%pip install -update tokenizers

In [None]:
from tokenizers import ByteLevelBPETokenizer

tokenizer = ByteLevelBPETokenizer()

In [None]:
tokenizer.train(files = paths[:100], vocab_size=30_522, min_frequency=2,
                special_tokens=['<s>', '<pad>', '</s>', '<unk>', '<mask>'])

In [None]:
import os
os.mkdir('latintokens')

In [None]:
tokenizer.save_model('latintokens')

In [None]:
%pip install transformers

In [None]:
from transformers import RobertaTokenizerFast

In [None]:
rb_tokenizer = RobertaTokenizerFast.from_pretrained('latintokens')

In [None]:
rb_tokenizer('quam hi sunt, u?', padding='max_length', max_length=12)

In [None]:
# labels == input_ids

# input_ids -> MLM (Mask Language Modeling Function)

import torch

def mlm(tensor):
    rand = torch.rand(tensor.shape) #[0, 1]
    mask_arr = (rand < 0.15) * (tensor > 2) # Special tokens 0, 1, 2
    
    for i in range(tensor.shape[0]):
        selection = torch.flatten(mask_arr[i].nonzero())
        tensor[i, selection] = 4
    return tensor

In [None]:
from pathlib import Path

paths = [str(x) for x in Path('./').glob('*.txt')]
paths[:5]

In [None]:
from tqdm.auto import tqdm

input_ids = []
mask = []
labels = []

for path in tqdm(paths):
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.read().split('\n')
    sample = rb_tokenizer(lines, max_length=512, padding='max_length', truncation=True, return_tensors='pt')
    labels.append(sample.input_ids)
    mask.append(sample.attention_mask)
    input_ids.append(mlm(sample.input_ids.detach().clone()))

In [None]:
input_ids = torch.cat(input_ids)
mask = torch.cat(mask)
labels = torch.cat(labels)

In [None]:
input_ids[0][:10]


In [None]:
encodings = {
    'input_ids': input_ids,
    'attention_mask': mask,
    'labels': labels
}

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return self.encodings['input_ids'].shape[0]
    
    def __getitem__(self, i):
        return{key: tensor[i] for key, tensor in self.encodings.items()}

In [None]:
dataset = Dataset(encodings)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
from transformers import RobertaConfig

In [None]:
config = RobertaConfig(
    vocab_size=rb_tokenizer.vocab_size,
    max_position_embeddings=514,
    hidden_size=768,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1
    )

In [None]:
from transformers import RobertaForMaskedLM

In [None]:
model = RobertaForMaskedLM(config=config)

In [None]:
device = 'mps' if torch.backends.mps.is_available() else 'CPU'
print(device)

In [None]:
model.to(device)

In [None]:
from transformers import AdamW

In [None]:
model.train()

In [None]:
optim = AdamW(model.parameters(), lr=1e-4)

In [None]:
from tqdm.auto import tqdm

epochs = 1

In [None]:
loop = tqdm(dataloader, leave=True)

for batch in loop:
    optim.zero_grad()
    input_ids = batch['input_ids'].to(device)
    mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    outputs = model(input_ids, attention_mask=mask, labels=labels)
    
    loss = outputs.loss
    loss.backwards()
    optim.step()
    
    loop.set_description(f'Epoch: {epochs}')
    loop.set_postfix(loss=loss.item())