In [1]:
from tqdm import tqdm

import matplotlib.pyplot as plt

In [2]:
import torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Embedding, Linear, Softmax, NLLLoss, RNN, ELU
from torch.optim import SGD

In [3]:
class TrainLoop:
    def __init__(self, model):
        self.criterion = NLLLoss()
        self.optimiser = SGD(lm.parameters(), lr=0.1) # 0.01
        self.model = model
        
        self.losses = []
        self.i = 0

    def train(self, n, train_in, train_out, eval_out=None, verbose=0):
        self.model.train()
        self.train_data = train_in
        
#         if train_out is None: train_out = train_in
        if eval_out is None: eval_out = train_in
        
        
        
        for i in tqdm(range(n)):  # , initial=self.i, total=n):
            self.optimiser.zero_grad()
            
            predicted = self.model(train_in)
            
            
#             loss = 0
#             for pred_row, true_row in zip(predicted, train_out):
#                 loss += self.criterion(pred_row, true_row)
            
            loss = self.criterion(predicted.flatten(0,1), train_out.flatten(0,1))
            loss.backward()
            self.optimiser.step()
        
            self.losses.append(loss.detach().item())
            
            if verbose and i and i % verbose == 0:
                self.eval_with(eval_out)
        
        self.i = i
        if verbose: self.show()
            
    def show(self):
        plt.plot(range(len(self.losses)), self.losses, "--")
        
        
    def eval_with(self, test=None, average=True):
        if test is None: test = self.train_data
        n, k = test.shape
        acc = (test == lm(test).argmax(-1)).sum()/(n*k)
        print("Accuracy: ", round(acc.item(), 3))

---
# Pure LM -- TransformerEncoder which Learns to Copy

In [None]:
class LM(torch.nn.Module):
    def __init__(self, vocab_dim, embed_dim=4):
        super().__init__()
        
        self.vocab_dim = vocab_dim
        self.embed_dim = embed_dim
        self.emb = Embedding(vocab_dim, self.embed_dim)
        
        # nhead needs to divide d_model (embedding dimension)
        self.encoder_layer = TransformerEncoderLayer(d_model=self.embed_dim, nhead=self.embed_dim//2)
        self.encoder = TransformerEncoder(self.encoder_layer, num_layers=1)

        
    def forward(self, x):
        x_emb = self.emb(x)
        out = self.encoder(x_emb)
#         out_p = self.sigma(self.linear(out))
        return out  # output shape is (x.shape[0], x.shape[1], self.embed_dim)

In [None]:
class Head(torch.nn.Module):
    def __init__(self, lm, out_dim):
        super().__init__()
        self.lm = lm
        self.in_dim = lm.embed_dim
        self.out_dim = out_dim
        
        self.linear = Linear(in_features=self.in_dim, out_features=self.out_dim)
    
    def forward(self, lm_input):
        return self.linear(lm(lm_input))
    
class ReconstructHead(Head):
    def __init__(self, lm):
        super().__init__(lm, out_dim=lm.vocab_dim)
        self.sigma = Softmax(dim=-1)
        
    def forward(self, lm_input):
        # required to fan back out to vocabulary dimensionality from embedding dimensionality
        back_projection = super().forward(lm_input)
        return self.sigma(back_projection)
    
    
class RNNHead(Head):
    def __init__(self, lm):
        super().__init__(lm, out_dim=1)
        
        self.h = 4
        self.rnn = RNN(input_size=self.in_dim, hidden_size=self.h,
                        num_layers=1, bidirectional=False,
                        batch_first=True, dropout=0.1)
        self.linear = Linear(in_features=self.h, out_features=self.out_dim)
        self.elu = ELU()
    
    @staticmethod
    def aggregate_hidden_layers(rnn_hidden, method="concat"):
        # shape of rnn_hidden: (num_layers * num_directions, batch, hidden_size)
        if method == "concat":
            return rnn_hidden.transpose(1,0).flatten(1,2)
        elif method == "sum":
            return rnn_hidden.transpose(1,0).sum(-1)
        else:
            raise ValueError()
    
    
    def forward(self, lm_input):
        lm_out = lm(lm_input)
        
        rnn_out, rnn_hidden = self.rnn(lm_out)
#         print(rnn_hidden.shape)
        hidden_vec = self.aggregate_hidden_layers(rnn_hidden, method="concat")
#         print(rnn_out.shape, hidden_vec.shape)
        return self.elu(self.linear(hidden_vec))
        
        

In [None]:
# toy test data
n, k, V = 100, 4, 6
vecs = torch.randint(V, size=(n, k))
sums = vecs.sum(-1).reshape(-1, 1)
sorts = vecs.sort(-1)[0]
remainders = vecs.remainder(5)
reverses = vecs.flip((-1, ))

In [None]:
lm = LM(V, embed_dim=100)

sort_head = ReconstructHead(lm)
rnn_head = RNNHead(lm)

trainer = TrainLoop(rnn_head)

In [None]:
trainer.train(1000, train_in=vecs, train_out=sums, verbose=0)
print()
trainer.eval_with(torch.randint(V, size=(100, k)))

In [None]:
trainer.show()

In [None]:
eval_vecs = torch.randint(V, size=(3, k))

eval_vecs, sort_head(eval_vecs).argmax(-1), eval_vecs.flip((-1, ))

In [None]:
is_sorted = (vecs == sorts).sum(-1) == k
is_sorted

# Things to Try

 - implement transformer heads: sum of vector, sort/reverse vector
 

 - Postional Encodings (see [pytorch tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html))
 - Masking Attention (see [pytorch tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)) <br>
   -> potentially has effect equivalent to skip-grams
   
   
 - extract embeddings (of vocabulary, of vector) from LM <br>
   e.g. something along the lines of `list(lm.emb.parameters())[0].detach()`

# DEV: RNN Head Training

In [None]:
def iter_batches(batch_size, train_in, train_out, shuffle=False):
    if shuffle:
        pass
    
    cur = 0
    batch_in, batch_out = train_in[cur:cur+batch_size], train_out[cur:cur+batch_size]
    yield batch_in, batch_out
    
    while batch_in.shape[0] == batch_size:
        cur += batch_size
        batch_in, batch_out = train_in[cur:cur+batch_size], train_out[cur:cur+batch_size]
        if batch_in.numel() > 0:
            yield batch_in, batch_out

In [None]:
train_in = vecs; train_out = sums.float()
eval_in = torch.randint(V, size=(10, k)); eval_out = eval_in.sum(-1).reshape(-1, 1).float()

In [None]:
lm = LM(V, embed_dim=10)
rnn_head = RNNHead(lm)

losses = []

In [None]:
rnn_head.train()

from torch.nn import MSELoss, L1Loss
from torch.optim import Adam

criterion = L1Loss(reduction="sum")
optimiser = SGD(lm.parameters(), lr=0.01)
# optimiser = Adam(lm.parameters(), lr=0.1)
        
    
for i in tqdm(range(500)):
    optimiser.zero_grad()

#     predicted = rnn_head(train_in)
    
    loss = 0
    for batch_in, batch_out in iter_batches(90, train_in, train_out):
        
        sum_pred = rnn_head(batch_in)
        loss += criterion(sum_pred, batch_out)
    loss.backward()
    optimiser.step()
    losses.append(loss.detach().item())
        
            
#     loss = criterion(predicted, train_out)
    
# #     print(predicted.dtype, loss.dtype, train_out.dtype)
#     loss.backward()
#     optimiser.step()
        
#     losses.append(loss.detach().item())


In [None]:
plt.plot(range(len(losses)), losses, "--")

In [None]:
eval_in, rnn_head(eval_in), eval_out, 

---
# Using Huggingface's transformers library (BERT, GPT-2, ...)

In [None]:
import torch
from transformers import BertConfig, BertModel

In [None]:
conf = BertConfig()

model = BertModel(conf)

In [None]:
tt = torch.tensor([[1,2,3]])
out = model(tt, return_dict=True)

In [None]:
out.last_hidden_state.shape

In [None]:
tt = torch.tensor([[1,2,3,4], [5,6,7,8]]).reshape(-1)