# Learning the Decoder

For the convolutional decoder, we're initially working with discrete representations, which the convolutions do not work with, they require a continuous representation. Therefore, I have to convert these discrete reps to continuous ones. The method used is by embedding the categorical variables (the groups that contain the n different classes).  

I get an embedding for each single class that is found within a group (a category, such as 'color'). So if there were 4 colors (red, blue, yellow, green), each of these colors contains an embedding that is of a size of our choosing. Each group has their own embedding created. So if there were 3 groups (color, shape, texture), each of these would have their own nn.Embedding.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# --- what we're expecting as input ---
# we want the integer indices associated with each group, e.g., if there were 4 groups each with 4 classes to pick from
# then we'd have something like the following:
x = torch.tensor([0, 1, 3, 2])
print("For only one sample:")
for i in range(len(x)):
    print(f"For group {i+1}, {x[i]} is the index associated with the chosen class.")

# normally there are batches, so we change the input as such if there are 3 samples in one batch:
y = torch.tensor([[0, 1, 3, 2],
                  [1, 2, 3, 0],
                  [1, 1, 2, 2]])

For only one sample:
For group 1, 0 is the index associated with the chosen class.
For group 2, 1 is the index associated with the chosen class.
For group 3, 3 is the index associated with the chosen class.
For group 4, 2 is the index associated with the chosen class.


In [None]:
# --- where the input comes from ---
# from the posterior (Encoder) or prior (Dynamics Predictor) network, we get logits
# these logits are used to find the biggest one for each group
# e.g., if one group (G) has 4 classes (C), then there will be 4 logits, each assoc with a class
# the largest logit will be chosen, which would be the discrete choice,
# i.e., the largest logit is associated with the 1 from theone hot vector that would be chosen via the gumbel softmax

# --- EXAMPLE ---
# logits = dynamics.dist(h_t) # dynamics is the instantiated class and .dist() is the method that gets the logits
# z_idx  = logits.argmax(dim=-1)
# z_idx
# torch.tensor([[0, 1, 3, 2],
#               [1, 2, 3, 0],
#               [1, 1, 2, 2]])
# this would give (B, G) size with each G being an index that picks out the class

In [9]:
x = torch.tensor([[0, 1, 3, 2],
               [1, 2, 3, 0],
               [1, 1, 2, 2]])

n_classes = 4
embed_dim = 16
n_cats    = 4
batch     = 3

tables = nn.ModuleList([nn.Embedding(n_classes, embed_dim) for _ in range(n_cats)])

pieces = []
for i, tab in enumerate(tables):
    pieces.append(tab(x[:, i]))

In [18]:
pieces[0]

tensor([[ 0.6862, -0.4291,  0.8057, -0.0298, -0.7715,  1.0482, -0.5261,  1.4336,
          1.0581,  2.6074,  0.0080,  1.7126,  1.1000,  1.2689, -0.1735,  1.5911],
        [-0.4210, -0.5019, -0.1908, -0.8345, -0.7349,  0.5490,  0.4359,  0.8548,
          0.9642,  1.3847, -0.1931,  0.3786,  1.2145,  0.6541,  0.2604,  0.4738],
        [-0.4210, -0.5019, -0.1908, -0.8345, -0.7349,  0.5490,  0.4359,  0.8548,
          0.9642,  1.3847, -0.1931,  0.3786,  1.2145,  0.6541,  0.2604,  0.4738]],
       grad_fn=<EmbeddingBackward0>)

So in the example above, there is an embedding of 16 dimensions for each associated discrete representation. pieces[0] contains each embedding for the first row in x (i.e., [0, 1, 3, 2]). So the 0 in the first row of x has a 16 dimension embedding attached to it, which is the continuous representation of that discrete one created earlier from the encoder. There are embeddings attached to each of these numbers in x, making up a total of 4 * 3 * 16 embeddings. 

In [23]:
# pytorch implementation

class DiscreteEmbedding(nn.Module):

    def __init__(self, n_cats, n_classes, embed_dim):
        super().__init__()

        self.tables = nn.ModuleList([
            nn.Embedding(n_classes, embed_dim) for _ in range(n_cats)
        ])

    def forward(self, x):
        pieces = [tab(x[:, i]) for i, tab in enumerate(self.tables)]
        return torch.cat(pieces, dim=-1)
            

TODO: finish decoder