In [None]:
import torch
import torch import nn
import torch.nn.functional
import torchvision.datasets.MNIST
import torchvision.transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#vqa vae model
class VQVAE(nn.Module):
    def __init__(self):
        super(VQVAE, self).__init__()
        #encoder
        self.encoder = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.Conv2d(16, 4, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(4),
        nn.ReLU()
        )
        self.commitment_factor = 0.2
        self.before_quantisation = nn.Conv2d(4, 2, 1)
        #codebook vectors
        self.codebook = nn.Parameters(torch.randn(3, 2)*0.0001)
        self.following_quantisation = nn.Conv2d(2, 4, 1)
        #decoder
        self.decoder = nn.Sequential(
        nn.ConvTranspose2d(4, 16, kernel_size=4, stride=2, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(16),
        nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
        )
    def forward(self, data):
        #first encode
        encode_data = self.encoder(data)
        #set to dimensionality of codebook
        q_input = self.before_quantisation(encode_data)
        B, C, H, W = q_input.shape()
        q_input = q_input.permute(0, 2, 3, 1)
        q_input = q_input.view(B,H*W,C)
        #find closest codebook
        d = torch.cdist(q_input, self.codebook[None,:].repeat(q_input.size(0), 1, 1))
        #find indices of the d
        idx = torch.argmin(d, dim=-1)
        #select the indices
        q_out = torch.index_select(self.codebook, dim=0, idx.view(-1))
        #reshape q_input to 2 dimensions
        q_input = q_input.continguous().view(-1, q_input.size(-1))
        #get commitment loss -- move the input to the codebook vectors
        commit_loss = torch.mean((q_out.detach() - q_input)**2)
        #get code loss -- move codebook to the the input
        code_loss = torch.mean((q_out - q_input.detach())**2)
        #loss
        q_loss = code_loss + self.commit*commit_loss
        #to ensure gradient flow, apply the following use of detach()
        q_out = q_input + (q_out - q_input).detach()
        #reshape to B, C, H, W
        q_out = q_out.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
        #reshape idx also
        #idx = idx.view(-1, q_out.size(-2), q_out.size(-1))
        #generate image
        decoded = self.following_quantisation(q_out)
        decoded = self.decoder(decoded)
        return decoded, q_loss, idx

In [None]:
#lstm to sample from vqvae
class LSTMSampler(nn.Module):
    def __init__(self, codebook=2, hidden_size=4, num_codebook=3):
        self.codebook = codebook
        self.hidden_size = hidden_size
        self.num_codebook = num_codebook
        self.embedding = nn.Embedding(num_codebook, codebook)
        self.lstm = nn.LSTM(input_size=codebook, hidden_size=4, batch_size=True)
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(hidden_size, codebook)
    def forward(self, x):
        x = self.embedding(x)
        x = self.lstm(x)
        x = x[:, -1, :]
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear1(x)
        return x

In [None]:
#train code
def train(model, epochs, criterion, optimiser, scheduler, train_loader):
    model.train()
    for epoch in range(epochs):
        print('Training epoch ', epoch)
        for d, label in range(train_loader):
            d = d.to(device)
            optimiser.zero_grad()
            decoded, q_loss, _ = model(d)
            loss = criterion(decoded, d) + q_loss
            loss.backward()
            optimiser.step()

In [None]:
#sample from trained model
def train_sampler(gen_model, model, epochs, criterion, optimiser, train_loader):
    gen_model.eval()
    model.train()
    for epoch in range(epochs):
        print('Training epoch', epoch)
        for d, label in range(train_loader):
            d = d.to(device)
            with torch.no_grad():
                _, _, en = gen_model(d)
            x = en[:, :-1]
            y = en[:, 1:]
            optimiser.zero_grad()
            o = model(x)
            loss = criterion(y, o)
            loss.backward()
            optimiser.step()