In [None]:
import sys
sys.path.append("../")
import transformers
transformers.logging.set_verbosity_error()

from functools import partial
from transformers import AutoTokenizer, AutoModelForMaskedLM
from src.data.dataio import DataFiles, Dataset, remove_empty_fn, truncate_fn

In [None]:
PRETRAINED_MODEL = 'distilroberta-base'

data_files = DataFiles.from_url_file(url_file="../data/books.txt")

dataset = Dataset(data_files)
dataset = dataset.map(remove_empty_fn)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL)
dataset = dataset.map(partial(truncate_fn, tokenizer=tokenizer, max_seq_length=20, fill_to_max=True))

for i, x in enumerate(dataset):
    print(x)
    if i >= 5:
        break

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers.data.data_collator import DataCollatorForLanguageModeling

collator = DataCollatorForLanguageModeling(tokenizer, mlm_probability=0.25)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL)
model.to(device)
model.train()

loader = DataLoader(dataset, batch_size=4)

optim = AdamW(model.parameters(), lr=5e-5)

# for epoch in range(1):
#     for i, batch in enumerate(loader):
#         optim.zero_grad()
#         batch = tokenizer(batch["text"], truncation=True, padding=True, return_special_tokens_mask=True, return_tensors="pt")
#         batch = batch.to(device)
#         attention_mask = batch["attention_mask"]
#         labels = batch['input_ids']
        
#         batch = collator(features=(batch,))
#         input_ids = batch["input_ids"].squeeze(0)
#         outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         print(loss)
#         loss.backward()
#         optim.step()

# model.eval()

In [None]:
import copy
import random
from abc import abstractmethod
from typing import List, Tuple
from collections import Counter


In [None]:
from src.models.masking import RandomMask, LengthBasedMask

randomMaskInstance = RandomMask(0.5, '<mask>')
randomMask = randomMaskInstance.mask

for i, x in enumerate(dataset):
    input_string = [x['text'].split()]
    print(input_string)
    print(randomMask(input_string))
    if i >= 5:
        break

lengthMaskInstance = LengthBasedMask(0.5, 'all', '<mask>')
lengthMask = lengthMaskInstance.mask

for i, x in enumerate(dataset):
    input_string = [x['text'].split()]
    print(input_string)
    print(lengthMask(input_string))
    if i >= 5:
        break

In [None]:
import torch

for i, x in enumerate(dataset):
    input_string = [x['text'].split()]
    print('input_string is', input_string)
    masked_tokens = lengthMask(input_string)[0]
    print('masked_tokens is', masked_tokens)
    masked_sentence = ' '.join(masked_tokens)
    print('masked_sentence is', masked_sentence)
    out = torch.argmax(torch.log_softmax(model(**tokenizer.batch_encode_plus([masked_sentence], return_tensors="pt"))["logits"], dim=-1), dim=-1)
    print('output is', tokenizer.batch_decode(out))
    # print(input_string)
    # print(randomMask(input_string))
    if i >= 5:
        break

In [None]:
import torch
out = torch.argmax(torch.log_softmax(model(**tokenizer.batch_encode_plus(["Montreal is a <mask> city, but Toronto is <mask>."], return_tensors="pt"))["logits"], dim=-1), dim=-1)
tokenizer.batch_decode(out)