In [3]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from tqdm.autonotebook import tqdm, trange
from skimage.color import rgb2lab, lab2rgb
from PIL import Image

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from sklearn.model_selection import train_test_split

In [4]:
seed = 42
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)

In [5]:
BATCH_SIZE = 8
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
UP_SIZE = 512
epochs = 20

PATH = 'model.pth'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using "{device.upper()}" device.')

In [6]:
size_transforms = A.Compose([A.Resize(height=IMAGE_WIDTH, width=IMAGE_HEIGHT, p=1),
])
final_transforms = ToTensorV2(p=1.0)
upsize_transforms = A.Resize(height=UP_SIZE, width=UP_SIZE, p=1)

In [7]:
coco = glob(r'../input/cocotest2014/test2014/' + '*.jpg')
coco = sorted([str(x) for x in coco])

df = pd.DataFrame(data={'color': coco, 'name': np.zeros(len(coco))})
# clean data from grayscale images
gray_indices = []
for i in trange(len(df)):
    img = Image.open(df.loc[i, 'color'])
    if img.mode == 'L':
        gray_indices.append(i)
df.drop(gray_indices, inplace=True)

train, valid = train_test_split(df, test_size=4000, shuffle=True, random_state=seed)
print(f'Train size: {len(train)}, valid size: {len(valid)}')

In [8]:
example = df.sample(1000, random_state=seed)

In [9]:
def shift_right(x):
    return torch.roll(x,1,dims=2)

def shift_down(x):
    return torch.roll(x,1,dims=1)

def create_mask(batch_size, W):
    mask = np.tril(np.ones((batch_size,W,W)),k=0).astype("uint8")
    return torch.Tensor(mask).int()

def positionalencoding2d(d_model, height, width, batch_size):
    if d_model % 4 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dimension (got dim={:d})".format(d_model))
    pe = torch.zeros(d_model, height, width)
    d_model = d_model // 2
    div_term = torch.exp(torch.arange(0., d_model, 2) *
                         -(np.log(10000.0) / d_model))
    pos_w = torch.arange(0., width).unsqueeze(1)
    pos_h = torch.arange(0., height).unsqueeze(1)
    pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

    return pe.permute(1,2,0).repeat(batch_size,1,1,1)

def convertTo3bit(x, N):
    return torch.round(torch.round(x*(N/255))*(255/N)).long()

def intTo3bit(value):
    tmp = value
    v1 = 64
    c1 = value // v1
    tmp -= c1 * v1
    v2 = 8
    c2 = tmp // v2
    tmp -= c2 * v2
    v3 = 1
    c3 = tmp
    return (c1,c2,c3)

def bitsToInt(channels):
    res = 0
    for k in range(3):
        res += torch.round(channels[k]*(7/255)) * 8**k 
    return res

def toOneChannel(x):
    n, r, c, _ = x.shape
    res = torch.zeros(n, r, c)
    for b in range(n):
        for i in range(r):
            for j in range(c): 
                res[b,i,j] = bitsToInt(x[b,i,j])
    return res

In [10]:
positionalencoding2d(32, 256, 256, 8).shape

In [11]:
class ColorDataset(Dataset):
    """
    transforms: 256 or 512 image size
    convert: "bit", "upscale", "all"
    """
    def __init__(self, df, convert, image_size=256):
        self.df = df
        self.image_size = image_size
        self.transforms = A.Resize(height=image_size, width=image_size, p=1)
        self.convert = convert
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, ix):
        row = self.df.iloc[ix].squeeze()
        color_image = cv2.imread(row['color'])
        color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
        color_image = self.transforms(image=color_image)['image']
        color_image = np.array(color_image)
        
        img_lab = rgb2lab(color_image).astype("float32")
        img_lab = final_transforms(image=img_lab)['image']
        # L = img_lab[[0], ...] / 50. - 1.
        L = img_lab[[0], ...].squeeze(0).long()  # / 100.
        color_image = torch.as_tensor(color_image, dtype=torch.int32) # final_transforms(image=color_image)['image'][:,:,::-1]

        ### + / 255. normalization ###
        if self.convert == 'bit':
            bit_image = convertTo3bit(color_image, 7).long()
            return L, bit_image
        elif self.convert == 'all':
            bit_image = convertTo3bit(color_image, 7).long()
            return L, bit_image, color_image
        elif self.convert == 'upscale':  # image_size = 512
            scaled_image = nn.functional.interpolate(color_image.unsqueeze(0).permute(0,3,1,2).float(), 
                                                     (self.image_size//2, self.image_size//2), 
                                                     mode="bilinear").long().squeeze().permute(1,2,0)
            scaled_image = scaled_image.squeeze(0)
            return L, scaled_image, color_image
        else:
            raise ValueError(f'Unknown transformations provided.')
            
    def collate_fn(self, batch):
        tup = list(zip(*batch))
        tup = [[tensor[None].to(device) for tensor in arr] for arr in tup]
        tup = [torch.cat(i) for i in tup]
        return tup

In [12]:
upscale_ds = ColorDataset(example, 'upscale', image_size=512)
all_ds = ColorDataset(example, 'all')
bit_ds = ColorDataset(example, 'bit')

In [13]:
bit_ds[0][1].shape

In [14]:
class MLP(nn.Module):
    def __init__(self, D):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(D,D)
        self.fc2 = nn.Linear(D,D)
        self.relu = nn.ReLU()

    def forward(self,x):
        return self.relu(self.fc2(self.fc1(x)))

class AttentionBlock(nn.Module):
    def __init__(self, t, D):
        super(AttentionBlock, self).__init__()

        self.D = D
        self.Type = t

        # head 1
        self.Q1 = nn.Linear(D,D,bias=False)
        self.K1 = nn.Linear(D,D,bias=False)
        self.V1 = nn.Linear(D,D,bias=False)
        # head 2
        self.Q2 = nn.Linear(D,D,bias=False)
        self.K2 = nn.Linear(D,D,bias=False)
        self.V2 = nn.Linear(D,D,bias=False)
        # head 3
        self.Q3 = nn.Linear(D,D,bias=False)
        self.K3 = nn.Linear(D,D,bias=False)
        self.V3 = nn.Linear(D,D,bias=False)
        # head 4
        self.Q4 = nn.Linear(D,D,bias=False)
        self.K4 = nn.Linear(D,D,bias=False)
        self.V4 = nn.Linear(D,D,bias=False)
        # linear 
        self.out = nn.Linear(4*D,D)

        self.LN = nn.LayerNorm(D)

        self.mlp = MLP(D)

    def forward(self, input_):

        batch, row, col, _ = input_.shape
        out = torch.empty_like(input_)
      
        softmax = torch.nn.Softmax(-1)

        if self.Type == 'row':          
            for i in range(row):
                ln_input = self.LN(input_[:,i,:,:])

                A1 = softmax(torch.matmul(self.Q1(ln_input),
                                          self.K1(ln_input).transpose(1,2))/np.sqrt(self.D))
                A2 = softmax(torch.matmul(self.Q2(ln_input),
                                          self.K2(ln_input).transpose(1,2))/np.sqrt(self.D))
                A3 = softmax(torch.matmul(self.Q3(ln_input),
                                          self.K3(ln_input).transpose(1,2))/np.sqrt(self.D))
                A4 = softmax(torch.matmul(self.Q4(ln_input),
                                          self.K4(ln_input).transpose(1,2))/np.sqrt(self.D))
                # A : BATCH * W * W

                HA1 = torch.matmul(A1,self.V1(ln_input))
                HA2 = torch.matmul(A2,self.V2(ln_input))
                HA3 = torch.matmul(A3,self.V3(ln_input))
                HA4 = torch.matmul(A4,self.V4(ln_input))

                HSA = self.out(torch.cat((HA1,HA2,HA3,HA4),2))
 
                tmp = HSA + input_[:,i,:,:] 
  
                out[:,i,:,:] = self.mlp(self.LN(tmp)) + tmp # W * D
          
          # ColumnAttention
        else:
            for j in range(col):
                ln_input = self.LN(input_[:,:,j,:])

                A1 = softmax(torch.matmul(self.Q1(ln_input),
                                          self.K1(ln_input).transpose(1,2))/np.sqrt(self.D))
                A2 = softmax(torch.matmul(self.Q2(ln_input),
                                          self.K2(ln_input).transpose(1,2))/np.sqrt(self.D))
                A3 = softmax(torch.matmul(self.Q3(ln_input),
                                          self.K3(ln_input).transpose(1,2))/np.sqrt(self.D))
                A4 = softmax(torch.matmul(self.Q4(ln_input),
                                          self.K4(ln_input).transpose(1,2))/np.sqrt(self.D))
              # Ai : BATCH * H * H

                HA1 = torch.matmul(A1,self.V1(ln_input))
                HA2 = torch.matmul(A2,self.V2(ln_input))
                HA3 = torch.matmul(A3,self.V3(ln_input))
                HA4 = torch.matmul(A4,self.V4(ln_input))

                HSA = self.out(torch.cat((HA1,HA2,HA3,HA4),2))
 
                tmp = HSA + input_[:,:,j,:]

                out[:,:,j,:] = self.mlp(self.LN(tmp)) + tmp # H * D

        return out

In [15]:
class MLPConditional(nn.Module):
    def __init__(self, D):
        super(MLPConditional, self).__init__()

        self.fc1 = nn.Linear(D,D)
        self.fc2 = nn.Linear(D,D)
        self.relu = nn.ReLU()

    def forward(self, x, conv1_context_h, conv2_context_h):      
        h = self.relu(self.fc2(self.fc1(x)))
        y = conv1_context_h * h + conv2_context_h
        return y

class AttentionBlockConditional(nn.Module):
    def __init__(self, t, D):

        super(AttentionBlockConditional, self).__init__()

        self.D = D
        self.Type = t

        # head 1
        self.Q1 = nn.Linear(D,D,bias=False)
        self.K1 = nn.Linear(D,D,bias=False)
        self.V1 = nn.Linear(D,D,bias=False)
        # head 2
        self.Q2 = nn.Linear(D,D,bias=False)
        self.K2 = nn.Linear(D,D,bias=False)
        self.V2 = nn.Linear(D,D,bias=False)
        # head 3
        self.Q3 = nn.Linear(D,D,bias=False)
        self.K3 = nn.Linear(D,D,bias=False)
        self.V3 = nn.Linear(D,D,bias=False)
        # head 4
        self.Q4 = nn.Linear(D,D,bias=False)
        self.K4 = nn.Linear(D,D,bias=False)
        self.V4 = nn.Linear(D,D,bias=False)
        # linear 
        self.out = nn.Linear(4*D,D)

        self.conv1_z = nn.Conv2d(D,D,kernel_size=1,bias=False)
        self.conv2_z = nn.Conv2d(D,D,kernel_size=1,bias=False)
        self.conv1_h = nn.Conv2d(D,D,kernel_size=1,bias=False)
        self.conv2_h = nn.Conv2d(D,D,kernel_size=1,bias=False)
      
        self.LN = nn.LayerNorm(D)
        self.mean_avg_pool = nn.AvgPool1d(kernel_size=1)

        self.mlp = MLPConditional(D)

    def forward(self, input_, context, mask = None):
        conv1_context_z = self.conv1_z(context.transpose(1,-1)).transpose(1,-1)
        conv2_context_z = self.conv2_z(context.transpose(1,-1)).transpose(1,-1)
      
        conv1_context_h = self.conv1_h(context.transpose(1,-1)).transpose(1,-1)
        conv2_context_h = self.conv2_h(context.transpose(1,-1)).transpose(1,-1)

        batch, row, col, _ = input_.shape
        out = torch.empty_like(input_)
      
        softmax = nn.Softmax(-1)

        if self.Type == 'row':
            for i in range(row):   
            
                ln_input = self.LN(input_[:,i,:,:])
                maski = 1 if mask is None else mask[:,i,:].unsqueeze(-1) @ torch.ones(1, col)
          
                Q1c = self.Q1(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                K1c = self.K1(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                V1c = self.V1(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]

                Q2c = self.Q2(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                K2c = self.K2(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                V2c = self.V2(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]

                Q3c = self.Q3(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                K3c = self.K3(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                V3c = self.V3(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]

                Q4c = self.Q4(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                K4c = self.K4(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]
                V4c = self.V4(ln_input) * conv1_context_z[:,i,:,:] + conv2_context_z[:,i,:,:]

                A1 = softmax(torch.matmul(Q1c, K1c.transpose(1,2)) * maski / np.sqrt(self.D))
                A2 = softmax(torch.matmul(Q2c, K2c.transpose(1,2)) * maski / np.sqrt(self.D))
                A3 = softmax(torch.matmul(Q3c, K3c.transpose(1,2)) * maski / np.sqrt(self.D))
                A4 = softmax(torch.matmul(Q4c, K4c.transpose(1,2)) * maski / np.sqrt(self.D))
          
                # W * W

                HA1 = torch.matmul(A1,V1c)
                HA2 = torch.matmul(A2,V2c)
                HA3 = torch.matmul(A3,V3c)
                HA4 = torch.matmul(A4,V4c)

                HSA = self.out(torch.cat((HA1,HA2,HA3,HA4),2))

                tmp = HSA + input_[:,i,:,:] 

                out[:,i,:,:] = self.mlp(self.LN(tmp), conv1_context_h[:,i,:,:], conv2_context_h[:,i,:,:]) + tmp # W * D
          
        # ColumnAttention
        else:
            for j in range(col):
                ln_input = self.LN(input_[:,:,j,:])
                maskj = 1 if mask is None else mask[:,:,j].unsqueeze(-1) @ torch.ones(1, row) 
          
                Q1c = self.Q1(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                K1c = self.K1(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                V1c = self.V1(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]

                Q2c = self.Q2(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                K2c = self.K2(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                V2c = self.V2(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]

                Q3c = self.Q3(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                K3c = self.K3(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                V3c = self.V3(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]

                Q4c = self.Q4(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                K4c = self.K4(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]
                V4c = self.V4(ln_input) * conv1_context_z[:,:,j,:] + conv2_context_z[:,:,j,:]

                A1 = softmax(torch.matmul(Q1c, K1c.transpose(1,2)) * maskj / np.sqrt(self.D))
                A2 = softmax(torch.matmul(Q2c, K2c.transpose(1,2)) * maskj / np.sqrt(self.D))
                A3 = softmax(torch.matmul(Q3c, K3c.transpose(1,2)) * maskj / np.sqrt(self.D))
                A4 = softmax(torch.matmul(Q4c, K4c.transpose(1,2)) * maskj / np.sqrt(self.D))

                # W * W

                HA1 = torch.matmul(A1,V1c)
                HA2 = torch.matmul(A2,V2c)
                HA3 = torch.matmul(A3,V3c)
                HA4 = torch.matmul(A4,V4c)

                HSA = self.out(torch.cat((HA1,HA2,HA3,HA4),2))

                tmp = HSA + input_[:,:,j,:] 

                out[:,:,j,:] = self.mlp(self.LN(tmp),conv1_context_h[:,:,j,:], conv2_context_h[:,:,j,:]) + tmp # W * D
          

        return out

In [16]:
class TransformerEncoder(nn.Module):
    def __init__(self, D):
        super(TransformerEncoder, self).__init__()

        self.D = D
        self.RowAttention = AttentionBlock('row',D)
        self.ColumnAttention = AttentionBlock('col',D)

    def forward(self, input_):
        input_ = self.RowAttention(input_)
        input_ = self.ColumnAttention(input_)
        return input_

class GrayscaleEncoder(nn.Module):

    def __init__(self, D):
        super(GrayscaleEncoder, self).__init__()

        self.D = D

        self.TransformerEncoder_Layer1 = TransformerEncoder(D) 
        # self.TransformerEncoder_Layer2 = TransformerEncoder(D)
        # self.TransformerEncoder_Layer3 = TransformerEncoder(D)
        # self.TransformerEncoder_Layer4 = TransformerEncoder(D)


    def forward(self, embedding_x_g):
        """
        embedding_x_g : B* M * N * D
        out : B * M * N * D
        """
        out = self.TransformerEncoder_Layer1(embedding_x_g)
        # out = self.TransformerEncoder_Layer2(out)
        # out = self.TransformerEncoder_Layer3(out)
        # out = self.TransformerEncoder_Layer4(out)

        return out

In [17]:
class TransformerDecoderInner(nn.Module):
    def __init__(self, D):
        super(TransformerDecoderInner, self).__init__()

        self.D = D
        self.ConditionalRowAttention = AttentionBlockConditional('row',D)

    def forward(self, input_, ctx_encoder_decoder, j):
        batch,row,col,_ = input_.shape
    
        mask = torch.ones(input_.shape[:-1])
        mask[:,:,j:] = -1e9
    
        out = self.ConditionalRowAttention(input_, ctx_encoder_decoder, mask)

        return out

class InnerDecoder(nn.Module):
    """
    Generate a row, one pixel at a time
    """
    def __init__(self, D):
        super(InnerDecoder, self).__init__()
        self.TransformerDecoderInner_Layer1 = TransformerDecoderInner(D) 

    def forward(self, emb_x_s_c, ctx_encoder_decoder, j):
        """
        z = o + ShiftRight(e)
        h = MaskedRow(z)
        p(xij) = Dense(h)
        """
        out = self.TransformerDecoderInner_Layer1(emb_x_s_c, ctx_encoder_decoder, j)

        return out

In [18]:
class TransformerDecoderOuter(nn.Module):
    def __init__(self, D):
        super(TransformerDecoderOuter, self).__init__()

        self.D = D
        self.ConditionalRowAttention = AttentionBlockConditional('row',D)
        self.ConditionalColumnAttention = AttentionBlockConditional('col',D)
        self.row = None

    def forward(self, input_, ctx_grayscale_encoder, i):
        batch,row,col,_ = input_.shape
    
        mask = torch.ones(input_.shape[:-1])
        mask[:,i+1:,:] = -1e9
    
        if self.row is None:
            self.row = self.ConditionalRowAttention(input_, ctx_grayscale_encoder)

        out = self.ConditionalColumnAttention(self.row, ctx_grayscale_encoder, mask)

        return out

class OuterDecoder(nn.Module):
    def __init__(self, D):
        super(OuterDecoder, self).__init__()
        self.D = D
        self.TransformerDecoderOuter_Layer1 = TransformerDecoderOuter(D) 

    def forward(self, emb_x_s_c, ctx_grayscale_encoder, i):
        """
            |e = Embeddings(x)
        N x |s_o = MaskedColumn(Row(e))
            |o = ShiftDown(s_o)
        
        """
        out = self.TransformerDecoderOuter_Layer1(emb_x_s_c, ctx_grayscale_encoder, i)
    
        out = shift_down(out)

        return out

In [19]:
class ColTranSpatialUpsampler(nn.Module):

    def __init__(self, D, NColor, H, W):
        super(ColTranSpatialUpsampler, self).__init__()

        self.D = D
        self.H = H
        self.W = W
        self.NColor = NColor

        self.embedding_x_g = nn.Embedding(NColor,D) # BATCH * H * W * D
        self.embedding_x_rgb = nn.Sequential(nn.Embedding(NColor,D), nn.Embedding(NColor,D), nn.Embedding(NColor,D))

        self.grayscale_encoder = GrayscaleEncoder(D)
        self.linear = nn.Linear(D,NColor)

    def forward(self, x_g, x_s):
        """
        x_s : M * N * 3
        x_g : H * W * 1
        return
        x :  H * W * 3
        """
    
        x_s = nn.functional.interpolate(x_s.permute(0,3,1,2).float(),size=(self.H, self.W),mode="bilinear").permute(0,2,3,1).long()
        batch,row,col,channel = x_s.shape
        pe = positionalencoding2d(self.D, row, col, batch)
    
        emb_g = pe + self.embedding_x_g(x_g)  
        out = torch.zeros(batch, row, col, channel, self.NColor)   
    
        for k in range(channel):
        
            emb_k = pe + self.embedding_x_rgb[k](x_s[:,:,:,k])  
            input_encoder = emb_g + emb_k
            out_encoder = self.grayscale_encoder(input_encoder)
  
            out[:,:,:,k] = self.linear(out_encoder)

        return out.argmax(-1), out

In [20]:
class ColTranColorUpsampler(nn.Module):

    def __init__(self, D, NColor):
        super(ColTranColorUpsampler, self).__init__()

        self.D = D
        self.NColor = NColor

        self.embedding_x_g = nn.Embedding(NColor,D) # BATCH * M * N * D
        self.embedding_x_rgb = nn.Sequential(nn.Embedding(NColor,D), nn.Embedding(NColor,D), nn.Embedding(NColor,D))

        self.grayscale_encoder = GrayscaleEncoder(D)

        self.linear = nn.Linear(D,NColor)

    def forward(self, x_g, x_s_c):
        """
        x_s_c : M * N * 3
        x_g : M * N * 1
        return:
        x : H * W * 3
        """
    
        batch,row,col,channel = x_s_c.shape
        pe = positionalencoding2d(self.D, row, col, batch)
        emb_g = pe + self.embedding_x_g(x_g)  
    
        out = torch.zeros(batch, row, col, channel, self.NColor)    
    
        for k in range(channel):
        
            emb_k = pe + self.embedding_x_rgb[k](x_s_c[:,:,:,k])       
            input_encoder = emb_g + emb_k
            out_encoder = self.grayscale_encoder(input_encoder)
            out[:,:,:,k] = self.linear(out_encoder)

        return out.argmax(-1), out

In [21]:
class ColTranCore(nn.Module):

    def __init__(self, D, nb_colors):
        super(ColTranCore, self).__init__()

        self.D = D
        self.nb_colors = nb_colors

        self.embedding_x_g = nn.Embedding(nb_colors, D) # H * W * D
        self.embedding_x_s_c = nn.Embedding(nb_colors, D) # H * W * D

        self.grayscale_encoder = GrayscaleEncoder(D)
        self.outer_decoder = OuterDecoder(D)
        self.inner_decoder = InnerDecoder(D)

        self.out_inner_decoder = nn.Linear(D,512)
        self.out_grayscale_encoder = nn.Linear(D,512)
    
    def sampling(self, proba):
        value = torch.linspace(0, self.nb_colors, 8, dtype=torch.long)
        b, m, n, l = proba.shape
        x_hat_s_c = torch.zeros(b,m,n,3)
        
        for i in range(m):
            for j in range(n):
                prob_dist = torch.distributions.Categorical(proba[:,i,j,:])
                rgb = intTo3bit(prob_dist.sample()) 
                for k in range(3):
                    x_hat_s_c[:,i,j,k] = value[rgb[k]]
        return x_hat_s_c

    def forward(self, x_g, x_s_c=None):
        """
        x_g : B * M * N
        x_s_c : B * M * N * 3
        return
        proba : B * M * N * 8^3
        """
        batch, row, col, channel = x_g.shape if x_s_c is None else x_s_c.shape
        pe = positionalencoding2d(self.D, row, col, batch)

        out_g = pe + self.embedding_x_g(x_g)
        out_i = torch.zeros(batch, row, col, self.D)
        out_o = torch.zeros(batch, row, col, self.D)
    
        projection = torch.zeros(batch, row, col, 512)
        x_hat_s_c = torch.zeros(batch, row, col, 3)
    
        for k in range(channel):
            x_s_ck = self.embedding_x_s_c((x_g if x_s_c is None else x_s_c)[:,:,:,k]) + pe
            out_g = self.grayscale_encoder(out_g)
            for i in range(row):
                out_o = self.outer_decoder(x_s_ck, out_g, i)
            
                context_i = (out_g + out_o)[:,i].unsqueeze(1)
                input_i = (context_i + shift_right(x_s_ck))[:,i].unsqueeze(1)
            
                for j in range(col):
                    out_i = self.inner_decoder(input_i, context_i, j)
                    projection[:,i,j] += self.out_inner_decoder(out_i[:,0,j])
                    # B * M * N * 8**3           
        
            # On reset conditionnal row attention               
            self.outer_decoder = OuterDecoder(self.D)

        x_hat_s_c = self.sampling(projection.softmax(-1))    

        return x_hat_s_c, projection, self.out_grayscale_encoder(out_g) 

In [22]:
criterion = nn.CrossEntropyLoss()
M = N = 64
H = W = 256
D = 32 # 512

core = ColTranCore(D, 256).to(device)
color = ColTranColorUpsampler(D, 256).to(device)
spatial = ColTranSpatialUpsampler(D, 256, H, W).to(device)

optimizer_core = torch.optim.RMSprop(core.parameters(), lr=3e-4)
optimizer_color = torch.optim.RMSprop(color.parameters(), lr=3e-4)
optimizer_spatial = torch.optim.RMSprop(spatial.parameters(), lr=3e-4)

In [23]:
x_g = torch.randint(0, 255, [1,M,N])
x_s_c = torch.randint(0, 255, [1,M,N,3])

core.eval()
out1 = core(x_g, x_s_c)

In [24]:
out1[0].shape, out1[1].shape, out1[2].shape

In [25]:
x_g = torch.randint(0, 255, [1,M,N])
x_s = torch.randint(0, 255, [1,M,N,3])

color.eval()
out2 = color(x_g, x_s)

In [26]:
out2[0].shape, out2[1].shape

In [None]:
x_g = torch.randint(0, 255, [1,H,W])
x_s = torch.randint(0, 255, [1,M,N,3])

spatial.eval()
out3 = spatial(x_g, x_s)