In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# !pip install torch-optimizer
import torch_optimizer as optim   
import einops
from einops import rearrange

In [2]:
class Linear_block(nn.Module):
    def __init__(self, in_feature, out_feature, dropout):

        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_feature, out_feature),
            nn.SELU(),
            nn.Dropout(dropout)
        )


    def forward(self, x):
        x = self.block(x)
        return x


In [None]:
class FCnet(nn.Module):
    def __init__(self, num_pixels):
        super().__init__()

        self.N = num_pixels

        H = num_pixels*num_pixels

        self.block1 = Linear_block(in_feature = 2*H, out_feature = 4*H, dropout = 0.2)
        self.block2 = Linear_block(in_feature = 4*H, out_feature = 4*H, dropout = 0.2)
        self.block3 = Linear_block(in_feature = 4*H, out_feature = 2*H, dropout = 0.2)

        self.last_block = nn.Sequential(
            nn.Linear(2*H, H),
            nn.Sigmoid()
        )


        
    def forward(self, x):

        x = torch.flatten(x, start_dim = 1)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.last_block(x)

        x = rearrange(x, 'b (N1 N2) -> b N1 N2', N2 = self.N)
        
        return x
                           