# Image Super-Resolution via Iterative Refinement

#### With additional functions upon the original codes

*Code Writer: Chaeeun Ryu*

references: 
- https://nn.labml.ai/diffusion/ddpm/unet.html
- https://github.com/KiUngSong/Generative-Models/tree/main
- https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/
- https://arxiv.org/abs/2102.09672
- https://arxiv.org/pdf/2104.14951.pdf

In [4]:
import torch, torchvision
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from einops import rearrange, repeat
from tqdm.notebook import tqdm
from functools import partial
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math, os, copy

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Define U-net Architecture:

Approximate reverse diffusion process by using U-net<br>
U-net of SR3 : U-net backbone + Positional Encoding of time + Multihead Self-Attention

**Swish activation function**

In [5]:
class Swish(nn.Module):
    def forward(self,x):
        return x*torch.sigmoid(x)

In [7]:
# PositionalEncoding Source： https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
class PositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, noise_level):
        count = self.dim // 2
        step = torch.arange(count, dtype=noise_level.dtype,
                            device=noise_level.device) / count
        encoding = noise_level.unsqueeze(
            1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
        encoding = torch.cat(
            [torch.sin(encoding), torch.cos(encoding)], dim=-1)
        return encoding

In [8]:
class FeatureWiseAffine(nn.Module):
    def __init__(self, in_channels, out_channels, use_affine_level=False):
        super(FeatureWiseAffine, self).__init__()
        self.use_affine_level = use_affine_level
        self.noise_func = nn.Sequential(
            nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
        )

    def forward(self, x, noise_embed):
        batch = x.shape[0]
        if self.use_affine_level:
            gamma, beta = self.noise_func(noise_embed).view(
                batch, -1, 1, 1).chunk(2, dim=1)
            x = (1 + gamma) * x + beta
        else:
            x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
        return x

In [9]:
class Upsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(dim, dim, 3, padding=1)

    def forward(self, x):
        return self.conv(self.up(x))


class Downsample(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

In [10]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=32, dropout=0):
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(groups, dim),
            Swish(),
            nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
            nn.Conv2d(dim, dim_out, 3, padding=1)
        )

    def forward(self, x):
        return self.block(x)


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
        super().__init__()
        self.noise_func = FeatureWiseAffine(
            noise_level_emb_dim, dim_out, use_affine_level)

        self.block1 = Block(dim, dim_out, groups=norm_groups)
        self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb):
        b, c, h, w = x.shape
        h = self.block1(x)
        h = self.noise_func(h, time_emb)
        h = self.block2(h)
        return h + self.res_conv(x)

In [11]:
class SelfAttention(nn.Module):
    def __init__(self, in_channel, n_head=1, norm_groups=32):
        super().__init__()

        self.n_head = n_head

        self.norm = nn.GroupNorm(norm_groups, in_channel)
        self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
        self.out = nn.Conv2d(in_channel, in_channel, 1)

    def forward(self, input):
        batch, channel, height, width = input.shape
        n_head = self.n_head
        head_dim = channel // n_head

        norm = self.norm(input)
        qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
        query, key, value = qkv.chunk(3, dim=2)  # bhdyx

        attn = torch.einsum(
            "bnchw, bncyx -> bnhwyx", query, key
        ).contiguous() / math.sqrt(channel)
        attn = attn.view(batch, n_head, height, width, -1)
        attn = torch.softmax(attn, -1)
        attn = attn.view(batch, n_head, height, width, height, width)

        out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
        out = self.out(out.view(batch, channel, height, width))

        return out + input

In [13]:
class ResnetBlockWithAttn(nn.Module):
    def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
        super().__init__()
        self.with_attn = with_attn
        self.res_block = ResnetBlock(
            dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
        if with_attn:
            self.attn = SelfAttention(dim_out, norm_groups=norm_groups)

    def forward(self, x, time_emb):
        x = self.res_block(x, time_emb)
        if(self.with_attn):
            x = self.attn(x)
        return x

In [14]:
class UNet(nn.Module):
    def __init__(
        self,
        in_channel=6,
        out_channel=3,
        inner_channel=32,
        norm_groups=32,
        channel_mults=(1, 2, 4, 8, 8),
        attn_res=(8),
        res_blocks=3,
        dropout=0,
        with_noise_level_emb=True,
        image_size=128
    ):
        super().__init__()

        if with_noise_level_emb:
            noise_level_channel = inner_channel
            self.noise_level_mlp = nn.Sequential(
                PositionalEncoding(inner_channel),
                nn.Linear(inner_channel, inner_channel * 4),
                Swish(),
                nn.Linear(inner_channel * 4, inner_channel)
            )
        else:
            noise_level_channel = None
            self.noise_level_mlp = None

        num_mults = len(channel_mults)
        pre_channel = inner_channel
        feat_channels = [pre_channel]
        now_res = image_size
        downs = [nn.Conv2d(in_channel, inner_channel,
                           kernel_size=3, padding=1)]
        for ind in range(num_mults):
            is_last = (ind == num_mults - 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks):
                downs.append(ResnetBlocWithAttn(
                    pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
                feat_channels.append(channel_mult)
                pre_channel = channel_mult
            if not is_last:
                downs.append(Downsample(pre_channel))
                feat_channels.append(pre_channel)
                now_res = now_res//2
        self.downs = nn.ModuleList(downs)

        self.mid = nn.ModuleList([
            ResnetBlockWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=True),
            ResnetBlockWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                               dropout=dropout, with_attn=False)
        ])

        ups = []
        for ind in reversed(range(num_mults)):
            is_last = (ind < 1)
            use_attn = (now_res in attn_res)
            channel_mult = inner_channel * channel_mults[ind]
            for _ in range(0, res_blocks+1):
                ups.append(ResnetBlockWithAttn(
                    pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
                        dropout=dropout, with_attn=use_attn))
                pre_channel = channel_mult
            if not is_last:
                ups.append(Upsample(pre_channel))
                now_res = now_res*2

        self.ups = nn.ModuleList(ups)

        self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)

    def forward(self, x, time):
        t = self.noise_level_mlp(time) if exists(
            self.noise_level_mlp) else None

        feats = []
        for layer in self.downs:
            if isinstance(layer, ResnetBlockWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)
            feats.append(x)

        for layer in self.mid:
            if isinstance(layer, ResnetBlockWithAttn):
                x = layer(x, t)
            else:
                x = layer(x)

        for layer in self.ups:
            if isinstance(layer, ResnetBlockWithAttn):
                x = layer(torch.cat((x, feats.pop()), dim=1), t)
            else:
                x = layer(x)

        return self.final_conv(x)

register_buffer: 연산하는데 GPU를 사용하지만, backpropagation을 통해 update는 안 되는 값을 저장할 때 사용

# Diffusion Process Framework

#### TL;DR: Predict noise! 

**Contributions**
- Huber loss added as loss type
- lengthened linear schedule added for schedule types

<br><br><br>

##### Diffusion Computations

Computing $x_t$ directly in **forward process**, given $x_0$<br><br>
$$x_t = \sqrt{\bar{\alpha}_{t}}x_0 + \sqrt{1-\bar{\alpha}_{t} }\epsilon$$
$\alpha_t := 1-\beta_t$<br>
$\bar{\alpha}_t := \Pi^t_{s=0}\alpha_s$<br>
$q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I)$<br>
$x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t}\epsilon)$
$=\frac{1}{\sqrt{\bar{\alpha}_t}}x_t - \sqrt{\frac{1}{\bar{\alpha}_t}-1}\epsilon$

<br><br>
Computing mean and variance of posteriors for **reverse process**

$$q(x_{t-1}|x_t,x_0) \sim \mathcal{N}(x_{t-1}; \tilde{\mu_t}(x_0,x_t),\tilde{\beta_t}(x_0,x_t)I)$$

$\tilde{\mu_t}(x_0,x_t) = \frac{\beta_t(\sqrt{\bar{\alpha}_{t-1}})}{1-\bar{\alpha}_t}x_0+\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\alpha_t}x_t$<br>
$\tilde{\beta_t}(x_0,x_t) = \beta_t(\frac{1-\bar{\alpha}_{t-1}}{1-\alpha_t})$

In [18]:
class Diffusion(nn.Module):
    def __init__(self, model, device, img_size, LR_size, channels = 3):
        super().__init__()
        self.channels = channels
        self.device = device
        self.model = model.to(self.device)
        self.img_size = img_size
        self.LR_size = LR_size
    
    #Set loss type
    def set_loss(self, type_):
        if type_ == 'l1':
            print(f"loss type: L1")
            self.loss_func = nn.L1Loss(reduction = 'sum')
        elif type_ == 'l2':
            print(f"loss type: L2")
            self.loss_func = nn.MSELoss(reduction='sum')
        else:
            print(f"loss type: Huber Loss")
            self.loss_func = nn.HuberLoss(reduction = 'sum')
            
    #Set scheduling for Beta as Cosine, and also modifications to Cosine
    def make_beta_schedule(self, schedule_type, n_timestep):
        if schedule_type == 'cosine':
            cosine_s = 8e-3
            timesteps = torch.arange(n_timestep+1,dtype = torch.float64)/n_timestep + cosine_s
            alphas = timesteps/(1+cosine_s)*math.pi/2
            alphas = alphas/alphas[0]
            betas = 1 - alphas[1:]/alphas[:-1]
            betas = betas.clamp(max = 0.999)
            return betas
        
        elif schedule_type == 'lengthened_linear':
            #original linear start: 0.0001
            #original linear end: 0.02
            linear_start = 0.00001
            linear_end = 0.005
            betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
            return betas
        else:
            print("schedule tpe not assigned properly!")
            return
    
    def set_new_noise_schedule(self, schedule_option):
        to_torch = partial(torch.tensor, dtype=torch.float32, device=self.device)#지정된 type으로 torch tensor 만드는 함수
        betas = self.make_beta_schedule(
            schedule=schedule_opt['schedule'],
            n_timestep=schedule_opt['n_timestep'],
            linear_start=schedule_opt['linear_start'],
            linear_end=schedule_opt['linear_end'])
        betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])#마지막 t번째 alpha 제외한 나머지끼리의 곱
        self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod))
        
        self.num_timesteps = int(len(betas))
        
        # Coefficient for forward diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))#\bar{alpha_t}
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))#\bar{alpha_t-1}
        #forward process에서 x0을 예측할 때 x_t앞의 coefficient
        self.register_buffer('pred_coef_xt', to_torch(np.sqrt(1. / alphas_cumprod)))
        #forward process에서 x0을 예측할 때 noise앞의 (-1)*coefficient
        self.register_buffer('pred_coef_noise', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        ###Coefficient for reverse diffusion posterior q(x_{t-1} | x_t, x_0)
        #reverse process에서 posterior variance 구할 때 Beta_t의 coefficient (=variance)
        variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)#predicted variance for reverse process
        self.register_buffer('variance', to_torch(variance))
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(variance, 1e-20))))
        #reverse process에서 posterior mean 구할 때 x0앞 coefficient
        self.register_buffer('posterior_mean_coef_x0', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        #reverse process에서 posterior mean 구할 때 xt앞 coefficient
        self.register_buffer('posterior_mean_coef_xt', to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
        
    def predict_x0(self, x_t, t, noise):#x0을 위 공식대로 구함
        return self.pred_coef_xt[t]*x_t - self.pred_coef_noise[t]*noise
    
    def q_posterior(self,x_start,x_t,t):
        posterior_mu = self.posterior_mean_coef_x0[t]*x_start + self.posterior_mean_coef_xt[t]*x_t
        posterior_log_var = self.posterior_log_variance_clipped[t]
        return posterior_mu, posterior_log_var         
    
    #construct x0 and predict q's mean variance (=reverse process)
    def q_mean_variance(self, x, t, clip_denoised:bool, condition_x = None):
        batch_size = x.shape[0]
        noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size,1).to(x.device)
        x_recon = self.predict_start(x,t,noise = self.model(torch.cat([condition_x,x],dim = 1),noise_level))#reconstruct x0
        
        if clip_denoised:
            x_recon.clamp_(-1.,1.)
        mean, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t = x, t=t)
        return mean, posterior_log_variance

    @torch.no_grad()
    def p_sample(self,x,t,clip_denoised = True, condition_x = None):
        mean, log_variance = self.q_mean_variance()
    
        

# Image Super-Resolution via Iterative Refinement(SR3)

In [None]:
class SR3