In [1]:
import torch.nn as nn
import torch as th
from guided_diffusion.nn import (
    checkpoint,
    conv_nd,
    linear,
    avg_pool_nd,
    zero_module,
    normalization,
    timestep_embedding,
)

model_channels = 128
time_embed_dim = 128 * 4

t = th.tensor([500] * 1, device='cpu')
time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

emb = time_embed(timestep_embedding(timesteps= t, dim= 128))
emb.shape

torch.Size([1, 512])

In [2]:
from guided_diffusion.mobileTrans import MobileViT
from guided_diffusion import dist_util, logger

model_cfg = {
    "s":{
        "features": [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640],
        "d": [144, 192, 240],
        "expansion_ratio": 8,
        "layers": [2, 3, 4],
        "input_channels": 5,
        "output_channels": 2
    },
}

cfg_s = model_cfg["s"]
mobilleVIT_model = MobileViT(224, cfg_s["features"], cfg_s["d"], cfg_s["layers"], cfg_s["expansion_ratio"], output_channels=2)
mobilleVIT_model.to(dist_util.dev())

input = th.randn((1, 5 , 224, 224))
tmpk = 200
with th.no_grad():
    mobilleVIT_model.eval()
    output = mobilleVIT_model(input, time = tmpk)

print(output.shape)

torch.Size([1, 2, 224, 224])


In [None]:
from guided_diffusion.unet import UNetModel, ResBlock
model = UNetModel(
            image_size=256,
            in_channels=5,
            model_channels=128,
            out_channels=2, #(3 if not learn_sigma else 6),
            num_res_blocks=2,
            attention_resolutions=tuple([16]),
            dropout=0.0,
            channel_mult=(1, 1, 2, 2, 4, 4),
            num_classes=None,
            use_checkpoint=False,
            use_fp16=False,
            num_heads=1,
            num_head_channels=-1,
            num_heads_upsample=-1,
            use_scale_shift_norm=False,
            resblock_updown=False,
            use_new_attention_order=False)
        
res_model = ResBlock(
    channels= 128,
    emb_channels= 128*4,
    dropout= 0,
    out_channels=1* 128,
    dims= 2,
    use_checkpoint=False,
    use_scale_shift_norm=False
)

In [1]:
from guided_diffusion.mobileTrans import MobileViT
from guided_diffusion.module import *
import torch.nn as nn


class STEM(nn.Module):
    def __init__(self, input_channels, middle_channel, output_channels, expand_ratio):
        super().__init__()
        self.input_channels = input_channels
        self.middle_channel = middle_channel
        self.output_channels = output_channels
        self.expand_ratio = expand_ratio

        self.conv_layer = nn.Conv2d(in_channels=input_channels, out_channels=middle_channel, kernel_size=3, stride=2, padding=1)
        self.Inverted = InvertedResidual(in_channels=middle_channel, out_channels=output_channels, stride=1, expand_ratio=expand_ratio)

    def forward(self, x):
        x = self.conv_layer(x)
        x = self.Inverted(x)
        return x


stem = nn.Sequential(
            nn.Conv2d(
                in_channels=5, out_channels=16, kernel_size=3, stride=2, padding=1),
            InvertedResidual(
                in_channels=16, out_channels=32, stride=1, expand_ratio=8),
        )

In [5]:
tmp1 = STEM(input_channels=5, middle_channel=16, output_channels=32)

print(stem)
total_params1 = sum(param.numel() for param in stem.parameters())
print(total_params1)

Sequential(
  (0): Conv2d(5, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (1): InvertedResidual(
    (conv): Sequential(
      (0): ConvNormAct(
        (conv): Conv2d(16, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm_layer): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ConvNormAct(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
        (norm_layer): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
      )
      (2): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)
8608


In [6]:
print(tmp1)
total_params2 = sum(param.numel() for param in tmp1.parameters())
print(total_params2)

STEM(
  (conv_layer): Conv2d(5, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (Inverted): InvertedResidual(
    (conv): Sequential(
      (0): ConvNormAct(
        (conv): Conv2d(16, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm_layer): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ConvNormAct(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
        (norm_layer): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
      )
      (2): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)
8608


In [None]:
def forward(self, x, emb):
        emb_out = self.emb_layers(emb).type(x.dtype)
        while len(emb_out.shape) < len(x.shape):
            emb_out = emb_out[..., None]
        # print("Hi Your Number is", y)
        print("x shape",x.shape)
        print("emb_out shape",emb_out.shape)
        if self.use_res_connect:
            return x + self.conv(x) + emb_out
        else:
            return self.conv(x) + emb_out