In [None]:
import torch as t
import torch.nn as nn
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests

In [None]:
def raw_attention_pattern(token_activations, num_heads, project_query, project_key):
    queries = project_query(token_activations)
    keys = project_key(token_activations)
    keys_reshaped = rearrange(keys, 'b l (h p) -> b h l p', h = num_heads)
    queries_reshaped = rearrange(queries, 'b l (h p) -> b h l p', h = num_heads)
    keys_times_queries = t.einsum('b h l p, b h m p -> b h l m', keys_reshaped, queries_reshaped) / t.sqrt(t.tensor(keys.shape[-1]//num_heads))
    return keys_times_queries
    

In [None]:
bert_tests.test_attention_pattern_fn(raw_attention_pattern)

attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.01012 STD: 0.1032 VALS [-0.08612 0.01278 -0.009718 -0.2377 0.02676 0.1858 -0.05701 -0.1389 0.07155 -0.07107...]


In [None]:
def bert_attention(token_activations, num_heads, attention_pattern, project_value, project_output):
    projected_input = project_value(token_activations)
    soft_max = t.nn.functional.softmax(attention_pattern, dim=-2)
    activations_reshaped = rearrange(projected_input, 'b l (h p) -> b h l p', h = num_heads)
    weighted_activations = t.einsum('b h l m, b h l p -> b h m p', soft_max, activations_reshaped)
    return project_output(rearrange(weighted_activations, 'b h m p -> b m (h p)'))

In [None]:
bert_tests.test_attention_fn(bert_attention)

attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 0.001565 STD: 0.1138 VALS [0.2294 0.06196 -0.05333 0.0651 -0.1487 0.02752 -0.02764 -0.07989 -0.252 -0.1724...]


In [None]:
class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, num_heads, hidden_size):
        super(MultiHeadedSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.query = nn.Linear(hidden_size, hidden_size)
        self.keys = nn.Linear(hidden_size, hidden_size)
        self.project_value = nn.Linear(hidden_size, hidden_size)
        self.project_out = nn.Linear(hidden_size, hidden_size)

    
    def forward(self, token_activations):
        pattern = raw_attention_pattern(token_activations, self.num_heads, lambda a: self.query(a), lambda a: self.keys(a))
        attn = bert_attention(token_activations, self.num_heads, pattern, lambda a: self.project_value(a), lambda a: self.project_out(a))
        return attn

In [None]:
bert_tests.test_bert_attention(MultiHeadedSelfAttention)

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [None]:
def bert_mlp(token_activations, linear_1, linear_2):
    return linear_2(nn.functional.gelu(linear_1(token_activations)))

In [None]:
bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.0001934 STD: 0.1044 VALS [-0.1153 0.1189 -0.0813 0.1021 0.0296 0.06182 0.0341 0.1446 0.2622 -0.08507...]


In [None]:
class BertMLP(nn.Module):
    def __init__(self, input_size, intermediate_size):
        super(BertMLP, self).__init__()
        self.linear1 = nn.Linear(input_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, input_size)
        
    def forward(self, x):
        return bert_mlp(x, self.linear1, self.linear2)

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_dim: int):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(t.ones(normalized_dim))
        self.bias = nn.Parameter(t.zeros(normalized_dim))
        
    def forward(self, x):
        x = (x - t.mean(x, dim=-1).detach().unsqueeze(-1))/t.std(x, dim = -1, unbiased=False).detach().unsqueeze(-1)
        x = x * self.weight + self.bias
        return x

In [None]:
bert_tests.test_layer_norm(LayerNorm)

layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: -1.431e-08 STD: 1.003 VALS [0.6906 -0.84 1.881 1.711 -0.5117 -0.9577 -0.1387 -0.6943 -0.6741 -0.4662...]


In [None]:
class BertBlock(nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_heads, dropout):
        super(BertBlock, self).__init__()
        self.attention = MultiHeadedSelfAttention(num_heads, hidden_size)
        self.layer_norm = LayerNorm(hidden_size)
        self.mlp = BertMLP(hidden_size, intermediate_size)
        self.ln2 = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x_mhsa = self.attention(x)
        x_ln1 = self.layer_norm(x_mhsa + x)
        x_mlp = self.mlp(x_ln1)
        x_dropout = self.dropout(x_mlp)
        x_ln2 = self.ln2(x_dropout + x_ln1) 
        return x_ln2

In [None]:
bert_tests.test_bert_block(BertBlock)

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -4.139e-09 STD: 1 VALS [0.007131 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


In [None]:
class Embedding(nn.Module):
    
    def __init__(self, vocab_size, embed_size):
        super(Embedding, self).__init__()
        self.weight = nn.Parameter(t.randn(vocab_size, embed_size))
        
    def forward(self, x):
        return self.weight[x.long(), :]

In [None]:
bert_tests.test_embedding(Embedding)

embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.06748 STD: 1.062 VALS [1.176 -0.1914 0.8212 1.047 -0.481 0.7106 -1.304 -1.307 -0.438 -0.2764...]


In [None]:
def bert_embedding(input_ids, token_type_ids, position_embedding, token_embedding, token_type_embedding, layer_norm, dropout):
    device = "cuda" if input_ids.is_cuda else "cpu"
    pos_emb = position_embedding(t.arange(0, input_ids.shape[1]).to(device))
    tok_emb = token_embedding(input_ids)
    typ_emb = token_type_embedding(token_type_ids)
    emb = pos_emb + tok_emb + typ_emb
    return dropout(layer_norm(emb))

In [None]:
bert_tests.test_bert_embedding_fn(bert_embedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 4.967e-09 STD: 1 VALS [-1.319 -0.4378 -2.074 0.9679 0.9274 1.479 -0.501 -1.9 -0.212 0.7961...]


In [None]:
class BertEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout):
        super(BertEmbedding, self).__init__()
        self.token_embedding = Embedding(vocab_size, hidden_size)
        self.position_embedding = Embedding(max_position_embeddings, hidden_size)
        self.token_type_embedding = Embedding(type_vocab_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = LayerNorm(hidden_size)
    
    def forward(self, input_ids, token_type_ids):
        return bert_embedding(input_ids, token_type_ids, self.position_embedding, self.token_embedding, self.token_type_embedding, self.layer_norm, self.dropout)

In [None]:
bert_tests.test_bert_embedding(BertEmbedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -4.553e-09 STD: 1 VALS [-0.009385 -0.4919 0.9852 -0.3535 -3.624 1.333 1.163 1.449 1.063 0.246...]


In [None]:
class Bert(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout, intermediate_size, num_heads, num_layers):
        super(Bert, self).__init__()
        self.embedding = BertEmbedding(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
        self.transformer = t.nn.Sequential(*[BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)])
        self.mlp = nn.Linear(hidden_size, hidden_size)
        self.gelu = nn.GELU()
        self.layer_norm = LayerNorm(hidden_size)
        self.unembedding = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, input_ids):
        token_type_ids = t.zeros(input_ids.shape)
        token_type_ids.to("cuda" if input_ids.is_cuda else "cpu")
        embedding = self.embedding(input_ids, token_type_ids)
        output = self.transformer(embedding)
        lin = self.mlp(output)
        gelu = self.gelu(lin)
        layernorm = self.layer_norm(gelu)
        return self.unembedding(layernorm)

In [None]:
bert_tests.test_bert(Bert)

bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.4321 0.1186 -0.7165 -0.5262 0.4967 1.223 0.3165 -0.3247 -0.5717...]


In [None]:
my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
pretrained_bert = bert_tests.get_pretrained_bert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
from collections import OrderedDict

def repl(st):
    st = st.replace("pattern.project_key", "keys")
    st = st.replace("pattern.project_query", "query")
    st = st.replace("residual.mlp", "mlp.linear")
    st = st.replace("residual.layer_norm", "ln2")
    st = st.replace("lm_head.", "")
    return st

skip_params = ["classification_head.weight", "classification_head.bias"]
    
d = OrderedDict([(repl(k), v) for k,v in pretrained_bert.state_dict().items() if k not in skip_params])

print(my_bert.load_state_dict(d))

<All keys matched successfully>


In [None]:
bert_tests.test_same_output(my_bert, pretrained_bert, tol=0.1)

comparing Berts MATCH!!!!!!!!
 SHAPE (10, 20, 28996) MEAN: -2.732 STD: 2.414 VALS [-5.65 -6.041 -6.096 -6.062 -5.945 -5.777 -5.977 -6.015 -6.028 -5.935...]


In [None]:
import transformers

%env TOKENIZERS_PARALLELISM=false
tokenizer_uncased = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")

env: TOKENIZERS_PARALLELISM=false


In [None]:
inp = tokenizer(["The firetruck was painted bright [MASK]."])
print(inp)

{'input_ids': [[101, 1109, 1783, 18062, 8474, 1108, 4331, 3999, 103, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [None]:
def ascii_art_probs(sentence, n_top, tkizer):
    inp = tkizer([sentence])
    probs = t.nn.functional.softmax(my_bert(t.Tensor(inp["input_ids"])), dim=-1)
    mask_num = tkizer(["[MASK]"])['input_ids'][0][1]
    mask_idx = inp["input_ids"][0].index(mask_num)
    probs_mask = probs[:,mask_idx].squeeze()
    sorted_probs = t.sort(probs_mask)
    most_likely = sorted_probs.indices[-n_top:].flip(0)
    probs = sorted_probs.values[-n_top:].flip(0)
    print(sentence.replace("[MASK]", "______"))
    for i in range(n_top):
        print(f"Word: {tkizer.decode(most_likely[i])} \t probability: {int(probs[i]*10000)/100}%")

In [None]:
ascii_art_probs("The fish likes to eat [MASK]", 10, tokenizer)

The fish likes to eat ______
Word: . 	 probability: 94.29%
Word: ; 	 probability: 4.72%
Word: ! 	 probability: 0.89%
Word: ? 	 probability: 0.05%
Word: ... 	 probability: 0.0%
Word: , 	 probability: 0.0%
Word: : 	 probability: 0.0%
Word: and 	 probability: 0.0%
Word: | 	 probability: 0.0%
Word: but 	 probability: 0.0%


In [None]:
class BertClassifier(nn.Module):
    def __init__(self,
                 vocab_size,
                 hidden_size,
                 max_position_embeddings,
                 type_vocab_size,
                 dropout,
                 intermediate_size,
                 num_heads,
                 num_layers,
                 num_classes):
        
        super(BertClassifier, self).__init__()
        self.embedding = BertEmbedding(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
        self.transformer = t.nn.Sequential(*[BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)])
        self.mlp = nn.Linear(hidden_size, hidden_size)
        self.gelu = nn.GELU()
        self.layer_norm = LayerNorm(hidden_size)
        self.unembedding = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.classification_head = nn.Linear(hidden_size, num_classes)
    
    def forward(self, input_ids):
        token_type_ids = t.zeros(input_ids.shape)
        token_type_ids.to("cuda" if input_ids.is_cuda else "cpu")
        embedding = self.embedding(input_ids, token_type_ids)
        output = self.transformer(embedding)
        lin = self.mlp(output)
        gelu = self.gelu(lin)
        layernorm = self.layer_norm(gelu)
        unembedding = self.unembedding(layernorm)
        
        dropout = self.dropout(output[:, 0, :])
        classification = self.classification_head(dropout)
        
        return unembedding, classification

In [None]:
classifier = BertClassifier(vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12, num_classes=2)
d = OrderedDict([(repl(k), v) for k,v in pretrained_bert.state_dict().items()])

print(classifier.load_state_dict(d))

<All keys matched successfully>


In [None]:
import torchtext
data_train, data_test = torchtext.datasets.IMDB(root='.data', split=('train', 'test'))
data = list(data_train).copy()

In [None]:
import einops

def preprocess(data, tokenizer, max_seq_len, batch_size):
    all_data = []
    labels = []
    for label, text in data:
        tokenized_text = tokenizer([text], padding='longest', max_length=max_seq_len, truncation=True)["input_ids"][0]
        if len(tokenized_text) < max_seq_len:
            tokenized_text += [0] * (max_seq_len - len(tokenized_text))
        # tokenized_text = tokenized_text[:max_seq_len]
        all_data.append(tokenized_text)
        labels.append(label)
    all_data = t.Tensor(all_data[:len(all_data) - (len(all_data) % batch_size)])
    labels = t.Tensor(list(map(lambda x: 0 if x == "neg" else 1, labels[:len(labels) - (len(labels) % batch_size)])))
    perm = t.randperm(all_data.shape[0])
    all_data = all_data[perm]
    labels = labels[perm]
    all_data = einops.rearrange(all_data, "(k b) m -> k b m", b = batch_size)
    labels = einops.rearrange(labels, "(k b) -> k b", b = batch_size)
    return all_data, labels

In [None]:
training_batches, training_labels = preprocess(data, tokenizer, 512, 16)
print(training_batches.shape, training_labels.shape)

torch.Size([1562, 16, 512]) torch.Size([1562, 16])


In [None]:
adam = t.optim.Adam(classifier.parameters(), 1e-5)
classifier.train()
classifier.cuda()
t.cuda.empty_cache()
num_batches = training_batches.shape[0]
batch_size = training_batches.shape[1]
for epoch in range(3):
    print("epoch", epoch)
    for batch_num in range(num_batches):
        adam.zero_grad()
        b = training_batches[batch_num].cuda()
        l = training_labels[batch_num].cuda()
        out = classifier(b)[1]
        out_loss = nn.functional.cross_entropy(out, l.long())
        out_loss.backward()
        adam.step()
        if batch_num % 20 == 0:
            print("batch", batch_num, "loss", out_loss.item())

epoch 0
batch 0 loss 2.6286661624908447
batch 20 loss 0.7271848320960999
batch 40 loss 0.7447864413261414
batch 60 loss 0.631729006767273
batch 80 loss 0.392570436000824
batch 100 loss 0.7503160834312439
batch 120 loss 0.44002634286880493
batch 140 loss 0.1489362120628357
batch 160 loss 0.30392324924468994
batch 180 loss 0.28741395473480225
batch 200 loss 0.21664977073669434
batch 220 loss 0.6484529972076416
batch 240 loss 0.2540418803691864
batch 260 loss 0.5189374089241028
batch 280 loss 0.0915040671825409
batch 300 loss 0.15545439720153809
batch 320 loss 0.405807763338089
batch 340 loss 0.3349575996398926
batch 360 loss 0.5228177309036255
batch 380 loss 0.08513675630092621
batch 400 loss 0.2544328570365906
batch 420 loss 0.10949930548667908
batch 440 loss 0.14662708342075348
batch 460 loss 0.41021499037742615
batch 480 loss 0.44425761699676514
batch 500 loss 0.34592345356941223
batch 520 loss 0.11588944494724274
batch 540 loss 0.21250569820404053
batch 560 loss 0.5014015436172485
ba

In [None]:
test_data = list(data_test).copy()

In [None]:
t.save(classifier.state_dict(), "classifier.pt")

In [None]:
classifier2 =  BertClassifier(vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12, num_classes=2)
classifier2.load_state_dict(t.load("classifier.pt"))

<All keys matched successfully>

In [None]:
tokens = tokenizer(["I was shocked and confused in a good way."])
print(tokens)
t.softmax(classifier2(t.Tensor(tokens["input_ids"]))[1], dim=-1)

{'input_ids': [[101, 146, 1108, 6764, 1105, 4853, 1107, 170, 1363, 1236, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}


tensor([[0.3176, 0.6824]], grad_fn=<SoftmaxBackward0>)

In [None]:
test_batches, test_labels = preprocess(test_data, tokenizer, 512, 16)
print(test_batches.shape, test_labels.shape)

torch.Size([1562, 16, 512]) torch.Size([1562, 16])


In [None]:
orig_classifier = BertClassifier(vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12, num_classes=2)
d = OrderedDict([(repl(k), v) for k,v in pretrained_bert.state_dict().items()])
orig_classifier.load_state_dict(d)

<All keys matched successfully>

In [None]:
t.cuda.empty_cache()
orig_classifier.cuda()
test_loss = nn.CrossEntropyLoss()
classifier.eval()
orig_classifier.eval()

our_losses = 0
untrained_losses = 0

example = test_batches[5].cuda()
label = test_labels[5].long().cuda()

for i in range(test_batches.shape[0]):
    example = test_batches[i].cuda()
    label = test_labels[i].long().cuda()

    with t.no_grad():
        out = classifier(example)[1]
        out_orig = orig_classifier(example)[1]
        our_losses += test_loss(out, label).item()
        untrained_losses += test_loss(out_orig, label).item()
        if i % 10 == 9:
            print(f"Our losses: {our_losses}, untrained losses: {untrained_losses}")
        #print(test_loss(out, label))
        #print(test_loss(out_orig, label))
    # print(out, out_orig, label)


Our losses: 2.4411317789927125, untrained losses: 6.9597057700157166
Our losses: 4.395634302869439, untrained losses: 14.038484454154968
Our losses: 6.618245888967067, untrained losses: 21.090316772460938
Our losses: 8.400710095185786, untrained losses: 28.131726503372192
Our losses: 11.236966487485915, untrained losses: 35.11941158771515
Our losses: 13.49742683628574, untrained losses: 42.1220446228981
Our losses: 14.890941345598549, untrained losses: 49.151991188526154
Our losses: 17.449096416588873, untrained losses: 56.359111964702606
Our losses: 19.597863247152418, untrained losses: 63.2793892621994
Our losses: 22.138968688901514, untrained losses: 70.4475582242012
Our losses: 24.37586442474276, untrained losses: 77.4146329164505
Our losses: 26.547236795537174, untrained losses: 84.50524258613586
Our losses: 32.004129777662456, untrained losses: 91.6532335281372
Our losses: 35.05536588234827, untrained losses: 98.58932375907898
Our losses: 37.947907514404505, untrained losses: 105

In [None]:
from torchtext.datasets import WikiText2
data_train, data_test = WikiText2(root='.data', split=('train', 'test'))
wiki_train = list(data_train)[:1000].copy()
wiki_test = list(data_test).copy()
print(len(wiki_train))
print(wiki_train[0:10])

1000
[' \n', ' = Valkyria Chronicles III = \n', ' \n', ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n', " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> 

In [None]:
print(wiki_train[57])

 



In [None]:
import random

def mask_wiki(data, tokenizer, max_seq_len):
    orig_data = []
    all_data = []
    lengths = []
    for text in data:
        tokenized_text = tokenizer([text], max_length=max_seq_len, truncation=True)["input_ids"][0]
        lengths.append(len(tokenized_text))
        all_data += tokenized_text
        orig_data.append(tokenized_text)
        # if len(tokenized_text) < max_seq_len:
        #     tokenized_text += [0] * (max_seq_len - len(tokenized_text))
    for i in range(len(all_data)):
        if random.random() < 0.15:
            if random.random() < 0.8:
                all_data[i] = 103 # [MASK] token
            elif random.random() < 0.5:
                idx = random.randint(0, len(all_data))
                all_data[i] = all_data[idx]
    reshaped_data = []
    idx = 0
    for i in range(len(lengths)):
        reshaped_data.append(all_data[idx:idx + lengths[i]])
        idx += lengths[i]
    return reshaped_data, orig_data

In [None]:
import torch as t
print(wiki_train[3])
max_seq_len = 512
masked_wiki, wiki = mask_wiki(wiki_train, tokenizer, max_seq_len)
for i in range(len(masked_wiki)):
    if len(masked_wiki[i]) < max_seq_len:
        masked_wiki[i] += [0] * (max_seq_len - len(masked_wiki[i]))
        wiki[i] += [0] * (max_seq_len - len(wiki[i]))
masked_wiki = t.Tensor(masked_wiki)
wiki = t.Tensor(wiki)
print(masked_wiki.shape)

 Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . 

torch.Size([1000, 512])


In [None]:
is_mask_token = (masked_wiki == 103)

In [None]:
vocab_size = 28996
tiny_bert = Bert(
    vocab_size=vocab_size, hidden_size=384, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=1536, 
    num_heads=12, num_layers=2
)
tiny_bert.eval()
probs = tiny_bert(masked_wiki)
masked_tokens = wiki.masked_select(is_mask_token)
print(masked_tokens)
print(probs.shape)
predictions = einops.rearrange(probs.masked_select(is_mask_token.unsqueeze(-1)), "(h w) -> h w", w= vocab_size)
actual = t.zeros_like(predictions)
for i in range(masked_tokens.shape[0]):
    actual[i, masked_tokens[i].int()] = 1
    
print(predictions.shape, actual.shape)
    
with t.no_grad():
    print(predictions[[0]], actual[[0]])
    print(predictions[[1]], actual[[1]])
    print(loss(predictions[[0]], actual[[0]]))
    print(loss(predictions[[1]], actual[[1]]))
    print(loss(predictions, actual))

In [None]:
def create_labels(wiki, masked_wiki, vocab_size):
    is_mask_token = (masked_wiki == 103)
    masked_tokens = wiki.masked_select(is_mask_token)
    actual = t.zeros(masked_tokens.shape[0], vocab_size)
    print(wiki.shape)
    print(actual.shape)
    for i in range(masked_tokens.shape[0]):
        actual[i, masked_tokens[i].int()] = 1
    return actual

In [None]:
batch_size = 5
batched_masked_wiki = einops.rearrange(masked_wiki[masked_wiki.shape[0]%batch_size:], "(k b) l -> k b l", b = batch_size)
batched_wiki = einops.rearrange(wiki[wiki.shape[0]%batch_size:], "(k b) l -> k b l", b = batch_size)
print(batched_masked_wiki.shape, batched_wiki.shape)

In [None]:
loss = nn.CrossEntropyLoss()
tiny_bert = Bert(
    vocab_size=vocab_size, hidden_size=384, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=1536, 
    num_heads=12, num_layers=2
)
adam = t.optim.Adam(tiny_bert.parameters(), 1e-5)
tiny_bert.train()
tiny_bert.cuda()
t.cuda.empty_cache()
num_batches = batched_masked_wiki.shape[0]
batch_size = batched_masked_wiki.shape[1]
for epoch in range(3):
    print("epoch", epoch)
    for batch_num in range(num_batches):
        adam.zero_grad()
        b = batched_masked_wiki[batch_num].cuda()
        l = batched_wiki[batch_num].cuda()
        print(tiny_bert(b).shape)
        out = tiny_bert(b)
        
        is_mask_token = (b == 103)
        masked_tokens = l.masked_select(is_mask_token)
        predictions = einops.rearrange(out.masked_select(is_mask_token.unsqueeze(-1)), "(h w) -> h w", w= vocab_size)
        actual = t.zeros_like(predictions)
        for i in range(masked_tokens.shape[0]):
            actual[i, masked_tokens[i].int()] = 1
        print(predictions.shape, actual.shape)
        out_loss = loss(predictions, actual)
        out_loss.backward()
        adam.step()
        if batch_num % 20 == 0:
            print("batch", batch_num, "loss", out_loss.item())