In [1]:
%run w2d1.ipynb

In [2]:
from collections import OrderedDict
from typing import Callable, Dict, Optional, List, Tuple

import torch as t
from torch import nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests
import matplotlib.pyplot as plt

## Tokenization

In [3]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
# print(tokenizer("hello what's up"))
# uncased_tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
# print(uncased_tokenizer(["hello what's up"]))
# coded = uncased_tokenizer(["hello what's up"])
# uncased_tokenizer.batch_decode(coded['input_ids'])
# tokenizer.batch_decode(coded['input_ids'])
# uncased_tokenizer.batch_decode(coded['input_ids'])


## Inference

In [4]:
my_bert, pretrained_bert = load_pretrained_bert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
def feed_bert(model: nn.Module, text: str, tokenizer, top_k: int = 10):
    input_ids: List[int] = tokenizer(text)["input_ids"]
    mask_idxs = [idx for idx, token in enumerate(input_ids) if token == 103]

    all_logits = model(t.tensor([input_ids], dtype=t.long))[0]

    print(text)
    for mask_idx in mask_idxs:
        logits = all_logits[mask_idx]
        probs = t.softmax(logits, dim=0)

        top_logit_idxs = t.argsort(logits, descending=True)[:top_k]
        top_logit_words = tokenizer.decode(top_logit_idxs)

        print(top_logit_words)
        print(probs[top_logit_idxs])
        print()

my_bert.eval()
feed_bert(my_bert, "The fish loves to eat [MASK].", tokenizer, top_k=20)
feed_bert(my_bert, "The fish loves to eat [MASK]", tokenizer, top_k=20)
#feed_bert(my_bert, "The vegetarian fish loves to eat [MASK].", tokenizer, top_k=20)
#feed_bert(my_bert, "The meat-eating fish loves to eat [MASK].", tokenizer, top_k=20)
#feed_bert(my_bert, "The tiny fish loves to eat [MASK].", tokenizer, top_k=20)


The fish loves to eat [MASK].
it fish them meat food eggs honey insects too rice everything water vegetables this fruit apples him there again here
tensor([0.1738, 0.0980, 0.0947, 0.0410, 0.0336, 0.0251, 0.0134, 0.0130, 0.0126,
        0.0119, 0.0092, 0.0090, 0.0088, 0.0083, 0.0072, 0.0069, 0.0063, 0.0060,
        0.0058, 0.0054], grad_fn=<IndexBackward0>)

The fish loves to eat [MASK]
. ;!?..., : | and " but - so । because as [UNK]') with
tensor([9.4125e-01, 4.6098e-02, 1.1822e-02, 4.5820e-04, 1.2235e-04, 5.4506e-05,
        3.6213e-05, 1.6483e-05, 1.2279e-05, 9.2127e-06, 6.5461e-06, 4.6536e-06,
        3.4753e-06, 3.3669e-06, 2.9931e-06, 2.4598e-06, 1.9791e-06, 1.7764e-06,
        1.3952e-06, 1.0635e-06], grad_fn=<IndexBackward0>)



## Fine tuning

In [6]:
bert_tests.test_bert_classification(Bert)

bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.432 0.1186 -0.7165 -0.5261 0.4967 1.223 0.3165 -0.3247 -0.5716...]
bert MATCH!!!!!!!!
 SHAPE (1, 2) MEAN: 0.09479 STD: 1.411 VALS [-0.903 1.093]


In [7]:
def get_imdb_collate_fn(
    max_seq_length: int,
    tokenizer: transformers.AutoTokenizer,
    device: str,
):
    def fn(raw_xs: List[Tuple[str, str]]) -> Tuple[t.Tensor, t.Tensor]:
        labels: Tuple[str, ...]
        texts: Tuple[str, ...]
        labels, texts = zip(*raw_xs)

        xs = t.tensor(
            tokenizer(
                list(texts),
                padding="longest",
                max_length=max_seq_length,
                truncation=True,
            )['input_ids'],
            dtype=t.long,
            device=device,
        )

        ys = t.tensor([int(l == "pos") for l in labels], dtype=t.long, device=device)

        return xs, ys

    return fn


In [8]:
from torch.utils.data import DataLoader
import torchtext

data_train, data_test = torchtext.datasets.IMDB(root='.data', split=('train', 'test'))
data_train = list(data_train)
data_test = list(data_test)

In [9]:
import random

device = "cuda"
collate_fn = get_imdb_collate_fn(512, tokenizer, device)

dl_train_small = DataLoader(
    random.sample(data_train, k=16),
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
)

dl_train = DataLoader(
    data_train,
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    # num_workers=0,
    # pin_memory=True,
)

dl_test_small = DataLoader(
    random.sample(data_test, k=256),
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
)

dl_test = DataLoader(
    data_test,
    batch_size=2,
    collate_fn=collate_fn,
    shuffle=True,
)


In [10]:
from torch import optim
from tqdm import tqdm
import os

def get_accuracy(model: nn.Module, dl: DataLoader) -> float:

    num_correct: int = 0
    num_total: int = 0

    model.eval()

    pbar = tqdm(dl, disable=True)
    for x, y in pbar:
        _, out = model(x)
        preds = t.argmax(out, dim=-1)

        num_correct += (preds == y).sum()
        num_total += len(y)
        pbar.set_description(f'acc={num_correct / num_total:.2}')

    return num_correct / num_total

def finetune_bert_epoch(model: nn.Module, dl_train: DataLoader, dl_test: DataLoader) -> nn.Module:
    optimizer = optim.Adam(model.parameters(), lr=1e-5)  # broken?
    pbar = tqdm(enumerate(dl_train))
    for i, (x, y) in pbar:
        optimizer.zero_grad()
        _, out = model(x)
        loss = F.cross_entropy(input=out, target=y)
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            pbar.set_description(f'loss={loss.item():.2},acc={get_accuracy(model, dl_test)}')
            model.train()

    return model

my_bert, _ = load_pretrained_bert(num_classes=2)
#for i, (name, p) in enumerate(my_bert.named_parameters()):
#    print(name)
#    p.cuda()

my_bert.cuda()
my_bert.train()

# print(get_accuracy(my_bert, dl_test_small))
import gc
gc.collect()
t.cuda.empty_cache()
epochs = 100
for i in range(epochs):
    print(i)
    model = finetune_bert_epoch(my_bert, dl_train=dl_train, dl_test=dl_test_small)


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


0


loss=0.012,acc=0.9375: : 3125it [16:13,  3.21it/s]    


1


loss=0.34,acc=0.90234375: : 1245it [06:28,  3.20it/s]  


KeyboardInterrupt: 

In [None]:
%pdb

Automatic pdb calling has been turned ON


## Training from Scratch on Masked Language Modeling

In [23]:
data_train, data_test = torchtext.datasets.WikiText2(root='.data', split=('train', 'test'))

def wiki_include(text: str) -> bool:
    return text.split("text")

data_train = list(data_train)
data_test = list(data_test)



In [129]:
def get_wiki_collate_fn(
    max_seq_length: int,
    tokenizer: transformers.PreTrainedTokenizerBase,
    pred_frac: float,
    mask_frac: float,
    random_frac: float,
    device: str,
):
    assert 0 <= pred_frac <= 1 and 0 <= mask_frac <= 1 and 0 <= random_frac <= 1
    assert 0 <= mask_frac + random_frac <= 1
    unchanged_frac = 1 - mask_frac - random_frac

    def fn(texts: List[str]) -> Tuple[t.Tensor, t.Tensor]:
        # TODO: Sample random substring of texts to have more data diversity?
        xs = t.tensor(
            tokenizer(
                list(texts),
                padding="longest",
                max_length=max_seq_length,
                truncation=True,
            )["input_ids"],
            dtype=t.long,
            device=device,
        )

        pred_mask = (t.rand_like(xs, dtype=t.float) < pred_frac) & (
            (xs != tokenizer.pad_token_id) & (xs != tokenizer.cls_token_id) &
            (xs != tokenizer.eos_token_id) & (xs != tokenizer.sep_token_id)
        )
        ys = t.masked_select(xs, pred_mask)

        r = t.rand_like(xs, dtype=t.float)
        mask_mask = r < mask_frac
        random_mask = (mask_frac <= r) & (r < mask_frac + random_frac)

        xs[pred_mask & mask_mask] = tokenizer.mask_token_id

        random_input_ids = t.randint(
            low=0, high=len(tokenizer), size=xs.shape, dtype=t.long, device=device
        )
        xs[pred_mask & random_mask] = random_input_ids[pred_mask & random_mask]

        return xs, pred_mask, ys

    return fn


In [130]:
import random

random.seed(0)

device = "cuda"
collate_fn = get_wiki_collate_fn(
    max_seq_length=10,
    tokenizer=transformers.AutoTokenizer.from_pretrained("bert-base-cased"),
    pred_frac=0.15,
    mask_frac=1,
    random_frac=0,
    device=device,
)

dl_train_small = DataLoader(
    random.sample(data_train, k=4),
    batch_size=4,
    collate_fn=collate_fn,
    shuffle=True,
)

dl_train = DataLoader(
    data_train,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
)

dl_test_small = DataLoader(
    random.sample(data_test, k=256),
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
)

dl_test = DataLoader(
    data_test,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
)


In [131]:
def get_mlm_accuracy(model: nn.Module, dl: DataLoader) -> float:
    num_correct: int = 0
    num_total: int = 0

    model.eval()

    pbar = tqdm(dl, disable=True)
    for x, pred_mask, y in pbar:
        if len(y) == 0:
            continue
        logits = model(x)
        pred_logits_flat = t.masked_select(logits, pred_mask.unsqueeze(-1))
        pred_logits = pred_logits_flat.reshape((-1, logits.shape[-1]))
        preds = t.argmax(pred_logits, dim=-1)

        num_correct += (preds == y).sum()
        num_total += len(y)
        pbar.set_description(f'acc={num_correct / num_total:.2}')

    if num_total == 0:
        return 0.0
    return num_correct / num_total

def mlm_epoch(model: nn.Module, dl_train: DataLoader, dl_test: DataLoader, lr: float, opbar=None) -> nn.Module:
    optimizer = optim.Adam(model.parameters(), lr=lr)  # broken?
    pbar = tqdm(enumerate(dl_train), leave=False)
    for i, (x, pred_mask, y) in pbar:
        if len(y) == 0:
            continue
        optimizer.zero_grad()
        logits = model(x)
        pred_logits_flat = t.masked_select(logits, pred_mask.unsqueeze(-1))
        pred_logits = pred_logits_flat.reshape((-1, logits.shape[-1]))
        loss = F.cross_entropy(input=pred_logits, target=y)
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            pbar.set_description(f'loss={loss.item():.3f},acc={get_mlm_accuracy(model, dl_test):.3f}')
            model.train()
            if opbar is not None:
                opbar.set_description(f'loss={loss.item():.3f},acc={get_mlm_accuracy(model, dl_test):.3f}')
            # print(tokenizer.decode(y))

    return model

In [140]:
tiny_bert = Bert(
    vocab_size=28996,
    hidden_size=256,
    max_position_embeddings=512,
    type_vocab_size=2,
    dropout=0,#0.1,
    intermediate_size=1024,
    num_heads=12,
    num_layers=2,
)
tiny_bert.to(device)
None

In [141]:
tiny_bert.train()
epochs = 1000
pbar = tqdm(range(epochs))
for i in pbar:
    # print(i)
    mlm_epoch(tiny_bert, dl_train_small, dl_train_small, 1e-4, pbar)

loss=3.330,acc=0.000: 100%|██████████| 1000/1000 [00:49<00:00, 20.28it/s]


In [133]:
# feed_bert(tiny_bert, "The fish loves to eat [MASK].", tokenizer, top_k=20)

In [149]:
def feed_bert(model: nn.Module, text: str, tokenizer, top_k: int = 10):
    input_ids: List[int] = tokenizer(text)["input_ids"]
    mask_idxs = [idx for idx, token in enumerate(input_ids) if token == tokenizer.mask_token_id]
    print(mask_idxs)

    all_logits = model(t.tensor([input_ids], dtype=t.long))[0]

    print(text)
    for mask_idx in mask_idxs:
        logits = all_logits[mask_idx]
        probs = t.softmax(logits, dim=0)

        top_logit_idxs = t.argsort(logits, descending=True)[:top_k]
        top_logit_words = tokenizer.decode(top_logit_idxs)

        t.set_printoptions(precision=9)
        print(top_logit_words)
        print(probs[top_logit_idxs])
        print()
        t.set_printoptions(precision=4)

In [178]:
tiny_bert.eval()
for (x, mask, y) in dl_train_small:
    print(tokenizer.batch_decode(x))
    print(tokenizer.decode(y))
    print("*****************************")
    logits = tiny_bert(x)
    #print(logits.shape)
    print(t.sort(logits, dim=-1).values[:, :3, -10:])
    print("-----------------------")
    for xx in x:
        feed_bert(tiny_bert, tokenizer.decode(xx), tokenizer, top_k=5)
        print("---------------------------------------")


['[CLS] A kits [MASK] [MASK] take [MASK] human form [SEP]', '[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', '[CLS] They [MASK] in the final of the 2012 [SEP]', '[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']
##une may on met
*****************************
tensor([[[-11.6676, -11.5691, -11.5570, -11.4961, -11.3019, -11.2182, -11.0608,
          -10.8353, -10.2363,  -9.4051],
         [-11.6676, -11.5691, -11.5570, -11.4961, -11.3019, -11.2182, -11.0608,
          -10.8353, -10.2363,  -9.4051],
         [-11.6676, -11.5691, -11.5570, -11.4961, -11.3019, -11.2182, -11.0608,
          -10.8353, -10.2363,  -9.4051]],

        [[-11.6676, -11.5691, -11.5570, -11.4961, -11.3019, -11.2182, -11.0608,
          -10.8353, -10.2363,  -9.4051],
         [-11.6676, -11.5691, -11.5570, -11.4961, -11.3019, -11.2182, -11.0608,
          -10.8353, -10.2363,  -9.4051],
         [-11.6676, -11.5691, -11.5570, -11.4961, -11.3019, -11.2182, -11.0608,
          -10.8353, -10.2363

In [None]:
test_bert, _ = load_pretrained_bert(num_classes=2)

In [184]:
test_bert.eval()
for (x, mask, y) in dl_train_small:
    print(tokenizer.batch_decode(x))
    print(tokenizer.decode(y))
    print("*****************************")
    logits, _ = test_bert(x)
    #print(logits.shape)
    # print(t.sort(logits, dim=-1).values[:, :3, -10:])
    print("-----------------------")
    for xx in x:
        feed_bert(lambda x : test_bert(x)[0], tokenizer.decode(xx), tokenizer, top_k=5)
        print("---------------------------------------")


['[CLS] [MASK] kits [MASK] may take on [MASK] form [SEP]', '[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', '[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]', '[CLS] They met in the [MASK] of the 2012 [SEP]']
Aune human final
*****************************
-----------------------
[2, 4, 8]
[CLS] [MASK] kits [MASK] may take on [MASK] form [SEP]
The These All For From
tensor([0.243836567, 0.067158222, 0.028144445, 0.018736543, 0.018527528],
       grad_fn=<IndexBackward0>)

also that which and,
tensor([0.172607183, 0.137166083, 0.128834605, 0.048280545, 0.035198011],
       grad_fn=<IndexBackward0>)

any this the kit a
tensor([0.408585191, 0.173036709, 0.076863118, 0.069941096, 0.029075900],
       grad_fn=<IndexBackward0>)

---------------------------------------
[]
[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
---------------------------------------
[]
[CLS] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
------------------------------------

In [162]:
tiny_bert.lm_head.unembedding.weight

Parameter containing:
tensor([[-0.0390,  0.1328, -0.0412,  ...,  0.1419,  0.0393,  0.0362],
        [-0.0796,  0.1322, -0.1050,  ...,  0.1515,  0.0561,  0.0403],
        [-0.0453,  0.0650, -0.0387,  ...,  0.0973,  0.1245,  0.0768],
        ...,
        [-0.1224,  0.1273, -0.0526,  ...,  0.1105,  0.0831,  0.0459],
        [-0.0134,  0.0756, -0.1486,  ...,  0.1267,  0.0428,  0.1165],
        [-0.0816,  0.1368, -0.0750,  ...,  0.0707,  0.1486,  0.1435]],
       device='cuda:0', requires_grad=True)

In [163]:
tiny_bert.lm_head.unembedding.bias

Parameter containing:
tensor([-0.0315, -0.0883, -0.0943,  ..., -0.0812, -0.0664, -0.1027],
       device='cuda:0', requires_grad=True)