In [97]:
import tensorflow as tf
from tensorflow.keras import (layers)

In [98]:
x = tf.random.normal((2,5,5,11))
print(f"first row, first column 11 channel values : {x[0,0,0,:]}")

first row, first column 11 channel values : [ 0.41925356 -0.5413031   0.58713526  0.6599484   1.6821845  -2.568531
 -0.7737217   0.5822169   0.4465645  -1.505228   -0.6275001 ]


In [99]:
conv_layer =  layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")
new_x =conv_layer(x)
print(f"new_x shape : {tf.shape(new_x)}")

new_x shape : [2 5 5 3]


In [100]:
print(f"new_x one value : {new_x[0,0,0,:]}")

new_x one value : [0. 0. 0.]


In [101]:
x2=conv_layer(x)
print(f"x2 : {x2[0,0,0,:]}")

x2 : [0. 0. 0.]


In [102]:
x3 = conv_layer(x)
print(f"x3 : {x3[0,0,0,:]}")

x3 : [0. 0. 0.]


# Torch stuff

In [103]:
import torch

print(torch.log(torch.exp(torch.tensor(2.0))))
print(torch.linspace(torch.log(torch.tensor(1.0)), torch.log(torch.tensor(1000.0)),16))

tensor(2.)
tensor([0.0000, 0.4605, 0.9210, 1.3816, 1.8421, 2.3026, 2.7631, 3.2236, 3.6841,
        4.1447, 4.6052, 5.0657, 5.5262, 5.9867, 6.4472, 6.9078])


In [109]:
def sinusoidal_embedding(x):
    """
    param x: (N,1) shape random normal noise
    """
    frequencies = torch.linspace(torch.log(torch.tensor(1.0)),torch.log(torch.tensor(1000.0)),16)
    frequencies = frequencies.unsqueeze(0).unsqueeze(2).unsqueeze(3)
    angular_speeds = 2*torch.pi*torch.exp(frequencies)*x
    return torch.cat((torch.sin(angular_speeds),torch.cos(angular_speeds)),dim=1)


In [112]:
noise=torch.randn((5,3,32,32))
sin_embeddings = sinusoidal_embedding(noise)

RuntimeError: The size of tensor a (16) must match the size of tensor b (3) at non-singleton dimension 1

In [106]:
embeddings = sin_embeddings.unsqueeze(2).unsqueeze(3)
print(embeddings.shape)

torch.Size([16, 2, 1, 1, 5, 1])


In [107]:
import torch.nn as nn

upsampled_noise = nn.Upsample(scale_factor=64,mode="nearest")(embeddings)



NotImplementedError: Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got nearest)

In [None]:
upsampled_noise.shape

In [None]:
upsampled_noise[0,0,:,:]

In [None]:
noisy_image = torch.randn(5,3,64,64)

In [None]:
x = nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3, padding=1, stride=1)(noisy_image)
print(x.shape)

In [None]:
x = torch.cat([upsampled_noise,x],dim=1)
print(x.shape)

In [None]:
x = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, stride=1)(x)
print(x.shape)

In [None]:
x = nn.AvgPool2d(kernel_size=2)(x)
print(x.shape)

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


def ResidualBlock(n_channels):
    def apply(x):
        in_channels = x.shape[1]
        if n_channels==in_channels:
            residual = x
        else:
            residual = nn.Conv2d(in_channels=in_channels, out_channels=n_channels, kernel_size=1)(x)
        x=nn.BatchNorm2d(num_features=in_channels,affine=False)(x)
        x = nn.Conv2d(in_channels=in_channels,out_channels=n_channels, kernel_size=3, padding=1, stride=1)(x)
        x = Swish()(x)
        x = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=3, padding=1, stride=1)(x)
        return x + residual
    return apply

In [None]:
x = ResidualBlock(64)(x)
print(x.shape)

In [None]:
def DownBlock(n_channels,block_depth):
    def apply(x):
        x, skips = x

        for _ in range(block_depth):
            x = ResidualBlock(n_channels)(x)
            skips.append(x)
        x = nn.AvgPool2d(kernel_size=2)(x)
        return x
    return apply

In [None]:
skips=[]
x = DownBlock(96,2)([x,skips])
print(x.shape)

In [None]:
def UpBlock(n_channels, block_depth):
    def apply(x):
        x, skips = x
        x = nn.Upsample(scale_factor=2,mode="bilinear")(x)
        for _ in range(block_depth):
            x = torch.cat([x,skips.pop()],dim=1)
            x = ResidualBlock(n_channels)(x)
        return x
    return apply

In [None]:
x = UpBlock(96,2)([x,skips])

In [None]:
print(x.shape)

In [None]:
skips

In [None]:
class Unet(nn.Module):
    def __init__(self, in_channels, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=1)
        self.upsample1 = nn.Upsample(scale_factor=64, mode="nearest")

        self.skips=[]

        self.downblock1 = DownBlock(32,2)
        self.downblock2 = DownBlock(64,2)
        self.downblock3 = DownBlock(96,2)

        self.residual1 = ResidualBlock(128)
        self.residual2 = ResidualBlock(128)

        self.upblock1 = UpBlock(96,2)
        self.upblock2 = UpBlock(64, 2)
        self.upblock3 = UpBlock(32,2)
        
        self.conv_last = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1)
    
    def forward(self, x):
        noisy_images,noise_variances = x
        noise_embeddings = sinusoidal_embedding(noise_variances)
        noise_embeddings = noise_embeddings.unsqueeze(2).unsqueeze(3)
        noise_embeddings = self.upsample1(noise_embeddings)
        
        x = self.conv1(noisy_images)
        x = torch.cat([x, noise_embeddings], dim=1)

        x = self.downblock1([x, self.skips])
        x = self.downblock2([x, self.skips])
        x = self.downblock3([x, self.skips])

        x = self.residual1(x)
        x = self.residual2(x)

        x = self.upblock1([x, self.skips])
        x = self.upblock2([x, self.skips])
        x = self.upblock3([x, self.skips])

        x = self.conv_last(x)

        return x


In [None]:
N=5
noise_variances = torch.randn(N,1)
noisy_images = torch.randn(N, 3, 64, 64)

unet = Unet(in_channels=3)

out = unet([noisy_image, noise_variances])

print(out.shape)

# Normalization layer

In [None]:
import torch

# Simulate some input data (batch_size, num_channels, height, width)
N, C, H, W = 5, 3, 64, 64
dummy_train = torch.randn((N, C, H, W))

# Compute mean and standard deviation across the channel dimension
mean = dummy_train.mean(dim=[0, 2, 3])
std = dummy_train.std(dim=[0, 2, 3])

print("Mean:", mean)
print("Standard Deviation:", std)


In [None]:
import torch
from torchvision import transforms

# Define the normalization transform
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Simulate some input data (batch_size, num_channels, height, width)
input_data = torch.randn(16, 3, 32, 32)

# Apply the normalization
# Note: The input data should be of type torch.FloatTensor and in the range [0, 1] if it's an image
normalized_data = normalize(input_data)

print(normalized_data.shape)  # Should be (16, 3, 32, 32)


In [None]:
class Normalizer:
    def __init__(self) -> None:
        self.mean=None
        self.std=None
    
    def adapt(self,train):
        self.mean=train.mean(dim=[0,2,3])
        self.std=train.std(dim=[0,2,3])
    
    def normalize(self, x):
        normalize_transform = transforms.Normalize(mean=self.mean, std=self.std)
        return normalize_transform(x)

In [None]:
normalizer = Normalizer()
normalizer.adapt(input_data)
out = normalizer.normalize(input_data)
print(normalizer.mean, normalizer.std)
print(out.shape, out.mean(dim=[0,2,3]))

# Diffusion schedules

In [None]:
def offset_cosine_diffusion_schedule(diffusion_times):
    min_rate = 0.02
    max_rate = 0.95
    start_angle = torch.arccos(torch.tensor(max_rate))
    end_angle = torch.arccos(torch.tensor(min_rate))
    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
    signal_rates = torch.cos(diffusion_angles)
    noise_rates = torch.sin(diffusion_angles)
    return noise_rates, signal_rates


In [None]:
diffusion_times = torch.rand((5,1))
noise_rates,signal_rates = offset_cosine_diffusion_schedule(diffusion_times)
print(noise_rates.shape)

# Diffusion model

In [None]:
class DiffusionModel(nn.Module):
    def __init__(self, in_channels, adapted_normalizer:Normalizer, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.unet = Unet(in_channels)
        self.ema_unet = Unet(in_channels)
        self.normalizer = adapted_normalizer
    
    
    def forward(self,images):
        images = self.normalizer.normalize(images)
        batch_size, n_channels, height, width = images.size()
        noises = torch.randn((batch_size, n_channels, height, width))
        diffusion_times = torch.rand(batch_size,1)
        noise_rates, signal_rates = offset_cosine_diffusion_schedule(diffusion_times)
        noise_rates = noise_rates.unsqueeze(2).unsqueeze(3)
        signal_rates = signal_rates.unsqueeze(2).unsqueeze(3)
        noisy_images = signal_rates * images + noise_rates * noises
        if self.training:
            pred_noises = self.unet([noisy_images, noises**2])
        else:
            pred_noises = self.ema_unet([noisy_images, noises**2])
        return pred_noises

In [None]:
model = DiffusionModel(3,normalizer)
out = model(input_data)