In [21]:
import numpy as np
import torch
import torch.nn as nn
import cv2
import os
import matplotlib.pyplot as plt

In [20]:
DATA_PATH = "../../data/"
SEMANTICS_PATH = os.path.join(DATA_PATH, "data_semantics")
TRAIN_PATH = os.path.join(SEMANTICS_PATH, "training")
TRAIN_RGB_PATH = os.path.join(TRAIN_PATH, "image_2")
TRAIN_SEMANTIC_PATH = os.path.join(TRAIN_PATH, "semantic")

rgbs = os.listdir(TRAIN_RGB_PATH)
semantics = os.listdir(TRAIN_SEMANTIC_PATH)

In [33]:
class HorizontalBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(input_dim, output_dim, kernel_size=3)
        self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=3)
        self.relu = nn.LeakyReLU(inplace=True)
        self.bn = nn.BatchNorm2d(output_dim)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class DownBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        return self.max_pool(x)
    
class UpBlock(nn.Module):
    def __init__(self, input_dim):
        super().__init__()        
        self.tconv = nn.ConvTranspose2d(input_dim, input_dim//2, kernel_size=2, stride=2)
    
    def forward(self, x):
        x = self.tconv(x)
        return x
    
x = torch.randn(1, 3, 360, 360)
hb = HorizontalBlock(3, 6)
y = hb(x)
print(x.shape)
print(y.shape)

db = DownBlock()
y1 = db(y)
print(y1.shape)

ub = UpBlock(6)
y2 = ub(y1)
print(y2.shape)

y3 = hb(y2)
print(y3.shape)

torch.Size([1, 3, 360, 360])
torch.Size([1, 6, 356, 356])
torch.Size([1, 6, 178, 178])
torch.Size([1, 3, 356, 356])
torch.Size([1, 6, 352, 352])
