# Task 1

Points: 20

You will be working with RGBNet: a network that accepts pixel position as input and outputs a triplet with R, G, B channels of that pixels.
RGBNet is trained on a fixed image. Your tasks are:

1. (14 points) Fill gaps in the code, which creates embeddings in 2 ways:
    - Learned embedding of size 64 (7 points)
    - Positional embedding of size 64 (7 points)


Please note that your code should train within 1 minute and report training loss below 15 for each case.
2. (6 points) Visualize output of the network for each encoding. Does it resemble the input image?

In [None]:
import math
import urllib
from typing import Literal

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

IMG_URL = "https://i.natgeofe.com/k/8fa25ea4-6409-47fb-b3cc-4af8e0dc9616/red-eyed-tree-frog-on-leaves-3-2.jpg"

In [None]:
url_response = urllib.request.urlopen(IMG_URL)
img = cv2.imdecode(np.array(bytearray(url_response.read()), dtype=np.uint8), -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

img = cv2.resize(img, (0,0), fx=0.01, fy=0.01) 
im_w, im_h = img.shape[0], img.shape[1]

plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
class NaiveEncoding(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(2, 64)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x.float())


class LearnedEncoding(nn.Module):
    def __init__(self) -> None:
        # Your code goes here. Output dim of embedding should be 64 
        ...    

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Your code goes here. Output dim of embedding should be 64 
        ...


class PositionalEncoding(nn.Module):    
    def __init__(self) -> None:
        # Your code goes here. Output dim of embedding should be 64 
        ...

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Your code goes here. Output dim of embedding should be 64 
        ...


# Define the network
class RGBNet(nn.Module):
    def __init__(self, encoding_type: Literal["naive", "learned", "positional"]) -> None:
        super().__init__()
        if encoding_type == "naive":
            self.encoding = NaiveEncoding()
        elif encoding_type == "learned":
            self.encoding = LearnedEncoding() 
        elif encoding_type == "positional":
            self.encoding = PositionalEncoding()
        else:
            raise ValueError("Wrong encoding type!")
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 3)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoding(x)
        x = F.softplus(self.fc1(x))
        x = F.softplus(self.fc2(x))
        x = self.fc3(x)
        return x


In [None]:
def train(used_embedding: Literal["naive", "learned", "positional"]) -> torch.nn.Module:
    # Instantiate the model and set it to the GPU (if available)
    model = RGBNet(encoding_type=used_embedding)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Define the loss function and optimizer
    criterion = nn.MSELoss(reduction="mean")
    optimizer = optim.AdamW(model.parameters(), lr=0.01)

    # Define the number of epochs and batch size
    num_epochs = 300
    batch_size = 32

    X, y = torch.cartesian_prod(torch.tensor(range(im_w)), torch.tensor(range(im_h))).to(device), torch.flatten(torch.tensor(img, dtype=torch.float32), start_dim=0, end_dim=1).to(device)

    model = model.train()

    # Train the model
    for epoch in range(num_epochs):
        total_loss = 0
        perm = torch.randperm(X.size(0))
        X, y = X[perm,:], y[perm, :]
        for i in range(0, X.shape[0], batch_size):
            # Get the current batch
            X_batch = X[i:i+batch_size]
            y_batch = y[i:i+batch_size]
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/X.size(0)}')
    return model

In [None]:
def visualize_model_output(model: RGBNet) -> None:
    # Your code goes here. Visualize the predicted image from pixels
    ...

In [None]:
# IMPORTANT: 
# training code works only for 
# used_embedding = "naive"
# training and visualization code should work in both
# used_embedding = "learned"
# used_embedding = "positonal"
used_embedding = "naive"

torch.manual_seed(0)

model = train(used_embedding=used_embedding)
visualize_model_output(model)