In [2]:
import torch
from vq_vae import DAVIS
from vq_vae import Model as Model_VQ_VAE


device = torch.device("cuda:3")
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 64
num_embeddings = 512
commitment_cost = 0.25
decay = 0.99
learning_rate = 1e-3

vq_vae = Model_VQ_VAE(num_hiddens, num_residual_layers, num_residual_hiddens,
                      num_embeddings, embedding_dim,
                      commitment_cost, decay).to(device)
vq_vae.load_state_dict(torch.load('out/model_46k.pt'))


# Dataset Parameters
num_frames = 12  # each video uses 10 consecutive frames for training
davis_root = '/playpen-raid2/qinliu/data/DAVIS'

trainset = DAVIS(root=davis_root, num_frames=num_frames, train=True)
valset = DAVIS(root=davis_root, num_frames=num_frames, train=False)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                           shuffle=True, num_workers=1)

val_loader = torch.utils.data.DataLoader(valset, batch_size=1,
                                         shuffle=True, num_workers=1)

vq_vae.eval()

inputs, _ = next(iter(train_loader))
inputs = torch.squeeze(inputs, dim=0)
inputs = inputs.to(device)
inputs_shape = inputs.shape

code_indices = vq_vae.get_code_indices(inputs)
code_indices = code_indices.view(num_frames, inputs_shape[2] // 4, inputs_shape[3] // 4)

embeddings = vq_vae._vq_vae.quantize(code_indices)

print(code_indices.shape, embeddings.shape)


torch.Size([98304]) torch.Size([98304, 64])


In [4]:
import torch.nn as nn
import torch.nn.functional as F


# train PixelCNN to generate new images

class MaskedConv2d(nn.Conv2d):
    """
    Implements a conv2d with mask applied on its weights.
    
    Args:
        mask (torch.Tensor): the mask tensor.
        in_channels (int): Number of channels in the input image.
        out_channels (int):  Number of channels produced by the convolution.
        kernel_size (int or tuple): Size of the convolving kernel
    """
    
    def __init__(self, mask, in_channels, out_channels, kernel_size, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
        self.register_buffer('mask', mask[None, None])
        
    def forward(self, x):
        self.weight.data *= self.mask # mask weights
        return super().forward(x)
    

class VerticalStackConv(MaskedConv2d):

    def __init__(self, mask_type, in_channels, out_channels, kernel_size, **kwargs):
        # Mask out all pixels below. For efficiency, we could also reduce the kernel
        # size in height (k//2, k), but for simplicity, we stick with masking here.
        self.mask_type = mask_type
        
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        mask = torch.zeros(kernel_size)
        mask[:kernel_size[0]//2, :] = 1.0
        if self.mask_type == "B":
            mask[kernel_size[0]//2, :] = 1.0

        super().__init__(mask, in_channels, out_channels, kernel_size, **kwargs)
        

class HorizontalStackConv(MaskedConv2d):

    def __init__(self, mask_type, in_channels, out_channels, kernel_size, **kwargs):
        # Mask out all pixels on the left. Note that our kernel has a size of 1
        # in height because we only look at the pixel in the same row.
        self.mask_type = mask_type
        
        if isinstance(kernel_size, int):
            kernel_size = (1, kernel_size)
        assert kernel_size[0] == 1
        if "padding" in kwargs:
            if isinstance(kwargs["padding"], int):
                kwargs["padding"] = (0, kwargs["padding"])
        
        mask = torch.zeros(kernel_size)
        mask[:, :kernel_size[1]//2] = 1.0
        if self.mask_type == "B":
            mask[:, kernel_size[1]//2] = 1.0

        super().__init__(mask, in_channels, out_channels, kernel_size, **kwargs)
        
class GatedMaskedConv(nn.Module):

    def __init__(self, in_channels, kernel_size=3, dilation=1):
        """
        Gated Convolution block implemented the computation graph shown above.
        """
        super().__init__()
        
        padding = dilation * (kernel_size - 1) // 2
        self.conv_vert = VerticalStackConv("B", in_channels, 2*in_channels, kernel_size, padding=padding,
                                          dilation=dilation)
        self.conv_horiz = HorizontalStackConv("B", in_channels, 2*in_channels, kernel_size, padding=padding,
                                             dilation=dilation)
        self.conv_vert_to_horiz = nn.Conv2d(2*in_channels, 2*in_channels, kernel_size=1)
        self.conv_horiz_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, v_stack, h_stack):
        # Vertical stack (left)
        v_stack_feat = self.conv_vert(v_stack)
        v_val, v_gate = v_stack_feat.chunk(2, dim=1)
        v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate)

        # Horizontal stack (right)
        h_stack_feat = self.conv_horiz(h_stack)
        h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat)
        h_val, h_gate = h_stack_feat.chunk(2, dim=1)
        h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate)
        h_stack_out = self.conv_horiz_1x1(h_stack_feat)
        h_stack_out = h_stack_out + h_stack

        return v_stack_out, h_stack_out
    
    
class GatedPixelCNN(nn.Module):
    
    def __init__(self, in_channels, channels, out_channels):
        super().__init__()
        
        # Initial first conv with mask_type A
        self.conv_vstack = VerticalStackConv("A", in_channels, channels, 3, padding=1)
        self.conv_hstack = HorizontalStackConv("A", in_channels, channels, 3, padding=1)
        # Convolution block of PixelCNN. use dilation instead of 
        # downscaling used in the encoder-decoder architecture in PixelCNN++
        self.conv_layers = nn.ModuleList([
            GatedMaskedConv(channels),
            GatedMaskedConv(channels, dilation=2),
            GatedMaskedConv(channels)
        ])
        
        # Output classification convolution (1x1)
        self.conv_out = nn.Conv2d(channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        # first convolutions
        v_stack = self.conv_vstack(x)
        h_stack = self.conv_hstack(x)
        # Gated Convolutions
        for layer in self.conv_layers:
            v_stack, h_stack = layer(v_stack, h_stack)
        # 1x1 classification convolution
        # Apply ELU before 1x1 convolution for non-linearity on residual connection
        out = self.conv_out(F.elu(h_stack))
        return out

In [6]:
# Train GatedPixelCNN to learn the prior of latent code indices
num_training_updates = 500000


pixelcnn = GatedPixelCNN(num_embeddings, 128, num_embeddings)
pixelcnn = pixelcnn.cuda()

optimizer = torch.optim.Adam(pixelcnn.parameters(), lr=1e-3)

# train pixelcnn
print_freq = 500
for i in range(num_training_updates):
    inputs, _ = next(iter(train_loader))
    inputs = torch.squeeze(inputs, dim=0)
    inputs = inputs.to(device)
    inputs_shape = inputs.shape

    indices = vq_vae.get_code_indices(inputs)
    indices = indices.view(
        num_frames, inputs_shape[2] // 4, inputs_shape[3] // 4)

    indices = indices.cuda()
    one_hot_indices = F.one_hot(
        indices, num_embeddings).float().permute(0, 3, 1, 2).contiguous()

    outputs = pixelcnn(one_hot_indices)

    loss = F.cross_entropy(outputs, indices)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # if (i + 1) % print_freq == 0 or (i + 1) == len(train_loader):
    print("\t [{}/{}]: loss {}".format(i, num_training_updates, loss.item()))


	 [0/500000]: loss 6.236263275146484
	 [1/500000]: loss 6.204369068145752
	 [2/500000]: loss 6.169195175170898
	 [3/500000]: loss 6.123173236846924
	 [4/500000]: loss 6.053569316864014
	 [5/500000]: loss 5.939473628997803
	 [6/500000]: loss 5.760213851928711
	 [7/500000]: loss 5.528297424316406
	 [8/500000]: loss 5.3283538818359375
	 [9/500000]: loss 5.256496906280518
	 [10/500000]: loss 5.213814735412598
	 [11/500000]: loss 5.139461040496826
	 [12/500000]: loss 5.0683417320251465
	 [13/500000]: loss 5.022518157958984
	 [14/500000]: loss 4.996445655822754
	 [15/500000]: loss 4.978377342224121
	 [16/500000]: loss 4.959958553314209
	 [17/500000]: loss 4.937673091888428
	 [18/500000]: loss 4.910678386688232
	 [19/500000]: loss 4.879382133483887
	 [20/500000]: loss 4.844651699066162
	 [21/500000]: loss 4.818471431732178
	 [22/500000]: loss 4.804126262664795
	 [23/500000]: loss 4.763067722320557
	 [24/500000]: loss 4.736823558807373
	 [25/500000]: loss 4.721305847167969
	 [26/500000]: loss 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f81f7afb1f0>
Traceback (most recent call last):
  File "/playpen-raid/qinliu/tools/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/playpen-raid/qinliu/tools/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1322, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/playpen-raid/qinliu/tools/anaconda3/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/playpen-raid/qinliu/tools/anaconda3/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/playpen-raid/qinliu/tools/anaconda3/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/playpen-raid/qinliu/tools/anaconda3/lib/python3.8/selectors.py", line 415, in se

	 [55/500000]: loss 4.040699481964111
	 [56/500000]: loss 4.018599033355713
	 [57/500000]: loss 3.9957942962646484
	 [58/500000]: loss 3.972363233566284
	 [59/500000]: loss 3.9487712383270264
	 [60/500000]: loss 3.925366163253784
	 [61/500000]: loss 3.9021499156951904
	 [62/500000]: loss 3.878460645675659
	 [63/500000]: loss 3.854994058609009
	 [64/500000]: loss 3.8317644596099854
	 [65/500000]: loss 3.8082275390625
	 [66/500000]: loss 3.784783363342285
	 [67/500000]: loss 3.761204481124878
	 [68/500000]: loss 3.737516403198242
	 [69/500000]: loss 3.7143232822418213
	 [70/500000]: loss 3.6914737224578857
	 [71/500000]: loss 3.6692025661468506
