In [3]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
import numpy as np
from torch.distributions import Categorical
from pathlib import Path
import torch.utils.data as data
from rsrch.datasets import font_awesome, tiny_imagenet
import rsrch.utils.data as data
import torch.utils.tensorboard as tensorboard
from contextlib import contextmanager
import rsrch.utils.visual as visual
import torchvision.transforms.functional as tv_F
from tqdm.auto import tqdm


class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask: Tensor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mask: Tensor
        self.register_buffer("mask", mask.type_as(self.weight))
    
    def forward(self, x: Tensor) -> Tensor:
        self.weight.data *= self.mask
        return super().forward(x)


class MaskedConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, mask: Tensor,
                 k=3, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU):
        super().__init__()
        
        p = k // 2
        self.conv = MaskedConv2d(mask, in_channels, out_channels, k, 1, p)
        self.norm = norm_layer(out_channels)
        self.act = act_layer()

    def forward(self, x: Tensor) -> Tensor:
        x = self.act(self.norm(self.conv(x)))
        return x


class MaskFactory:
    def input_layer_mask(self, in_channels: int, channel_idx: int,
                   kernel_size: Tuple[int, int]):
        kh, kw = kernel_size
        mask = torch.ones((in_channels, kh, kw))
        mask[channel_idx:, kh//2, kw//2] = 0
        mask[:, kh//2, (kw//2+1):] = 0
        mask[:, (kh//2+1):, :] = 0
        return mask

    def hidden_layer_mask(self, kernel_size: Tuple[int, int]):
        kh, kw = kernel_size
        mask = torch.ones((kh, kw))
        mask[kh//2, (kw//2+1):] = 0
        mask[(kh//2+1):, :] = 0
        return mask


@contextmanager
def eval_ctx(net: nn.Module):
    prev_val = net.training
    net.train = False
    with torch.no_grad():
        yield
    net.train = prev_val

class PixelCNN(nn.Module):
    def __init__(self, in_channels: int, num_values: int, num_hidden_layers: int,
                 hidden_dim: int, kernel_size=3):
        super().__init__()
        self.in_channels = in_channels
        self.num_values = num_values
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim
        self.kernel_size = self.k = kernel_size
        
        self.subnets = nn.ModuleList([
            self._make_channel_subnet(ch_idx)
            for ch_idx in range(self.in_channels)
        ])
    
    def _make_channel_subnet(self, ch_idx: int):
        input_mask = MaskFactory().input_layer_mask(
            in_channels=self.in_channels, channel_idx=ch_idx,
            kernel_size=(self.k, self.k))
        input_layer = MaskedConvBlock(
            in_channels=self.in_channels, out_channels=self.hidden_dim,
            mask=input_mask, k=self.k)
        
        hidden_mask = MaskFactory().hidden_layer_mask(
            kernel_size=(self.k, self.k))
        hidden_layers = [
            MaskedConvBlock(in_channels=self.hidden_dim, 
                            out_channels=self.hidden_dim,
                            mask=hidden_mask,
                            k=self.k)
            for _ in range(self.num_hidden_layers)
        ]
        
        p = self.kernel_size // 2
        final_layer = nn.Conv2d(self.hidden_dim, self.num_values,
                                self.k, 1, p)
        
        return nn.Sequential(input_layer, *hidden_layers, final_layer)
    
    def forward(self, x: Tensor, ch_idx=None):        
        if ch_idx is None:
            outs = [self.subnets[idx](x) for idx in range(self.in_channels)]
        else:
            outs = [self.subnets[ch_idx](x)]
        
        outs = torch.stack(outs, dim=1) # [B, C_in or 1, N_v, H, W]
        outs = outs.permute(0, 1, 3, 4, 2) # [B, C_in or 1, H, W, N_v]
        if ch_idx is not None:
            outs = outs[:, 0]
        outs = Categorical(logits=outs)
        return outs

    def predict(self, images: Tensor, start_pos: Tuple[int, int]):
        ix0, iy0 = start_pos
        _, c_in, h, w = images.shape
        
        result = images.clone()
        
        with eval_ctx(self):
            for iy, ix, ic in np.ndindex((h, w, c_in)):
                if (iy, ix) < (iy0, ix0):
                    continue
                
                # Here we obtain Categorical(num_values) over [B, H, W]
                value_dist: Categorical = self(result, ch_idx=ic)
                logits_at_point = value_dist.logits[:, iy, ix] # [B, N_v]
                preds = logits_at_point.argmax(-1)
                result[:, ic, iy, ix] = preds
        
        return result
                

class PixelCNNData:
    def __init__(self):
        self._setup_tiny_imagenet()
        # self._setup_font_awesome()
    
    def _setup_tiny_imagenet(self):
        ds_root = "../datasets/tiny-imagenet-200"
        self.train_ds = tiny_imagenet.TinyImageNet(root=ds_root, split="train")
        self.val_ds = tiny_imagenet.TinyImageNet(root=ds_root, split="val")

        def val_transform(item: font_awesome.Item) -> Tensor:
            image = item.image
            image = image.convert("RGB")
            image = tv_F.center_crop(image, (32, 32))
            image = tv_F.to_tensor(image)
            return image

        def train_transform(item: font_awesome.Item) -> Tensor:
            image = val_transform(item)
            return image

        self.train_ds = self.train_ds.map(train_transform)
        self.val_ds = self.val_ds.map(val_transform)
        
        self.in_channels = 3
        self.num_values = 256
        
    def train_loader(self, batch_size: int):
        return data.DataLoader(
            dataset=self.train_ds,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
        )
    
    def val_loader(self, batch_size: int):
        return data.DataLoader(
            dataset=self.val_ds,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
        )

class Trainer:
    def __init__(self):
        self.batch_size = 32
        self.val_batch_size = self.batch_size
        self.num_epochs = 32
        self.num_val_samples = 8
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        
        self.writer = tensorboard.SummaryWriter()
    
    def train(self, pix: PixelCNN, pix_data: PixelCNNData):
        pix = pix.to(self.device)
        
        train_loader = pix_data.train_loader(self.batch_size)
        val_loader = pix_data.val_loader(self.val_batch_size)
        
        optim = torch.optim.Adam(pix.parameters(), lr=1e-3)
        
        pbar = tqdm()
        
        step_idx = 0
        for epoch in range(self.num_epochs):
            pbar.reset()
            pbar.set_description(f"Train #{epoch}")
            pbar.total = len(train_loader)
            
            for images in train_loader:
                images = images.to(self.device)
                value_dist: Categorical = pix(images)
                targets = (255 * images).int()
                loss: Tensor = -value_dist.log_prob(targets).mean()
                
                optim.zero_grad(set_to_none=True)
                loss.backward()
                optim.step()
            
                self.writer.add_scalar("train/loss", loss, global_step=step_idx)
                step_idx += len(images)
                pbar.update()
            
            pbar.reset()
            pbar.set_description(f"Val #{epoch}")
            pbar.total = len(val_loader)
            
            val_grid = []
            val_loss = 0
            with eval_ctx(pix):
                for images in val_loader:
                    images = images.to(self.device)
                    value_dist: Categorical = pix(images)
                    targets = (255 * images).int()
                    batch_loss: Tensor = -value_dist.log_prob(targets).mean()
                    val_loss += batch_loss * len(images)
                    
                    if len(val_grid) < self.num_val_samples:
                        rem = self.num_val_samples - len(val_grid)
                        rem = min(len(images), rem)
                        h, w = images.shape[-2:]
                        preds = pix.predict(images[:rem], (w//2, h//2))
                        for image, pred in zip(images[:rem], preds):
                            val_grid.extend([image, pred])
                    
                    pbar.update()
                
                self.writer.add_scalar("val/loss", val_loss, global_step=epoch)
                pbar.set_postfix({"val/loss": val_loss})
                
                val_grid = visual.make_grid(val_grid, ncols=4)
                self.writer.add_image("val/samples", tv_F.to_tensor(val_grid),
                                      global_step=epoch)
                

pix_data = PixelCNNData()
pix = PixelCNN(pix_data.in_channels, pix_data.num_values,
               num_hidden_layers=8, hidden_dim=64, kernel_size=5)
trainer = Trainer()

trainer.train(pix, pix_data)

0it [00:00, ?it/s]

KeyboardInterrupt: 