In [None]:
%run bert.ipynb

In [3]:
import transformers
from tqdm.notebook import tqdm
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

In [57]:
def make_bert():
    my_bert = Bert(
        vocab_size=28996, hidden_size=384, max_position_embeddings=512, 
        type_vocab_size=2, dropout=0.1, intermediate_size=1536, 
        num_heads=6, num_layers=2, num_classes=2
    )
    return my_bert

In [5]:
import torchtext
import random
from torch.utils.data import DataLoader
data_train, data_valid, data_test = torchtext.datasets.WikiText2(root='.data', split=('train', 'valid', 'test'))

data_train_list = list(data_train)

train_dataloader = DataLoader(data_train_list, batch_size=16, shuffle=True)
test_dataloader = DataLoader(data_test, batch_size=16)
valid_dataloader = DataLoader(data_test, batch_size=16)

In [58]:
bert = make_bert().cuda()
lossfn = nn.CrossEntropyLoss()

In [59]:
lr = 1e-4
optimizer = torch.optim.Adam(bert.parameters(), lr)

In [None]:
import matplotlib.pyplot as plt
from IPython import display
import datetime
def train(model, optimizer, data, lossfn, epochs=1, max_seq_len=512):
    model.train()
    avg_losses = []
    loss_buffer = []
    i = 0
    for epoch in range(epochs):
        for X in tqdm(data):
            optimizer.zero_grad()
            
            # token processing
            tokens = tokenizer(X, padding='longest', max_length=max_seq_len, truncation=True)
            # the original input
            unmasked_tokens = torch.tensor(tokens.input_ids, dtype=torch.long).cuda()
            zero_tokens = unmasked_tokens == 0
            rand_nums = torch.rand(unmasked_tokens.shape).cuda() <= 0.15

            masked_tokens = unmasked_tokens.clone()
            masked_tokens[rand_nums] = tokenizer.mask_token_id
            masked_tokens[zero_tokens] = 0

            output, _classifications = model(masked_tokens)
            masked_mask = masked_tokens == tokenizer.mask_token_id
            #expected_output_at_masks = torch.masked_select(unmasked_tokens,    masked_mask)
            expected_output_at_masks = unmasked_tokens[masked_mask]
            unnormed_probs_at_masks = output[masked_mask]
            loss = lossfn(unnormed_probs_at_masks, expected_output_at_masks) 

            
            #print(unmasked_tokens)
            #print(masked_tokens)
            #print(output)
            #print(expected_output_at_masks)
            #print(unnormed_probs_at_masks)
            #break
            
            loss.backward()
            optimizer.step()
            loss_buffer.append(loss)
            i += 1
            if i % 100 == 0:
                avg_loss = sum(loss_buffer) / len(loss_buffer)
                avg_losses.append(avg_loss)
                loss_buffer.clear()
                if i % 5000 == 0:
                    time_string = datetime.datetime.now().strftime('%y%m%d-%H-%M')
                    torch.save(bert, f"saved_tiny_bert/{time_string}_loss{avg_loss}.pt")          
                
                print(f"{i=} {epoch=} loss={avg_loss}")
                plt.plot(avg_losses)
                display.display(plt.gcf())
                display.clear_output(wait=True)
                             

train(bert, optimizer, train_dataloader, lossfn, epochs=100)

In [92]:
def ascii_art_probs(model, sentence):
    model.eval()
    token_ids = torch.Tensor(tokenizer.encode(sentence)).long().unsqueeze(0)
    mask_idxs = set()
    for idx, token_id in enumerate(token_ids[0]):
        if token_id == 103: # 103 == [MASK]
            mask_idxs.add(idx)
    unnormalized_output, _classifications = bert(token_ids.cuda())
    output = torch.log_softmax(unnormalized_output, dim=-1)
    top_k = torch.topk(output, 5, dim=-1)
    results = []
    for seq_i, seq_top_k in enumerate(top_k.indices[0]):
        if seq_i in mask_idxs:
            results.append(list(zip(tokenizer.convert_ids_to_tokens(seq_top_k), [round(x.item(), 2) for x in top_k.values[0][seq_i].exp()])))
    for l in results:
        print(l)

ascii_art_probs(
    bert, 
    'MLAB stands for "Machine [MASK] Alignment Bootcamp." '\
    "Here are my [MASK] to the feedback form: "\
    "My name is [MASK]. "\
    "The best part about MLAB so far has been [MASK]. "\
    "The worst part about MLAB so far has been [MASK]. "\
    "What I would personally like to get out of MLAB is [MASK]. "\
    "Overall, MLAB is [MASK]. "\
    "You could make MLAB better by [MASK] [MASK]. "\
    "I would prefer that we spend [MASK] time on lecture. "\
    "My interactions with the teaching assistants has been [MASK]."
)

[('the', 0.14), ('"', 0.1), ('a', 0.06), ('that', 0.04), ('in', 0.03)]
[('referred', 0.08), ('due', 0.04), ('able', 0.02), ('similar', 0.01), ('according', 0.01)]
[('possible', 0.02), ('a', 0.02), ('well', 0.02), ('not', 0.01), ('"', 0.01)]
[('known', 0.04), ('well', 0.01), ('seen', 0.01), ('released', 0.01), ('him', 0.01)]
[('used', 0.01), ('written', 0.01), ('found', 0.01), ('able', 0.01), ('based', 0.01)]
[('Mr', 0.02), ('a', 0.01), ('Dr', 0.01), ('the', 0.01), ('possible', 0.01)]
[('St', 0.02), ('the', 0.01), ('No', 0.01), ('a', 0.01), ('it', 0.01)]
[('the', 0.45), ('his', 0.11), ('a', 0.05), ('her', 0.03), ('their', 0.03)]
[('"', 0.06), ('time', 0.05), ('years', 0.05), ('critics', 0.01), ('.', 0.01)]
[('the', 0.12), ('a', 0.05), ('this', 0.04), ('to', 0.03), ('his', 0.02)]
[('released', 0.01), ('used', 0.01), ('described', 0.01), ('written', 0.01), ('available', 0.01)]
[('The', 0.16), ('He', 0.12), ('It', 0.08), ('This', 0.05), ('In', 0.04)]
[('the', 0.12), (',', 0.04), ('in', 0.0