# DND Exam 1

## Name:
## Neptun:

In [2]:
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
image = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])(Image.new('RGB', (224, 224), color=(0, 0, 0)))

# Architecture

Your task is to implement the following architecture!

![ARCHITECTURE.png](./model.png)

In [31]:
# Custom layers for image reconstruction
class CustomNet(nn.Module):
    def __init__(self):
        super(CustomNet, self).__init__()

        # Load a pre-trained ResNet18 model
        resnet18 = models.resnet18(pretrained=True)

        # Remove the last two layers from ResNet18 to use it as a feature extractor
        modules = list(resnet18.children())[:-2]
        self.resnet18 = nn.Sequential(*modules)
        
        # Defining several transposed convolution layers for upsampling the features
        self.upper_deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upper_deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upper_deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upper_deconv4 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)

        # Max pooling layer for downscaling
        self.upper_mp5 = nn.MaxPool2d(28)

        # Additional convolution and pooling layers for further feature manipulation
        self.lower_conv1 = nn.Conv2d(512, 256, kernel_size=1, stride=2)
        self.lower_mp2 = nn.MaxPool2d(2, stride=2)
        self.lower_deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=4)
        self.lower_bn4 = nn.BatchNorm2d(128)
        self.lower_mp5 = nn.MaxPool2d(2, stride=2)
        self.lower_conv6 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)

        self.lower_conv7 = nn.Conv2d(512, 3, kernel_size=3, stride=1, padding=1)

        # Activation functions
        self.activation = nn.Sigmoid()

        # LSTM layer for sequence processing
        self.lstm = nn.LSTM(input_size=6, hidden_size=10, batch_first=True)

    def forward(self, x):
        # Forward pass through the modified ResNet18
        resnet18 = self.resnet18(x)                             # [1, 512, 7, 7]
        print(f'resnet18: {resnet18.shape}')
        
        # Forward pass through the upper branch
        upper_deconv1 = self.upper_deconv1(resnet18)            # [1, 256, 14, 14]
        print(f'upper_deconv1: {upper_deconv1.shape}')
        upper_deconv1 = nn.functional.relu(upper_deconv1)       
        print(f'upper_deconv1 relu: {upper_deconv1.shape}')
        upper_deconv2 = self.upper_deconv2(upper_deconv1)       # [1, 128, 28, 28]
        print(f'upper_deconv2: {upper_deconv2.shape}')
        upper_deconv2 = nn.functional.relu(upper_deconv2)
        print(f'upper_deconv2 relu: {upper_deconv2.shape}')
        upper_deconv3 = self.upper_deconv3(upper_deconv2)       # [1, 64, 56, 56]
        print(f'upper_deconv3: {upper_deconv3.shape}')
        upper_deconv3 = nn.functional.relu(upper_deconv3)
        print(f'upper_deconv3 relu: {upper_deconv3.shape}')
        upper_deconv4 = self.upper_deconv4(upper_deconv3)       # [1, 3, 112, 112]
        print(f'upper_deconv4: {upper_deconv4.shape}')
        upper_mp5 = self.upper_mp5(upper_deconv4)               # [1, 3, 4, 4]
        print(f'mp5: {upper_mp5.shape}')

        # Forward pass through the lower branch
        lower_conv1 = self.lower_conv1(resnet18)                # [1, 256, 4, 4]
        print(f'lower_conv1: {lower_conv1.shape}')
        lower_mp2 = self.lower_mp2(lower_conv1)                 # [1, 256, 2, 2]
        print(f'lower_mp2: {lower_mp2.shape}')
        lower_deconv3 = self.lower_deconv3(lower_mp2)           # [1, 128, 8, 8]
        print(f'lower_deconv3: {lower_deconv3.shape}')
        lower_bn4 = self.lower_bn4(lower_deconv3)               # [1, 128, 8, 8]
        print(f'lower_bn4: {lower_bn4.shape}')
        lower_mp5 = self.lower_mp5(lower_bn4)                   # [1, 128, 4, 4]
        print(f'lower_mp5: {lower_mp5.shape}')
        lower_conv6 = self.lower_conv6(lower_mp5)               # [1, 256, 4, 4]
        print(f'lower_conv6: {lower_conv6.shape}')


        # Concatenating the features
        lower_concat = torch.cat((lower_conv1, lower_conv6), dim=1) # [1, 512, 4, 4]
        print(f'lower_concat: {lower_concat.shape}')

        lower_conv7 = self.lower_conv7(lower_concat)            # [1, 3, 4, 4]
        print(f'lower_conv7: {lower_conv7.shape}')


        # Concatenating the features
        concat = torch.cat((upper_mp5, lower_conv7), dim=1)    # [1, 6, 4, 4]
        print(f'concat: {concat.shape}')


        # Reshaping for the LSTM input and processing through LSTM
        lstm_input = concat.reshape(1, 16, -1)                  # [1, 16, 6]
        print(f'lstm_input: {lstm_input.shape}')

        lstm_output, _ = self.lstm(lstm_input)                  # [1, 16, 10]
        print(f'lstm_output: {lstm_output.shape}')

        # Applying sigmoid activation function
        output = self.activation(lstm_output)
        
        return output

# Initialize the model
model = CustomNet()

# Add a batch dimension to the image if it doesn't have one
if len(image.shape) == 3:
    image = image.unsqueeze(0)  # Add batch dimension

# Move the model and image to the same device (like CPU or GPU)
model = model.to(device)
image = image.to(device)

# Set the model to evaluation mode
model.eval()

# Perform inference without calculating gradients
with torch.no_grad():
    output = model(image)

# Print the shape of the output
output.shape


resnet18: torch.Size([1, 512, 7, 7])
upper_deconv1: torch.Size([1, 256, 14, 14])
upper_deconv1 relu: torch.Size([1, 256, 14, 14])
upper_deconv2: torch.Size([1, 128, 28, 28])
upper_deconv2 relu: torch.Size([1, 128, 28, 28])
upper_deconv3: torch.Size([1, 64, 56, 56])
upper_deconv3 relu: torch.Size([1, 64, 56, 56])
upper_deconv4: torch.Size([1, 3, 112, 112])
mp5: torch.Size([1, 3, 4, 4])
lower_conv1: torch.Size([1, 256, 4, 4])
lower_mp2: torch.Size([1, 256, 2, 2])
lower_deconv3: torch.Size([1, 128, 8, 8])
lower_bn4: torch.Size([1, 128, 8, 8])
lower_mp5: torch.Size([1, 128, 4, 4])
lower_conv6: torch.Size([1, 256, 4, 4])
lower_concat: torch.Size([1, 512, 4, 4])
lower_conv7: torch.Size([1, 3, 4, 4])
concat: torch.Size([1, 6, 4, 4])
lstm_input: torch.Size([1, 16, 6])
lstm_output: torch.Size([1, 16, 10])
output: torch.Size([1, 16, 10])


torch.Size([1, 16, 10])

Expected output: [1,16,10]

In [None]:
list(output.shape) == [1,16,10]

False