# Learning Encoder

The DreamerV3 encoder, depending on if the input gives pixels, is a convolutional encoder. The encoder will transform the high dim image into a lower dim *latent representation* (in this case, a discrete, compressed representation; see https://arxiv.org/abs/2312.01203 regarding discrete vs continuous representations). This latent representation is what the RSSM uses. An MLP encoder is used for other kinds of inputs, such as state vectors (e.g., positions, velocities, etc.). Robotics tasks will use both (camera input and state vectors).

So initially, I will focus on convolutional encoder.

A convolutional is basically a learnable filter (its weights). A *filter kernel* slides over the input and computes the dot product at each spatial location. More specifically, in this "spatial location," the kernel computes a dot product between the filter weights and the patch of the input it sits over, giving a value for the whole patch. The goal that is optimized here is whether the dot product is associated with what is wanted, i.e., it basically says "does this patch look like the filter." If yes, the weights will be nudged further because it did accurately capture the patch (change the weights to make the dot product larger). If no, the weights are nudged the other way so the filter stops firing that way (change the weights to make the dot product lower). 

### Convolution

In [2]:
import torch 

# code up a basic convolution to get the idea across

# example image 
x = torch.arange(25, dtype=torch.float32).reshape(5,5)

# a 3x3 kernel w/ bias
# the kernel and bias are learnable parameters
kernel = torch.randn(3,3)
bias   = torch.randn(1)

# output after the kernel (3x3)
out_h, out_w = x.shape[0] - 3 + 1, x.shape[1] - 3 + 1
out = torch.zeros(out_h, out_w)

# basic convolution (no stride, no padding, single channel--no color)
for i in range(out_h):
    for j in range(out_w):
        patch = x[i:i+3, j:j+3] # i+3 and j+3 to make a square of size 3x3
        out[i, j] = torch.sum(patch * kernel) + bias # dot product + bias

In [5]:
# pytorch implementation

import torch.nn as nn

class Conv2d(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(kernel_size, kernel_size))
        self.bias   = nn.Parameter(torch.randn(1))


    def forward(self, x):
        h, w = x.shape
        k = self.weight.shape[0] # k is a hyperparameter (kernel_size)
        out_h, out_w = h - k + 1, w - k + 1
        out = torch.zeros(out_h, out_w)

        for i in range(out_h):
            for j in range(out_w):
                patch = x[i:i+k, j:j+k]
                out[i, j] = torch.sum(patch * self.weight) + self.bias

        return out



### with stride and padding

In [None]:
import torch 

x = torch.arange(25, dtype=torch.float32).reshape(5,5)
kernel = torch.randn(3,3)
bias   = torch.randn(1)


padding = 1

if padding > 0:
    x = torch.nn.functional.pad(x, (padding, padding, padding, padding)) # use pad to add zeros around the border


# using stride 
stride = 2 # how far the kernel shifts after computing dot product for one position
out_h = x.shape[0] - kernel.shape[0] // stride + 1
out_w = x.shape[1] - kernel.shape[1] // stride + 1
out = torch.zeros(out_h, out_w)

# basic convolution (no stride, no padding, single channel--no color)
for i in range(out_h):
    for j in range(out_w):
        patch = x[i*stride:i*stride+kernel.shape[0], j*stride:j*stride+kernel.shape[1]]
        out[i, j] = torch.sum(patch * kernel) + bias 

In [None]:
# pytorch implementation

import torch.nn as nn

class Conv2d(nn.Module):
    def __init__(self, kernel_size, stride=1, padding=0):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(kernel_size, kernel_size))
        self.bias   = nn.Parameter(torch.randn(1))

        self.padding = padding
        self.stride = stride

    def forward(self, x):

        if self.padding > 0:
            x = torch.nn.functional.pad(x, (self.padding, self.padding, self.padding, self.padding))
        
        h, w = x.shape
        k = self.weight.shape[0] # k is a hyperparameter (kernel_size)
        out_h = h - k // self.stride + 1
        out_w = w - k // self.stride + 1
        out = torch.zeros(out_h, out_w)

        for i in range(out_h):
            for j in range(out_w):
                patch = x[i*self.stride:i*self.stride+k, j*self.stride:j*self.stride+k]
                out[i, j] = torch.sum(patch * self.weight) + self.bias

        return out

### Conv Encoder for DreamerV3

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvEncoder(nn.Module):

    def __init__(self, in_channels: int = 3, model_dim: int = 1024):
        super().__init__()

        chs = max(16, model_dim // 16)

        # want to increase the feature channel (dim=1) as we compress the image (dim 2 and 3)
        # we begin with (B, feature_chs, H, W)
        # want feature_chs to increase to incorporate more semantic meaning
        # while reducing the visual map by compression so that semantic wholeness is acquired
        # i.e., if we compress the image enough, the weights have to learn more 'bigger picture' features
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,   chs,   kernel_size=4, stride=2, padding=1), nn.SiLU(), 
            nn.Conv2d(chs,           chs*2, kernel_size=4, stride=2, padding=1), nn.SiLU(),     
            nn.Conv2d(chs*2,         chs*4, kernel_size=4, stride=2, padding=1), nn.SiLU(),
            nn.Conv2d(chs*4,         chs*8, kernel_size=4, stride=2, padding=1), nn.SiLU()
        )
        # we expect the flattened output, e.g., all dimensions except the batch multiplied together
        # assumes that the end output will have H, W = 4, 4
        self.proj = nn.Linear(8 * chs * 4 * 4, model_dim)

    def forward(self, x): # expects a 64 x 64 img
        h = self.conv(x)
        if h.shape[-1] != 4 or h.shape[-2] != 4:
            h = F.adaptive_avg_pool2d(h, (4,4))
        # we need to flatten before puttign into linear
        h = h.flatten(1) # flatten from dim=1 to dim=-1
        logits = self.proj(h)
        return logits



In [5]:
conv = ConvEncoder()
x = torch.randn(3, 3, 64, 64)
conv(x).shape

torch.Size([3, 1024])