In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from disentangled_representations.src.models.UNet import UNet
from disentangled_representations.src.models.UNet_parts import SeparableConv, DoubleNonLinearConv

unet = UNet(
    in_channels=1,
    out_channels=1,
    channels=[32, 64, 128, 256],
    conv_block_down=DoubleNonLinearConv,
    conv_block_up=SeparableConv,
)
_ = unet.to("mps")

In [11]:
from torchinfo import summary
import torch

x = torch.randn(1, 1, 256, 256).to("mps")
p = torch.randn(1, 2 * 256).to("mps")

summary(
    unet,
    input_size=[tuple(x.shape), tuple(p.shape)],
    device='mps',
    col_names=("output_size", "num_params", "mult_adds")
)

Layer (type:depth-idx)                             Output Shape              Param #                   Mult-Adds
UNet                                               [1, 1, 256, 256]          --                        --
├─DoubleNonLinearConv: 1-1                         [1, 32, 256, 256]         --                        --
│    └─Sequential: 2-1                             [1, 32, 256, 256]         --                        --
│    │    └─Conv2d: 3-1                            [1, 32, 256, 256]         288                       18,874,368
│    │    └─BatchNorm2d: 3-2                       [1, 32, 256, 256]         64                        64
│    │    └─ReLU: 3-3                              [1, 32, 256, 256]         --                        --
│    │    └─Conv2d: 3-4                            [1, 32, 256, 256]         9,216                     603,979,776
│    │    └─BatchNorm2d: 3-5                       [1, 32, 256, 256]         64                        64
│    │    └─ReLU: 3-6 

In [68]:
unet.eval()
out = unet(x, p)

In [9]:
from disentangled_representations.src.models.transient_encoders import EfficientNetB0VariationalTransientEncoder
d = 64

EB0_variational_encoder = EfficientNetB0VariationalTransientEncoder(1, d)
EB0_variational_encoder.to("mps")

EfficientNetB0VariationalTransientEncoder(
  (model): EfficientNet(
    (conv_stem): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): SiLU(inplace=True)
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
          (aa): Identity()
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
       

In [16]:
EB0_variational_encoder(x)

tensor([[ 0.0331, -0.0022, -0.0639, -0.0606,  0.0263, -0.0605,  0.0366,  0.1020,
         -0.0763,  0.0357,  0.0555, -0.0493, -0.0489, -0.0256,  0.0506, -0.0138,
          0.0109, -0.0119, -0.0970,  0.1170, -0.0798,  0.0286,  0.0038, -0.0631,
          0.0350, -0.0676, -0.0816,  0.0598, -0.0249,  0.0393,  0.0142,  0.0555,
          0.0995,  0.0356, -0.1142, -0.0570, -0.0284,  0.0496,  0.0667,  0.0107,
          0.0507,  0.0074, -0.0319,  0.0010, -0.0900, -0.0201,  0.0501, -0.0498,
         -0.0021, -0.0104,  0.0659, -0.0512, -0.0332, -0.0757,  0.0222, -0.0419,
         -0.0326, -0.0512,  0.0209, -0.0772, -0.0666,  0.0272, -0.0313,  0.0618,
         -0.0410, -0.0339,  0.0193, -0.0817, -0.0419, -0.0803, -0.0201,  0.0260,
         -0.0766, -0.0487, -0.0138, -0.0241, -0.0598,  0.0208, -0.1074, -0.0428,
         -0.0408,  0.0213, -0.0775,  0.0002,  0.0491,  0.0565, -0.0401,  0.0861,
          0.0648,  0.0176, -0.0181, -0.0042, -0.0700,  0.0444,  0.0066, -0.0117,
          0.1363,  0.0482,  