In [18]:
# Borrowed Code From : https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

def Normalize(in_channels, num_groups=32):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)

def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)

class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        '''Convolution Block 1'''
        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        
        '''Inject Time Infomation'''
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels,
                                             out_channels)
        
        '''Convolution Block 2'''
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb):
        # x : encoded input image, (b, c, h, w)
        # temb : (b, c)
        
        
        h = x
        '''Convolution Block 1'''
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        '''Inject Time Infomation'''
        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

        '''Convolution Block 2'''
        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        '''Match the input channels to the output channels'''
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h

In [19]:
resnet_block = ResnetBlock(in_channels=32, out_channels=128, conv_shortcut=False, dropout=0.1, temb_channels=128)
x = torch.randn(2, 32, 64, 64)
temb = torch.randn(2, 128)
y = resnet_block(x, temb)
print(x.shape, y.shape)

torch.Size([2, 32, 64, 64]) torch.Size([2, 128, 64, 64])
