<a href="https://colab.research.google.com/github/takbull/U-Tokyo-Deep-Generative-Model-Spring-Seminar/blob/master/lecture_chap03_exercise_RealNVP_master.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 第3回 講義 演習 Real NVPの実装

In [0]:
%cd /root/userspace/chap03

import PIL
PIL.PILLOW_VERSION = PIL.__version__

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.utils as utils
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.utils.data as data
from torch.utils.data import DataLoader 
import torchvision
from torchvision import datasets, transforms

from utils.norm_util import get_norm_layer, get_param_groups, WNConv2d
from utils.optim_util import bits_per_dim, clip_grad_norm
from utils.shell_util import AverageMeter
from utils.resnet import ResidualBlock, ResNet

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import pylab as pl

import functools
from enum import IntEnum
import os
from tqdm import tqdm

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. 前準備

### 1.1. データローダー

学習用のデータとしてcifar10を使用します。以下でそのためのデータローダを用意します。

In [0]:
batch_size = 32
num_workers =8

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
transform_test = transforms.Compose([transforms.ToTensor()])

trainloader = DataLoader(datasets.CIFAR10(root='data/cifar10/', train=True, download=True, transform=transform_train),
                         batch_size=batch_size, shuffle=True, num_workers=num_workers)

testloader = DataLoader(datasets.CIFAR10(root='data/cifar10/', train=False, download=True, transform=transform_test),
                        batch_size=batch_size, shuffle=False, num_workers=num_workers)

### 1.2. 学習プロセスの表示のための関数

モデルの学習中に、bpdの推移過程と生成例を表示するための関数です。学習のときに使用します。

In [0]:
def display_process(train_bpd, samples, image_frame_dim=4, fix=True):
    plt.gcf().clear()
        
    fig = plt.figure(figsize=(24, 15))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=0.5, hspace=0.05, wspace=0.05)
        
    x = range(len(train_bpd))

    y = train_bpd
    
    ax1 = fig.add_subplot(1, 2, 1)

    ax1.plot(x, y, label='train_bpd')

    ax1.set_xlabel('Iter')
    ax1.set_ylabel('bpd')

    ax1.legend(loc='upper right')
    ax1.grid(True)
    
    for i in range(image_frame_dim*image_frame_dim):
        ax = fig.add_subplot(image_frame_dim, image_frame_dim*2, (int(i/image_frame_dim)+1)*image_frame_dim+i+1, xticks=[], yticks=[])
        ax.imshow(samples[i])

## 2. Real NVPの実装

### 2.1. 関数・クラスの実装

ネットワークの定義に必要な各種関数・クラスを実装します。

#### squeezing operation

`squeeze_2x2()`は各チャネルを2x2のサイズに分割する処理を行います。  
- `alt_order=True`を指定することで順番を変えることができます。  
- `reverse=True`では逆の処理を行います。（分割されたものを元のサイズに戻す処理）

In [0]:
def squeeze_2x2(x, reverse=False, alt_order=False):
    block_size = 2
    if alt_order:
        n, c, h, w = x.size()

        if reverse:
            c //= 4
        # Defines permutation of input channels (shape is (4, 1, 2, 2)).
        squeeze_matrix = torch.tensor([[[[1., 0.], [0., 0.]]],
                                       [[[0., 0.], [0., 1.]]],
                                       [[[0., 1.], [0., 0.]]],
                                       [[[0., 0.], [1., 0.]]]], 
                                       dtype=x.dtype, device=x.device)
        perm_weight = torch.zeros((4 * c, c, 2, 2), dtype=x.dtype, device=x.device)
        for c_idx in range(c):
            slice_0 = slice(c_idx * 4, (c_idx + 1) * 4)
            slice_1 = slice(c_idx, c_idx + 1)
            perm_weight[slice_0, slice_1, :, :] = squeeze_matrix
        shuffle_channels = torch.tensor([c_idx * 4 for c_idx in range(c)]
                                      + [c_idx * 4 + 1 for c_idx in range(c)]
                                      + [c_idx * 4 + 2 for c_idx in range(c)]
                                      + [c_idx * 4 + 3 for c_idx in range(c)])
        perm_weight = perm_weight[shuffle_channels, :, :, :]

        if reverse:
            x = F.conv_transpose2d(x, perm_weight, stride=2)
        else:
            x = F.conv2d(x, perm_weight, stride=2)
    else:
        b, c, h, w = x.size()
        x = x.permute(0, 2, 3, 1)

        if reverse:
            x = x.view(b, h, w, c // 4, 2, 2)
            x = x.permute(0, 1, 4, 2, 5, 3)
            x = x.contiguous().view(b, 2 * h, 2 * w, c // 4)
        else:
            x = x.view(b, h // 2, 2, w // 2, 2, c)
            x = x.permute(0, 1, 3, 5, 2, 4)
            x = x.contiguous().view(b, h // 2, w // 2, c * 4)

        x = x.permute(0, 3, 1, 2)

    return x

In [0]:
# Example
a = torch.Tensor([i+1 for i in range(16)]).view(1,1,4,4) #元のテンソル
a_2x2 = squeeze_2x2(a) # 分割
a_2x2alt = squeeze_2x2(a, alt_order=True) # 順番を変えて分割
a_reverse = squeeze_2x2(a_2x2, reverse=True) # 分割されたものをもとのサイズに戻す
print(a)
print(a_2x2)
print(a_2x2alt)
print(a_reverse)

#### Coupling layer

<img src="./image/figure3.png" align=left>

Real NVPでは$D$次元の$x$のうち、$x_{1:d}$と$x_{d+1:D}$がそれぞれ**Coupling layer**の入力として与えられます。  
その際の出力は以下の式のように表されます。  
$y_{1:d}=x_{1:d}$  
$y_{d+1:D} = x_{d+1:D}\odot \exp(s(x_{1:d})+t(x_{1:d}))$  
この変換におけるJacobianは下三角行列となるためその行列式は  
$\exp(\sum_{j} s(x_{1:d})_{j} )$となり、簡単に計算を行うことができます。  
また、Coupling layerは逆の計算も容易に行うことができます。  
$x_{1:d}=y_{1:d}$  
$x_{d+1:D} = (y_{d+1:D}-t(y_{1:d}))\odot \exp(-s(y_{1:d}))$

実装上はbinary mask $b$を利用します。（後述するcheckboard pattern maskとchannel-wise masking）    
$y = b \odot x + (1-b) \odot (x \odot \exp(s(b\odot x))+t(b\odot x))$

In [0]:
class CouplingLayer(nn.Module):
    def __init__(self, in_channels, mid_channels, num_blocks, mask_type, reverse_mask):
        super(CouplingLayer, self).__init__()

        # Save mask info
        self.mask_type = mask_type # CHECKBOARD(=0) or CHANNEL_WISE(=1)
        self.reverse_mask = reverse_mask # True or False

        # Build scale and translate network
        if self.mask_type == MaskType.CHANNEL_WISE:
            in_channels //= 2
        self.st_net = ResNet(in_channels, mid_channels, 2 * in_channels,
                             num_blocks=num_blocks, kernel_size=3, padding=1,
                             double_after_norm=(self.mask_type == MaskType.CHECKERBOARD))

        # Learnable scale for s
        self.scale = nn.utils.weight_norm(Scalar())

    def forward(self, x, sldj=None, reverse=True):
        if self.mask_type == MaskType.CHECKERBOARD:
            # Checkerboard mask
            b = checkerboard_mask(x.size(2), x.size(3), self.reverse_mask, device=x.device)
            x_b = x * b
            st = self.st_net(x_b)
            s, t = st.chunk(2, dim=1)
            s = self.scale(torch.tanh(s))
            s = s * (1 - b)
            t = t * (1 - b)

            # Scale and translate
            if reverse:
                inv_exp_s = s.mul(-1).exp()
                x = x * inv_exp_s - t
            else:
                exp_s = s.exp()
                x = (x + t) * exp_s

                # Add log-determinant of the Jacobian
                sldj += s.view(s.size(0), -1).sum(-1)
        else:
            # Channel-wise mask
            if self.reverse_mask:
                x_id, x_change = x.chunk(2, dim=1)
            else:
                x_change, x_id = x.chunk(2, dim=1)

            st = self.st_net(x_id)
            s, t = st.chunk(2, dim=1)
            s = self.scale(torch.tanh(s))

            # Scale and translate
            if reverse:
                inv_exp_s = s.mul(-1).exp()
                x_change = x_change * inv_exp_s - t
            else:
                exp_s = s.exp()
                x_change = (x_change + t) * exp_s

                # Add log-determinant of the Jacobian
                sldj += s.view(s.size(0), -1).sum(-1)

            if self.reverse_mask:
                x = torch.cat((x_id, x_change), dim=1)
            else:
                x = torch.cat((x_change, x_id), dim=1)

        return x, sldj

class Scalar(nn.Module):
    def __init__(self):
        super(Scalar, self).__init__()
        self.weight = nn.Parameter(torch.randn(1))

    def forward(self, x):
        x = self.weight * x
        return x

#### Masking schemes

<img src="./image/figure4.png" align=left>

coupling layerで使用するmasking方式には**checkboard pattern mask**と**channel-wise masking**の2つが存在します。  
- checkboard pattern maskでは0と1が交互に並んだマスクを適用します。
- channel-wise maskingでは`squeeze_2x2()`で分割したチャネルのうち前半を1、後半を0とします。

In [0]:
def checkerboard_mask(height, width, reverse=False, dtype=torch.float32,
                      device=None, requires_grad=False):
    ###### checkerboard ######
    # [[0, 1, 0, ..., 1, 0, 1],
    #  [1, 0, 1, ..., 0, 1, 0],
    #  [0, 1, 0, ..., 1, 0, 1],
    #  ...,                    
    #  [1, 0, 1, ..., 0, 1, 0],
    #  [0, 1, 0, ..., 1, 0, 1],
    #  [1, 0, 1, ..., 0, 1, 0]]
    
    checkerboard = [[((i % 2) + j) % 2 for j in range(width)] for i in range(height)]
    mask = torch.tensor(checkerboard, dtype=dtype, device=device, requires_grad=requires_grad)

    if reverse:
        mask = 1 - mask

    # Reshape to (1, 1, height, width) for broadcasting with tensors of shape (B, C, H, W)
    mask = mask.view(1, 1, height, width)

    return mask

class MaskType(IntEnum):
    CHECKERBOARD = 0
    CHANNEL_WISE = 1

### 2.2. ネットワークの定義

Coupling_Layer(checkboard) x3 -> squeezing -> Coupling_Layer(channel-wise) x3 -> unsqueezing -> Coupling_Layer(checkboard) x4  
という構成のネットワークを定義します。

In [0]:
class RealNVPModule(nn.Module):
    def __init__(self, scale_idx, num_scales, in_channels, mid_channels, num_blocks):
        super(RealNVPModule, self).__init__()

        self.is_last_block = scale_idx == num_scales - 1

        self.in_couplings = nn.ModuleList([
            CouplingLayer(in_channels, mid_channels, num_blocks, MaskType.CHECKERBOARD, reverse_mask=False),
            CouplingLayer(in_channels, mid_channels, num_blocks, MaskType.CHECKERBOARD, reverse_mask=True),
            CouplingLayer(in_channels, mid_channels, num_blocks, MaskType.CHECKERBOARD, reverse_mask=False)
        ])

        if self.is_last_block:
            self.in_couplings.append(
                CouplingLayer(in_channels, mid_channels, num_blocks, MaskType.CHECKERBOARD, reverse_mask=True))
        else:
            self.out_couplings = nn.ModuleList([
                CouplingLayer(4 * in_channels, 2 * mid_channels, num_blocks, MaskType.CHANNEL_WISE, reverse_mask=False),
                CouplingLayer(4 * in_channels, 2 * mid_channels, num_blocks, MaskType.CHANNEL_WISE, reverse_mask=True),
                CouplingLayer(4 * in_channels, 2 * mid_channels, num_blocks, MaskType.CHANNEL_WISE, reverse_mask=False)
            ])
            self.next_block = RealNVPModule(scale_idx + 1, num_scales, 2 * in_channels, 2 * mid_channels, num_blocks)

    def forward(self, x, sldj, reverse=False):
        if reverse:
            if not self.is_last_block:
                # Re-squeeze -> split -> next block
                x = squeeze_2x2(x, reverse=False, alt_order=True)
                x, x_split = x.chunk(2, dim=1)
                x, sldj = self.next_block(x, sldj, reverse)
                x = torch.cat((x, x_split), dim=1)
                x = squeeze_2x2(x, reverse=True, alt_order=True)

                # Squeeze -> 3x coupling (channel-wise)
                x = squeeze_2x2(x, reverse=False)
                for coupling in reversed(self.out_couplings):
                    x, sldj = coupling(x, sldj, reverse)
                x = squeeze_2x2(x, reverse=True)

            for coupling in reversed(self.in_couplings):
                x, sldj = coupling(x, sldj, reverse)
        else:
            for coupling in self.in_couplings:
                x, sldj = coupling(x, sldj, reverse)

            if not self.is_last_block:
                # Squeeze -> 3x coupling (channel-wise)
                x = squeeze_2x2(x, reverse=False)
                for coupling in self.out_couplings:
                    x, sldj = coupling(x, sldj, reverse)
                x = squeeze_2x2(x, reverse=True)

                # Re-squeeze -> split -> next block
                x = squeeze_2x2(x, reverse=False, alt_order=True)
                x, x_split = x.chunk(2, dim=1)
                x, sldj = self.next_block(x, sldj, reverse)
                x = torch.cat((x, x_split), dim=1)
                x = squeeze_2x2(x, reverse=True, alt_order=True)

        return x, sldj

In [0]:
class RealNVP(nn.Module):
    def __init__(self, num_scales=2, in_channels=3, mid_channels=64, num_blocks=8):
        super(RealNVP, self).__init__()
        self.register_buffer('data_constraint', torch.tensor([0.9], dtype=torch.float32))

        self.flows = RealNVPModule(0, num_scales, in_channels, mid_channels, num_blocks)

    def forward(self, x, reverse=False):
        sldj = None
        if not reverse:
            # De-quantize and convert to logits
            x, sldj = self._pre_process(x)

        x, sldj = self.flows(x, sldj, reverse)

        return x, sldj

    def _pre_process(self, x):
        y = (x * 255. + torch.rand_like(x)) / 256.
        y = (2 * y - 1) * self.data_constraint
        y = (y + 1) / 2
        y = y.log() - (1. - y).log()

        # Save log-determinant of Jacobian of initial transform
        ldj = F.softplus(y) + F.softplus(-y) \
            - F.softplus((1. - self.data_constraint).log() - self.data_constraint.log())
        sldj = ldj.view(ldj.size(0), -1).sum(-1)

        return y, sldj

ロス関数を定義します。

In [0]:
class RealNVPLoss(nn.Module):
    def __init__(self, k=256):
        super(RealNVPLoss, self).__init__()
        self.k = k

    def forward(self, z, sldj):
        prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
        prior_ll = prior_ll.view(z.size(0), -1).sum(-1) - np.log(self.k) * np.prod(z.size()[1:])
        ll = prior_ll + sldj
        nll = -ll.mean()

        return nll

## 3. モデルの学習

In [0]:
model = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
model = model.to(device)

In [0]:
weight_decay = 5e-5
lr = 1e-3
train_bpd = []

loss_fn = RealNVPLoss()
param_groups = get_param_groups(model, weight_decay, norm_suffix='weight_g')
optimizer = optim.Adam(param_groups, lr=lr)

In [0]:
def train(epoch, model, trainloader, device, optimizer, loss_fn, max_grad_norm):
    print('\nEpoch: %d' % epoch)
    model.train()
    loss_meter = AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z, sldj = model(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            
            train_bpd.append(bits_per_dim(x, loss_meter.avg))
            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

In [0]:
def sample(model, batch_size, device):
    z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)
    x, _ = model(z, reverse=True)
    x = torch.sigmoid(x)
    x = x.detach().cpu().numpy().transpose(0, 2, 3, 1)
    return x

In [0]:
def test(epoch, model, testloader, device, loss_fn, num_samples):
    global best_loss
    model.eval()
    loss_meter = AverageMeter()
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for x, _ in testloader:
                x = x.to(device)
                z, sldj = model(x, reverse=False)
                loss = loss_fn(z, sldj)
                loss_meter.update(loss.item(), x.size(0))
                
                progress_bar.set_postfix(loss=loss_meter.avg,
                                         bpd=bits_per_dim(x, loss_meter.avg))
                progress_bar.update(x.size(0))

    # Save samples and data
    images = sample(model, num_samples, device)
    display_process(train_bpd, images, image_frame_dim=4, fix=True)
    display.clear_output(wait=True)
    display.display(pl.gcf())
    plt.close()

In [0]:
start_epoch = 0
num_epochs = 15
max_grad_norm = 100.
num_samples = 64
best_loss = 0
for epoch in range(start_epoch, start_epoch + num_epochs):
        train(epoch, model, trainloader, device, optimizer, loss_fn, max_grad_norm)
        test(epoch, model, testloader, device, loss_fn, num_samples)

## 4.結果

目安時間: 1epochあたり20分

<img src="./image/result2.png" aling=left>