In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [GroupNorm] => LeakyReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(1, out_channels, 1e-3),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(1, out_channels, 1e-3),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            DoubleConv(in_channels, out_channels),
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2, skip=True):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        if skip:
            x = torch.cat([x2, x1], dim=1)
        else:
            x = x1
        return self.conv(x)
    
class RepeatDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, repeat=1):
        super().__init__()
        self.block = [Down(in_channels, out_channels)]
        for i in range(0, repeat-1):
            self.block.append(DoubleConv(out_channels, out_channels))
        self.block = nn.ModuleList(self.block)
    
    def forward(self, x):
        for layer in self.block:
            x = layer(x)
        return x
    
class RepeatBlock(nn.Module):
    def __init__(self, in_channels, out_channels, repeat=1):
        super().__init__()
        self.block = []
        for i in range(0, repeat):
            self.block.append(DoubleConv(out_channels, out_channels))
        self.block = nn.ModuleList(self.block)
    
    def forward(self, x):
        for layer in self.block:
            x = layer(x)
        return x

class RepeatUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, repeat=1):
        super().__init__()
        self.block = [Up(in_channels, out_channels)]
        for i in range(0, repeat-1):
            self.block.append(DoubleConv(out_channels, out_channels))
        self.block = nn.ModuleList(self.block)
    
    def forward(self, x1, x2, skip=True):
        x = self.block[0](x1, x2, skip)
        for layer in self.block[1:]:
            x = layer(x)
        return x
    

class UNet(nn.Module):
    def __init__(self, in_channels, d, repeat=4, skip=True):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.d = d
        self.skip = skip
        
        # Encoder path
        self.encoder = []
        for depth in range(0, d):
            in_dim = 2**depth * self.in_channels
            self.encoder.append(RepeatDownBlock(in_dim, 2*in_dim, repeat=repeat))
        self.encoder = nn.ModuleList(self.encoder)
        
        bottleneck_dim = 2*in_dim
        self.bottleneck = RepeatBlock(bottleneck_dim, bottleneck_dim, repeat=2*repeat)

        # Decoder path
        self.decoder = []
        for depth in range(d-1, -1, -1):
            in_dim = 2**(depth+1) * self.in_channels
            in_ch_ = 2*in_dim if skip else in_dim
            self.decoder.append(RepeatUpBlock(in_ch_, in_dim//2, repeat=repeat)) #due to skip connections
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.decoder.append(DoubleConv(self.in_channels, self.in_channels))
        self.decoder = nn.ModuleList(self.decoder)
        
    def forward(self, x):
        enc_outputs = []
        for layer in self.encoder:
            x = layer(x)
            enc_outputs.append(x)
        x = self.bottleneck(x)
        for i, layer in enumerate(self.decoder[:-1]):
            x = layer(x, enc_outputs[-(i+1)], skip=self.skip)
        x = self.upsample(x)
        x = self.decoder[-1](x)
        return x

In [50]:
def count_parameters(model, verbose=True):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if num_params >= 1e6:
        num_params /= 1e6
        suffix = "M"
    elif num_params >= 1e3:
        num_params /= 1e3
        suffix = "K"
    else:
        suffix = ""
    if verbose:
        print(f"Number of trainable parameters: {num_params:.2f}{suffix}")
    return num_params

In [51]:
model = UNet(in_channels=128, d=2, repeat=2, skip=True)
inputs = torch.randn(10, 128, 70, 70)
out = model(inputs)
print(out.shape)

torch.Size([10, 128, 70, 70])


In [52]:
count_parameters(model)

Number of trainable parameters: 34.68M


34.67904

In [53]:
model = UNet(in_channels=128, d=2, repeat=2, skip=False)
inputs = torch.randn(10, 128, 70, 70)
out = model(inputs)
print(out.shape)

torch.Size([10, 128, 70, 70])


In [54]:
count_parameters(model)

Number of trainable parameters: 33.20M


33.20448

In [131]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VelAutoEncoder(nn.Module):
    def __init__(self, input_channel, encoder_channels, decoder_channels):
        super(VelAutoEncoder, self).__init__()
        
        encoder_layers = nn.ModuleList()
        decoder_layers = nn.ModuleList()
        
        enlayer1 = nn.Sequential(nn.Conv2d(input_channel, encoder_channels[0],
                                                kernel_size=1, stride=1, padding=0, 
                                                padding_mode='reflect'),
                                     nn.BatchNorm2d(encoder_channels[0]),
                                     nn.LeakyReLU(0.2, inplace=True))
        encoder_layers.append(enlayer1)
        
        enlayer2 = nn.Sequential(nn.Conv2d(encoder_channels[0], encoder_channels[1],
                                                kernel_size=1, stride=1, padding=0,
                                                padding_mode='reflect'),
                                     nn.BatchNorm2d(encoder_channels[1]),
                                     nn.LeakyReLU(0.2, inplace=True))
        
        encoder_layers.append(enlayer2)
        
        enlayer3 = nn.Sequential(nn.Conv2d( encoder_channels[1], encoder_channels[2],
                                                kernel_size=(3, 3), stride=1, padding=0,
                                                padding_mode='reflect'),
                                     nn.BatchNorm2d(encoder_channels[2]),
                                     nn.LeakyReLU(0.2, inplace=True))
        
        encoder_layers.append(enlayer3)
        
        enlayer4 = nn.Sequential(nn.Conv2d(encoder_channels[2], encoder_channels[3],
                                                kernel_size=(3, 3), stride=(2,2), padding=(1,1),
                                                padding_mode='reflect'),
                                     nn.BatchNorm2d(encoder_channels[3]),
                                     nn.LeakyReLU(0.2, inplace=True))
        
        encoder_layers.append(enlayer4)
        
        enlayer5 = nn.Sequential(nn.Conv2d(encoder_channels[3], encoder_channels[4],
                                                kernel_size=(3, 3), stride=(2,2), padding=0,
                                                padding_mode='reflect'),
                                     nn.BatchNorm2d(encoder_channels[4]),
                                     nn.LeakyReLU(0.2, inplace=True))
        
        encoder_layers.append(enlayer5)
        
        enlayer5 = nn.Sequential(nn.Conv2d(encoder_channels[4], encoder_channels[5],
                                                kernel_size=(3, 3), stride=1, padding=(1,1),
                                                padding_mode='reflect'),
                                     nn.BatchNorm2d(encoder_channels[5]),
                                     nn.LeakyReLU(0.2, inplace=True))
        
        encoder_layers.append(enlayer5)
        
        enlayer6 = nn.Upsample((16,16), mode='bilinear')
        
        encoder_layers.append(enlayer6)
        
        
        delayer1 = nn.Sequential(nn.ConvTranspose2d(encoder_channels[-1], decoder_channels[0],
                                                kernel_size=(3,3), stride=1, padding=0,padding_mode='zeros'),
                             nn.BatchNorm2d(decoder_channels[0]),
                             nn.Tanh())
        
        decoder_layers.append(delayer1)
        
        delayer2 = nn.Sequential(nn.ConvTranspose2d(decoder_channels[0], decoder_channels[1],
                                                kernel_size=(3,3), stride=(2,2), padding=0,padding_mode='zeros'),
                             nn.BatchNorm2d(decoder_channels[1]),
                             nn.Tanh())
        
        decoder_layers.append(delayer2)
        
        delayer3 = nn.Sequential(nn.ConvTranspose2d(decoder_channels[1], decoder_channels[2],
                                                kernel_size=(3,3), stride=1, padding=0,padding_mode='zeros'),
                             nn.BatchNorm2d(decoder_channels[2]),
                             nn.Tanh())
        
        decoder_layers.append(delayer3)
        
        delayer4 = nn.Sequential(nn.ConvTranspose2d(decoder_channels[2], decoder_channels[3],
                                                kernel_size=1, stride=1, padding=0,padding_mode='zeros'),
                             nn.BatchNorm2d(decoder_channels[3]),
                             nn.Tanh())
        
        decoder_layers.append(delayer4)
        
        delayer5 = nn.Sequential(nn.ConvTranspose2d(decoder_channels[3], decoder_channels[4],
                                                kernel_size=1, stride=1, padding=0,padding_mode='zeros'),
                             nn.BatchNorm2d(decoder_channels[4]),
                             nn.Tanh(),
                             nn.Conv2d(decoder_channels[4], decoder_channels[4], kernel_size=1, stride=1)
                             )
        decoder_layers.append(delayer5)
        
        delayer5 = nn.Upsample((70,70), mode='bilinear')
        
        decoder_layers.append(delayer5)
        
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        
    def forward(self, x):
        x = self.encoder_layers(x)
        x = self.decoder_layers(x)
        return x
    
    def embedding(self, x):
        for layer in self.encoder_layers:
            x = layer(x)
            print(x.shape)
        return x
    
    def decoder(self, x):
        for layer in self.decoder_layers:
            x = layer(x)
            print(x.shape)
        return x

In [132]:
ae = VelAutoEncoder(input_channel=1, encoder_channels=[8, 16, 32, 64, 128, 256], decoder_channels=[128, 64, 32, 16, 1])

In [133]:
inputs = torch.randn(10, 1, 70, 70)
out = ae.embedding(inputs)

torch.Size([10, 8, 70, 70])
torch.Size([10, 16, 70, 70])
torch.Size([10, 32, 68, 68])
torch.Size([10, 64, 34, 34])
torch.Size([10, 128, 16, 16])
torch.Size([10, 256, 16, 16])
torch.Size([10, 256, 16, 16])


In [122]:
inputs = torch.randn(10, 128, 32, 32)
out = ae.decoder(inputs)

torch.Size([10, 128, 34, 34])
torch.Size([10, 64, 69, 69])
torch.Size([10, 32, 71, 71])
torch.Size([10, 16, 71, 71])
torch.Size([10, 1, 71, 71])
torch.Size([10, 1, 70, 70])


In [123]:
params_dict = { 
                  'in_dim':(70, 70),
                  'latent_dim':(64,64),
                  'in_channels':1,
                  'encoder_blocks':{
                       "0": {
                            "out_channels":8,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                        "1": {
                            "out_channels":16,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      "2": {
                            "out_channels":32,
                            "kernel_size":(3,3), 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      "3": {
                            "out_channels":64,
                            "kernel_size":(3,3), 
                            "stride":(2,2),
                            "padding":(1,1),
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      "4": {
                            "out_channels":128,
                            "kernel_size":(3,3), 
                            "stride":(2,2),
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                        "5": {
                            "out_channels":256,
                            "kernel_size":(3,3), 
                            "stride":1,
                            "padding":(1,1),
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      },
                  'decoder_blocks': {
                      "0": {
                            "out_channels":128,
                            "kernel_size":(3,3), 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                        "1": {
                            "out_channels":64,
                            "kernel_size":(3,3), 
                            "stride":(2,2),
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                      "2": {
                            "out_channels":32,
                            "kernel_size":(3,3), 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                      "3": {
                            "out_channels":16,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                      "4": {
                            "out_channels":1,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                  },
                  "last_conv2d":{
                      "kernel_size":1, 
                      "stride":1,
                      "padding":0,
                    }
                }
                

In [124]:
import json

with open('../configs/velocity_config_latent_dim_32.json', 'w') as file:
    json.dump(params_dict, file)

In [1]:
vel_params_70 = { 
                  'in_dim':(70, 70),
                  'latent_dim':(70,70),
                  'in_channels':1,
                  'encoder_blocks':{
                       "0": {
                            "out_channels":8,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                        "1": {
                            "out_channels":16,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      "2": {
                            "out_channels":32,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      "3": {
                            "out_channels":64,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      "4": {
                            "out_channels":128,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.LeakyReLU",
                                "activation_params":{
                                    "negative_slope":0.2,
                                    "inplace":True,
                                },
                            },
                            "padding_mode":'reflect',
                           },
                      },
                  'decoder_blocks': {
                      "0": {
                            "out_channels":128,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                        "1": {
                            "out_channels":64,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                      "2": {
                            "out_channels":32,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                      "3": {
                            "out_channels":16,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                      "4": {
                            "out_channels":1,
                            "kernel_size":1, 
                            "stride":1,
                            "padding":0,
                            "activation":{
                                "activation_fn":"nn.Tanh",
                                "activation_params":{
                                },
                            },
                            "padding_mode":'zeros',
                           },
                  },
                  "last_conv2d":{
                      "kernel_size":1, 
                      "stride":1,
                      "padding":0,
                    }
                }
                

In [2]:
act

NameError: name 'act' is not defined

In [18]:

forward_params = { 
        'WaveformNet':{
            'in_channels': 1, #WaveformNet 
            'encoder_channels': [32, 64, 128, 256, 512], #WaveformNet 
            'decoder_channels': [256, 128, 64, 5],#WaveformNet 
            },
         'FNO':{
             'modes1': 30, #FNO 
             'modes2': 30, #FNO 
             'width': 32, #FNO 
             'out_dim':5 #FNO
          },  
    }

In [23]:
forward_params.setdefault('fancymodel', {})
forward_params['fancymodel']['cfg_path'] = "e;;p"

In [24]:
forward_params

{'WaveformNet': {'in_channels': 1,
  'encoder_channels': [32, 64, 128, 256, 512],
  'decoder_channels': [256, 128, 64, 5],
  'cfg_path': 'e;;p'},
 'FNO': {'modes1': 30, 'modes2': 30, 'width': 32, 'out_dim': 5},
 'fancymodel': {'cfg_path': 'e;;p'}}

In [17]:
forward_params['iUnet'] = "test"

In [3]:
from torchsummary import summary

In [68]:
device = torch.device("cuda")

In [66]:
summary(amp_ae.to(device), (5, 1000, 70))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 8, 334, 70]             288
       BatchNorm2d-2           [-1, 8, 334, 70]              16
         LeakyReLU-3           [-1, 8, 334, 70]               0
            Conv2d-4          [-1, 16, 165, 70]             912
       BatchNorm2d-5          [-1, 16, 165, 70]              32
         LeakyReLU-6          [-1, 16, 165, 70]               0
            Conv2d-7           [-1, 32, 82, 70]           2,592
       BatchNorm2d-8           [-1, 32, 82, 70]              64
         LeakyReLU-9           [-1, 32, 82, 70]               0
           Conv2d-10           [-1, 64, 78, 70]          10,304
      BatchNorm2d-11           [-1, 64, 78, 70]             128
        LeakyReLU-12           [-1, 64, 78, 70]               0
           Conv2d-13          [-1, 128, 74, 70]          41,088
      BatchNorm2d-14          [-1, 128,

In [76]:
eval("nn.Tanh")

torch.nn.modules.activation.Tanh

In [78]:
amp_params_70

{'in_dim': (1000, 70),
 'latent_dim': (70, 70),
 'in_channels': 5,
 'encoder_blocks': {'0': {'out_channels': 8,
   'kernel_size': (7, 1),
   'stride': (3, 1),
   'padding': (3, 0),
   'activation': {'activation_fn': 'nn.LeakyReLU',
    'activation_params': {'negative_slope': 0.2}},
   'padding_mode': 'zeros'},
  '1': {'out_channels': 16,
   'kernel_size': (7, 1),
   'stride': (2, 1),
   'padding': (1, 0),
   'activation': {'activation_fn': 'nn.LeakyReLU',
    'activation_params': {'negative_slope': 0.2}},
   'padding_mode': 'zeros'},
  '2': {'out_channels': 32,
   'kernel_size': (5, 1),
   'stride': (2, 1),
   'padding': (1, 0),
   'activation': {'activation_fn': 'nn.LeakyReLU',
    'activation_params': {'negative_slope': 0.2}},
   'padding_mode': 'zeros'},
  '3': {'out_channels': 64,
   'kernel_size': (5, 1),
   'stride': (1, 1),
   'padding': (0, 0),
   'activation': {'activation_fn': 'nn.LeakyReLU',
    'activation_params': {'negative_slope': 0.2}},
   'padding_mode': 'zeros'},
  '4

In [2]:
import json

with open('../configs/velocity_config_latent_dim_70.json', 'w') as file:
    json.dump(vel_params_70, file)

In [7]:
from torchsummary import summary

In [65]:
import os
import json
import torch
import torch.nn as nn

class ConvBlockLegacy(nn.Module):
    """(convolution => [BatchNorm] => LeakyReLU) * 2"""
    def __init__(self, 
                 in_channels, 
                 out_channels,
                 activation_fn=nn.LeakyReLU,
                 activation_params={},
                 kernel_size=(1, 1), 
                 stride=(1, 1), 
                 padding=(0, 0),
                 padding_mode='zeros',
                 transpose_conv=False,
                ):
        super().__init__()
        conv_fn = nn.ConvTranspose2d if transpose_conv else nn.Conv2d
        self.conv_block = nn.Sequential(
            conv_fn(in_channels, out_channels, kernel_size=kernel_size, 
                    stride=stride, padding=padding, padding_mode=padding_mode),
            nn.BatchNorm2d(out_channels),
            activation_fn(**activation_params),
        )
    def forward(self, x):
        return self.conv_block(x)

class ConvBlock(nn.Module):
    """(convolution => [GroupNorm] => LeakyReLU) * 2"""
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size=(1, 1), 
                 stride=(1, 1), 
                 padding=(0, 0),
                 negative_slope=0.01,
                 groups=1,
                 transpose_conv=False,
                ):
        super().__init__()
        conv_fn = nn.ConvTranspose2d if transpose_conv else nn.Conv2d
        self.conv_block = nn.Sequential(
            conv_fn(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.GroupNorm(groups, out_channels, 1e-3),
            nn.LeakyReLU(negative_slope=negative_slope, inplace=True),
        )
    def forward(self, x):
        return self.conv_block(x)

class AutoEncoder(nn.Module):
    def __init__(self, cfg_path="./configs/", cfg_name='amplitude_config_latent_dim_70.json'):
        super(AutoEncoder, self).__init__()
        
        config_file = os.path.join(cfg_path, cfg_name)
        with open(config_file, 'r') as file:
            cfg = json.load(file)
        
        encoder_layers = nn.ModuleList()
        decoder_layers = nn.ModuleList()
        
        for i, key in enumerate(cfg["encoder_blocks"].keys()):
            enc_dict = cfg["encoder_blocks"][key]
            activation_fn = eval(enc_dict["activation"]["activation_fn"])
            activation_params = enc_dict["activation"]["activation_params"]
            in_channels = cfg["in_channels"] if i==0 else cfg["encoder_blocks"][str(i-1)]["out_channels"]
            out_channels = enc_dict["out_channels"]
            conv_block = ConvBlockLegacy(
                             in_channels=in_channels, 
                             out_channels=out_channels,
                             activation_fn=activation_fn,
                             activation_params=activation_params,
                             kernel_size=enc_dict["kernel_size"], 
                             stride=enc_dict["stride"], 
                             padding=enc_dict["padding"],
                             padding_mode=enc_dict["padding_mode"],
                             transpose_conv=False,          
                    )
            encoder_layers.append(conv_block)
        encoder_layers.append(nn.Upsample(cfg["latent_dim"], mode='bilinear'))
        
        for i, key in enumerate(cfg["decoder_blocks"].keys()):
            dec_dict = cfg["decoder_blocks"][key]
            activation_fn = eval(dec_dict["activation"]["activation_fn"])
            activation_params = dec_dict["activation"]["activation_params"]
            in_channels = out_channels if i==0 else cfg["decoder_blocks"][str(i-1)]["out_channels"]
            out_channels = dec_dict["out_channels"]
            conv_block = ConvBlockLegacy(
                             in_channels=in_channels, 
                             out_channels=out_channels,
                             activation_fn=activation_fn,
                             activation_params=activation_params,
                             kernel_size=dec_dict["kernel_size"], 
                             stride=dec_dict["stride"], 
                             padding=dec_dict["padding"],
                             padding_mode=dec_dict["padding_mode"],
                             transpose_conv=True,          
                    )
            decoder_layers.append(conv_block)
        
        #last convolutional block after decoder
        last_conv = nn.Conv2d(
                             in_channels=out_channels, 
                             out_channels=out_channels,
                             kernel_size=cfg["last_conv2d"]["kernel_size"], 
                             stride=cfg["last_conv2d"]["stride"], 
                             padding=cfg["last_conv2d"]["padding"],        
                    )
        decoder_layers.append(last_conv)
        decoder_layers.append(nn.Upsample(cfg["in_dim"], mode='bilinear'))
        
        self.encoder_layers = nn.Sequential(*encoder_layers)
        self.decoder_layers = nn.Sequential(*decoder_layers)
    
    def forward(self, x):
        x = self.encoder_layers(x)
        x = self.decoder_layers(x)
        return x
    
    def embedding(self, x):
        x = self.encoder_layers(x)
        return x
    
    def decoder(self, x):
        x = self.decoder_layers(x)
        return x

In [75]:
autoencoder = AutoEncoder(cfg_path="/home/darka/projects/Latent_Bijectivity/configs/", 
                          cfg_name='amplitude_config_latent_dim_70.json'
                         )

In [76]:
summary(autoencoder.to(device), (5, 1000, 70))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 8, 334, 70]             288
       BatchNorm2d-2           [-1, 8, 334, 70]              16
         LeakyReLU-3           [-1, 8, 334, 70]               0
   ConvBlockLegacy-4           [-1, 8, 334, 70]               0
            Conv2d-5          [-1, 16, 165, 70]             912
       BatchNorm2d-6          [-1, 16, 165, 70]              32
         LeakyReLU-7          [-1, 16, 165, 70]               0
   ConvBlockLegacy-8          [-1, 16, 165, 70]               0
            Conv2d-9           [-1, 32, 82, 70]           2,592
      BatchNorm2d-10           [-1, 32, 82, 70]              64
        LeakyReLU-11           [-1, 32, 82, 70]               0
  ConvBlockLegacy-12           [-1, 32, 82, 70]               0
           Conv2d-13           [-1, 64, 78, 70]          10,304
      BatchNorm2d-14           [-1, 64,