# Do we want to use a Transformer(Decoder), and if so, how?

The default implementation of the Transformer takes **2** parameters (the source sequence and the right-shifted target sequence), i.e. by default it is a translation model (modelling joint probability distributions). Since we only care about generation, what do we do with that?

Ideas to try:

  1. Instead of the right-shifted traget, provide noise; perhaps noise of different shapes. To be figured out: does the Transformer learn to still reconstruct the sequence? Is this use just a waste of parameters (hence computation)? 
  
  2. Can we somehow invert the TransformerEncoder? That is, learn a function which reconstructs the original sequence from the encoded one, very much like an auto-encoder would work. To be figured out: Does this properly leverage the power of the Transformer?
  
  3. Train the model to translate into itself, i.e. learn P(X, X) = P(X)? How is this useful though, especially for generation?
  
  
**Generally, we need to figure out how to use the Transformer as a generative model.**
Maybe the literature on generative models has inspiration, or even guidance via proper mathematical formalisation of what is to be modeled (formalise modelling problem: random variables, relationships between them, which ones are observable and which ones are not). 

In [None]:
from tqdm import tqdm

import matplotlib.pyplot as plt

In [None]:
import torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Embedding, RNN
from torch.nn import Linear, Sigmoid, Softmax
from torch.nn import NLLLoss, CrossEntropyLoss
from torch.optim import SGD

from torch.nn import MSELoss, L1Loss, BCELoss
from torch.optim import Adam

In [None]:
from nets import LM, AggregateHead, ReconstructHead

In [None]:
from utils import iter_batches, merge_and_shuffle

# Data

In [None]:
# val = 1
# ind = 0
# def val_in_seq(data):
#     return torch.tensor([val in seq for seq in data]).unsqueeze(1).float()
# def val_at_index(data):
#     return (data[:, ind] == val).unsqueeze(1).float()

# def is_sorted(data):
#     return ((X.sort(1).values == X).sum(1) == X.shape[1]).unsqueeze(1).float()

# def sum_of_seq(data):
#     return data.sum(1).unsqueeze(1).float()


# task_function = sum_of_seq    

In [None]:
n, k, V, d = 100, 3, 5, 1
X = torch.randint(V, size=(n, k))
Y = X[:]

eval_X = torch.randint(V, size=(10, k))
eval_Y = eval_X[:]

# Definitions

In [None]:
class Model(torch.nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
        
    def forward(self, inputs):
        vectors = self.enc(inputs)
        return self.dec(vectors)

# Instantiations

In [None]:
enc = LM(V, embed_dim=64, num_layers=2)
dec = ReconstructHead(enc)
model = Model(enc, dec)


losses = []

criterion = CrossEntropyLoss()
optim = SGD(dec.parameters(), lr=0.01)

# Training

In [None]:
alpha = 0.3

model.train()

for _ in tqdm(range(2000)):
    optim.zero_grad()
    
    X_ = model(X)
    
#     print(X_.shape, X.shape, X_.view(-1, 5).shape, X.view(-1).shape)
    
    loss = criterion(X_.view(-1, 5), X.view(-1))
    loss.backward()
    optim.step()
    losses.append(loss.detach())

# Inspection

In [None]:
plt.plot(range(len(losses)), losses, "--")

In [None]:
eval_X_ = model(eval_X)

print(eval_X_.argmax(-1), eval_X)