In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mol import discretized_mix_logistic_loss, sample_from_discretized_mix_logistic

class Decoder(nn.Module):
    def __init__(self, size, z_dim, h_dims=[32, 64, 128, 256, 512], **kwargs):
        super().__init__()
        
        size = size // 2 ** len(h_dims)
        self.size = size
        in_dim = h_dims[-1]
        self.in_dim = in_dim
        self.linear = nn.Linear(z_dim, in_dim*size**2)
        
        h_dims = h_dims[:-1]
        convs = []
        for h_dim in h_dims[::-1]:
            conv = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
                                 nn.BatchNorm2d(h_dim),
                                 nn.LeakyReLU())
            convs.append(conv)
            in_dim = h_dim
        self.convs = nn.Sequential(*convs)
        self.out_conv = nn.Sequential(nn.ConvTranspose2d(h_dim, h_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
                                      nn.BatchNorm2d(h_dim),
                                      nn.LeakyReLU(),
                                      nn.Conv2d(h_dim, 100, kernel_size=3, padding=1))
        
    def forward(self, data, **kwargs):
        # x : (b, c, h, w)
        
        z = data['z']
        y = self.linear(z)
        y = y.reshape(y.shape[0], self.in_dim, self.size, self.size)
        y = self.convs(y)
        l = self.out_conv(y)
        data['l'] = l
        recon_loss = torch.mean(discretized_mix_logistic_loss(data['x'].permute(0, 2, 3, 1), l.permute(0, 2, 3, 1)))
        data['recon_loss'] = recon_loss
        return data
    
    def sample(self, z):
        y = self.linear(z)
        y = y.reshape(y.shape[0], self.in_dim, self.size, self.size)
        y = self.convs(y)
        l = self.out_conv(y)
        sample = sample_from_discretized_mix_logistic(l.permute(0, 2, 3, 1), 10)
        sample = sample.permute(0, 3, 1, 2)
        return sample
    

In [25]:
decoder = Decoder(64, 128)
data = {'x': torch.randn(2, 3, 64, 64),
        'z': torch.randn(2, 128)
       }
data = decoder(data)
print(data['recon_loss'])
print(data['l'].shape)

tensor(4.8418, grad_fn=<MeanBackward0>)
torch.Size([2, 100, 64, 64])


In [26]:
sample = decoder.sample(data['z'])
print(sample.min(), sample.max())

tensor(-1., grad_fn=<MinBackward1>) tensor(1., grad_fn=<MaxBackward1>)


In [27]:
sample.shape

torch.Size([2, 3, 64, 64])