<a href="https://colab.research.google.com/github/petron23/Generative_AI/blob/main/Stable_Diffusion_tutorial/sd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import os
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image


class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_res=False):
        super(ResidualConvBlock, self).__init__()

        self.same_channels = in_channels == out_channels
        self.is_res = is_res

        self.conv1 = self._conv_block(in_channels, out_channels)
        self.conv2 = self._conv_block(out_channels, out_channels)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)

        if self.is_res:
            if self.same_channels:
                out = x + x2
            else:
                shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
                out = shortcut(x) + x2

            return out / 1.414
        else:
            return x2

    def get_out_channels(self):
        return self.conv2[0].out_channels

    def set_out_channels(self, out_channels):
        self.conv1[0].out_channels = out_channels
        self.conv2[0].in_channels = out_channels
        self.conv2[0].out_channels = out_channels

class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init()

        # Create the upsampling block with ConvTranspose2d and ResidualConvBlock layers
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.residual1 = ResidualConvBlock(out_channels, out_channels)
        self.residual2 = ResidualConvBlock(out_channels, out_channels)

    def forward(self, x, skip):
        # Concatenate the input tensor (x) with the skip connection tensor along the channel dimension
        x = torch.cat((x, skip), 1)

        # Apply upsampling, followed by two residual convolutional blocks
        x = self.upsample(x)
        x = self.residual1(x)
        x = self.residual2(x)

        return x

class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init()

        # Create the downsampling block with ResidualConvBlock and MaxPool2d layers
        self.residual1 = ResidualConvBlock(in_channels, out_channels)
        self.residual2 = ResidualConvBlock(out_channels, out_channels)
        self.downsample = nn.MaxPool2d(2)

    def forward(self, x):
        # Apply two residual convolutional blocks followed by downsampling
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.downsample(x)

        return x
