In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange
# from models.model import *
import matplotlib.pyplot as plt
import os

In [2]:
torch.cuda.is_available()

True

In [3]:
from torchvision import datasets, transforms

train_transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Resize((32, 32), antialias=False),
#     transforms.CenterCrop(224)
])
train_dataset = datasets.CIFAR10(root='./datasets', train=True, transform=train_transform, download=True)
val_dataset = datasets.CIFAR10(root='./datasets', train=False, transform=train_transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)


img, lab = next(iter(train_dataloader))
print(img.shape, lab.shape)

torch.Size([1, 3, 32, 32]) torch.Size([1])


In [51]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange
import torch.nn as nn


skip_info = {}

def time_encoding2d(dim, h, w, t):
    assert dim%2 == 0
    pos_emb = torch.zeros(dim, h, w)
    h_dim = int(dim/2)
    pos_emb[0::2, :, :] = torch.sin(t.view(1, -1, 1).repeat(h_dim, h, w))/torch.pow(10000, torch.arange(0, h_dim)*2/dim).view(-1, 1, 1).repeat(1, h, w)
    pos_emb[1::2, :, :] = torch.cos(t.view(1, -1, 1).repeat(h_dim, h, w))/torch.pow(10000, torch.arange(0, h_dim)*2/dim).view(-1, 1, 1).repeat(1, h, w)
    
    return pos_emb


class Sequential(nn.Sequential):
    def forward(self, *args):
        for module in self._modules.values():
            if type(args) == tuple:
                args = module(*args)
            else:
                args = module(args)
        return args


class PixCNNPP(nn.Module):
    
    def __init__(self, num_blocks=2, num_res=4, channels=[4, 16, 32, 64, 128, 256, 512], sz=32):
        super().__init__()                
        self.model = []

        self.upscale = Sequential(
            nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=4, num_channels=4),
            nn.ReLU(),
        )
        
        self.downscale = Sequential(
            nn.Conv2d(in_channels=4, out_channels=3, kernel_size=1, stride=1),
            # nn.Dropout2d(p=0.1),
            # nn.GroupNorm(num_groups=3, num_channels=3),
            # nn.ReLU(),
        )
        
        ## downsample
        for res in range(num_res):
            for block in range(num_blocks):
                if block == 0:
                    self.model.append(DownsampleBlock(in_channels=channels[res], out_channels=channels[res+1], idx=None))
                elif block == num_blocks-1:
                    self.model.append(DownsampleBlock(in_channels=channels[res+1], out_channels=channels[res+1], idx=res))
                else:
                    self.model.append(DownsampleBlock(in_channels=channels[res+1], out_channels=channels[res+1], idx=None))

            self.model.append(ReduceBlock(in_channels=channels[res+1], out_channels=channels[res+1]))
            sz /= 2

        ## Bottleneck
        self.model.append(
            Sequential(
                DownsampleBlock(in_channels=channels[num_res], out_channels=channels[num_res], idx=None),
                DownsampleBlock(in_channels=channels[num_res], out_channels=channels[num_res], idx=None)
            )
        )

        ## upsample
        for res in reversed(range(1, num_res+1)):
            self.model.append(IncreaseBlock(in_channels=channels[res], out_channels=channels[res]))
            sz *= 2
            for block in range(num_blocks):
                if block == 0:
                    self.model.append(UpsampleBlock(in_channels=channels[res], out_channels=channels[res-1], idx=res-1))
                else:
                    self.model.append(UpsampleBlock(in_channels=channels[res-1], out_channels=channels[res-1], idx=None))

            self.model.append(ReduceBlock(in_channels=channels[res+1], out_channels=channels[res+1]))

        self.model = Sequential(*self.model)
        
        
    def forward(self, x, t):
        global skip_info
        skip_info = {}

        x = self.upscale(x)
        x, _ = self.model(x, torch.tensor([t], requires_grad=False))
        x = self.downscale(x)

        return x
        

class ReduceBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.block = Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=out_channels, num_channels=out_channels),
            nn.ReLU()
        )

    
    def forward(self, x, t):
        print("Red", self.block(x).shape)
        return self.block(x), t


class IncreaseBlock(nn.Module):

    def __init__(self, in_channels, out_channels, idx=None):
        super().__init__()

        self.block = Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=out_channels, num_channels=out_channels),
            nn.ReLU()
        )

    
    def forward(self, x, t):
        print("Inc", self.block(x).shape)
        z = self.block(x)
        return z, t


class DownsampleBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, idx=None):
        super().__init__()   
        
        self.idx = idx
        global skip_info
        
        self.block = Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=out_channels, num_channels=out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=out_channels, num_channels=out_channels)
        )
        self.relu = nn.ReLU()
        
        self.downsample=False
        if in_channels != out_channels:
            self.downsample=True
            self.downsample_block = Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
                # nn.Dropout2d(p=0.1),
                nn.GroupNorm(num_groups=out_channels, num_channels=out_channels)
            )

        
    def forward(self, x, t):
        B, C, H, W = x.shape
        pos_emb = time_encoding2d(C, H, W, t)
        x = x + pos_emb.to(x)
        
        if self.downsample:
            z = self.relu(self.downsample_block(x) + self.block(x))
        else:
            z = self.relu(x + self.block(x))
            
        if self.idx is not None:
            skip_info[self.idx] = z
        
        print("Down", z.shape)
        
        return z, t
    
    
class UpsampleBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, idx=None):
        super().__init__()    
        
        self.idx = idx
        in_ch = in_channels
        if self.idx is not None:
            in_ch = 2*in_channels
                                      
        self.block = Sequential(
            nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_channels, kernel_size=3, stride=1, padding=1, output_padding=0),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=out_channels, num_channels=out_channels),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, output_padding=0),
            # nn.Dropout2d(p=0.1),
            nn.GroupNorm(num_groups=out_channels, num_channels=out_channels)
        )
        self.relu = nn.ReLU()
        
        self.upsample=False
        if in_channels != out_channels:
            self.upsample=True
            self.upsample_block = Sequential(
                nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_channels, kernel_size=3, stride=1, padding=1, output_padding=0),
                # nn.Dropout2d(p=0.1),
                nn.GroupNorm(num_groups=out_channels, num_channels=out_channels)
            )
        
        
    def forward(self, x, t):
        global skip_info
        
#         for key, value in skip_info.items():
#             print(key, value.shape)
        
        B, C, H, W = x.shape
        pos_emb = time_encoding2d(C, H, W, t)
        x = x + pos_emb.to(x)
        
        if self.idx is not None:
            x = torch.cat((x, skip_info[self.idx]), dim=1)
        
        if self.upsample:
            z = self.relu(self.upsample_block(x) + self.block(x))
        else:
            z = self.relu(x + self.block(x))
        print("Up", z.shape)
        
        return z, t

In [52]:
img, lab = next(iter(train_dataloader))
model = PixCNNPP()
t = 0

out = model(img, t)
print(out.shape)

# model.load_state_dict(torch.load('./weights/model_40.pth'))

Down torch.Size([1, 16, 32, 32])
Down torch.Size([1, 16, 32, 32])
Red torch.Size([1, 16, 16, 16])
Down torch.Size([1, 32, 16, 16])
Down torch.Size([1, 32, 16, 16])
Red torch.Size([1, 32, 8, 8])
Down torch.Size([1, 64, 8, 8])
Down torch.Size([1, 64, 8, 8])
Red torch.Size([1, 64, 4, 4])
Down torch.Size([1, 128, 4, 4])
Down torch.Size([1, 128, 4, 4])
Red torch.Size([1, 128, 2, 2])
Down torch.Size([1, 128, 2, 2])
Down torch.Size([1, 128, 2, 2])
Inc torch.Size([1, 128, 4, 4])
Up torch.Size([1, 64, 4, 4])
Up torch.Size([1, 64, 4, 4])


RuntimeError: Given groups=1, weight of size [256, 256, 3, 3], expected input[1, 64, 4, 4] to have 256 channels, but got 64 channels instead

In [None]:
######## INFERENCE

L2_loss = torch.nn.MSELoss()

with torch.no_grad():
    model.eval()
    maxT = 1000
    beta_t = torch.arange(0, 2000, 2000/maxT)/(1e5)
    alpha_bar_t = torch.ones(beta_t.shape)
    for i in range(beta_t.shape[0]):
        if i == 0:
            alpha_bar_t[i] = 1-beta_t[i]
        else:
            alpha_bar_t[i] = alpha_bar_t[i-1] * (1-beta_t[i])
    img, lab = next(iter(train_dataloader))
    xt = torch.normal(torch.zeros(img.shape), torch.ones(img.shape))
    
    for t in range(maxT-1, 0, -1):
        epsilon_theta = model(xt, t)    
        if t>0:
            z = torch.normal(torch.zeros(img.shape), torch.ones(img.shape))
        else:
            z = torch.zeros(img.shape)
            
        xt = (xt - (1-beta_t[t])/torch.sqrt(1-alpha_bar_t[t]) * epsilon_theta)/torch.sqrt(1-beta_t[t]) + torch.sqrt(beta_t[t])*z
        xt = xt - torch.min(torch.min(xt, dim=3, keepdim=True).values, dim=2, keepdim=True).values
        xt = xt / torch.max(torch.max(xt, dim=3, keepdim=True).values, dim=2, keepdim=True).values
        xt = xt*2 - 1
        
        print(t)
        img_show = (xt + 1)/2
        plt.imshow(img_show[0].permute(1, 2, 0))
        plt.show()        

In [None]:
# maxT = 200
# # L2_loss = torch.nn.MSELoss()
# beta_t = torch.arange(0, 1000, 1000/maxT)/(1e4)
# print(beta_t)
# alpha_bar_t = 1

# img, lab = next(iter(train_dataloader))
# img = 2*img-1

# plt.imshow(((img[0]+1)/2).permute(1, 2, 0))
# plt.show()

# for t in range(1, maxT):
#     alpha_bar_t = (1-beta_t[t]) * alpha_bar_t
    
#     if t %10 == 0:
#         print(t)
#         ## forward process
#         epsilon = torch.normal(torch.zeros(img.shape), torch.sqrt((1-alpha_bar_t)).repeat(img.shape))
#         img_t = torch.sqrt(alpha_bar_t)*img + epsilon

#         plt.imshow(((img_t[0]+1)/2).permute(1, 2, 0))
#         plt.show()

# #         ## reverse estimation
# #         epsilon_theta = model(img_t, t)    
# #         L2_loss(epsilon_theta, epsilon)    

In [None]:
a = torch.tensor([1,2,3])
b = torch.tensor([4, 5, 6])

gx, gy = torch.meshgrid(a, b)
torch.tensor(list(zip(gx.flatten(), gy.flatten())))
# print(gx, gy)

In [None]:
torch.arange(0, 2)

In [None]:
blk = nn.Sequential(nn.Conv2d(1, 2, 3),
                   nn.Linear(4, 5))

for mod in blk._modules.values():
    print(mod)

In [None]:
def func(*args, **kwargs):
    print(args)
    print(kwargs)
    
func(1,2,3)

In [None]:
a = (50,)
print(*a)

In [None]:
import numpy as np
np.random.choice(np.arange(1000))

In [None]:
a = torch.tensor([[1, 2, 3], [4,5,6]])
torch.min(a, dim=1).values