In [None]:
import torch as th 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 

import random 
import os 
import cv2 
import numpy as np 

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, n_layers:int, filters:int, kernel:int=3, growth_factor:float=2.0,  
                 moment:float=0.7, stride:bool=True, alpha:float=0.03):
        super(ConvBlock,  self).__init__()
        padding = (kernel-1)//2
        self.stride = stride

        self.norm = nn.ModuleList([nn.BatchNorm1d(num_features=filters, momentum=moment) for _ in range(n_layers)])
        self.conv = nn.ModuleList([nn.Conv2d(filters, filters, kernel, padding=padding) for _ in range(n_layers-1)])
        self.nlin = nn.LeakyReLU(alpha)

        if stride:
            self.conv.append(nn.Conv2d(filters//growth_factor, filters, kernel, stride=2, padding=padding))
        else:
            self.conv.append(nn.Conv2d(filters//growth_factor, filters, kernel, padding=padding))
            self.pool = nn.MaxPool2d(2, 2)

        self.conv.reverse()


    def forward(self, x):
        if not self.stride:
            x = self.pool(x)
        
        for conv, norm in zip(self.conv, self.norm):
            x = self.nlin(norm(conv(x)))
        return x

        

class OneShot(nn.Module):
    def __init__(self, n_blocks:int=7, n_high_refine:int=3, n_conv_high_refine:int=3, 
                 n_conv_end:int=2, filters:int=64, start_kernel:int=5, kernel:int=3, 
                 growth_factor:float=2.0, alpha:float=0.07, moment:float=0.7, dense:int=512, 
                 final:int=100, drop:float=0.2, stride:bool=True):
        super(OneShot, self).__init__()

        self.conv = nn.ModuleList([nn.Conv2d(3, filters, start_kernel, stride=2, padding=(start_kernel-1)//2)])
        for c in range(n_blocks):
            filters = int(filters*growth_factor)
            if c<= n_high_refine:
                self.conv.append(ConvBlock(n_conv_high_refine, filters, kernel, growth_factor, moment, stride, alpha))
            else: 
                self.conv.append(ConvBlock(n_conv_end, filters, kernel, growth_factor, moment, stride, alpha))
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.lden = nn.Linear(int(filters*growth_factor), dense)
        self.lvec = nn.Linear(dense, final)
        self.nlin = nn.LeakyReLU(alpha)
        self.drop = nn.Dropout(drop)


    def forward(self, x):
        for conv in self.conv:
             x = conv(x)

        x = self.flat(self.pool(x))
        x = self.drop(self.nlin(self.lden))
        return self.lvec(x)

        