# Learning Encoder

The DreamerV3 encoder, depending on if the input gives pixels, is a convolutional encoder. The encoder will transform the high dim image into a lower dim *latent representation* (in this case, a discrete, compressed representation; see https://arxiv.org/abs/2312.01203 regarding discrete vs continuous representations). This latent representation is what the RSSM uses. An MLP encoder is used for other kinds of inputs, such as state vectors (e.g., positions, velocities, etc.). Robotics tasks will use both (camera input and state vectors).

So initially, I will focus on convolutional encoder.

A convolutional is basically a learnable filter (its weights). A *filter kernel* slides over the input and computes the dot product at each spatial location. More specifically, in this "spatial location," the kernel computes a dot product between the filter weights and the patch of the input it sits over, giving a value for the whole patch. The goal that is optimized here is whether the dot product is associated with what is wanted, i.e., it basically says "does this patch look like the filter." If yes, the weights will be nudged further because it did accurately capture the patch (change the weights to make the dot product larger). If no, the weights are nudged the other way so the filter stops firing that way (change the weights to make the dot product lower). 

In [None]:
import torch 

# code up a basic convolution to get the idea across

# example image 
x = torch.arange(25, dtype=torch.float32).reshape(5,5)

# a 3x3 kernel w/ bias
# the kernel and bias are learnable parameters
kernel = torch.randn(3,3)
bias   = torch.randn(1)

# output after the kernel (3x3)
out_h, out_w = x.shape[0] - 3 + 1, x.shape[1] - 3 + 1
out = torch.zeros(out_h, out_w)

# basic convolution (no stride, no padding, single channel)
for i in range(out_h):
    for j in range(out_w):
        patch = x[i:i+3, j:j+3] # i+3 and j+3 to make a square of size 3x3
        out[i, j] = torch.sum(patch * kernel) + bias # dot product + bias

In [5]:
# pytorch implementation

import torch.nn as nn

class Conv2d(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(kernel_size, kernel_size))
        self.bias   = nn.Parameter(torch.randn(1))


    def forward(self, x):
        h, w = x.shape
        k = self.weight.shape[0] # k is a hyperparameter (kernel_size)
        out_h, out_w = h - k + 1, w - k + 1
        out = torch.zeros(out_h, out_w)

        for i in range(out_h):
            for j in range(out_w):
                patch = x[i:i+k, j:j+k]
                out[i, j] = torch.sum(patch * self.weight) + self.bias

        return out

