In [1]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

import numpy as np

In [2]:
device = torch.device("cuda")

In [3]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f87c03c4f10>

In [5]:
class GLU(nn.Module):
    def __init__(self, c1, c2):
        super(GLU, self).__init__()
        self.s = nn.Linear(c1, c2)
        self.g = nn.Linear(c1, c2)

    def forward(self, x):
        s = torch.sigmoid(self.s(x))
        g = torch.relu(self.g(x))
        output = s * g

        return output
    
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 50)

    def forward(self, x):

        x = torch.relu(self.fc1(x))
        phase = torch.sigmoid(self.fc2(x))

        return phase
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.fc1 = GLU(100, 400)
        self.fc2 = nn.Linear(400, 784)

    def forward(self, x):

        x = self.fc1(x)
        x = torch.sigmoid(self.fc2(x))

        return x

class Key(nn.Module):
    def __init__(self):
        super(Key, self).__init__()

        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 50)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        w = torch.sigmoid(self.fc2(x))

        return w
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.e = Encoder()
        self.d = Decoder()
        self.amplitude = Key()

    def forward(self, x, c, t):
        x = x.view(-1, 784)
        N = x.shape[0]

        w = self.amplitude(c)
        phase = self.e(x)

        w = w.view(N, 50, 1)
        phase = phase.view(N, 50, 1)

        w = w.repeat(1, 1, 100)
        phase = phase.repeat(1, 1, 100)

        x = torch.sin(2 * np.pi * w * t  + np.pi * phase )
        x = x.sum(dim=1)
        x = x.view(N, 100)
        noise = torch.randn_like(x)
        x = noise + x
        x = self.d(x)

        return x, w, phase

In [6]:
model = VAE().to(device)

In [7]:
model.load_state_dict(torch.load('checkpoints/mnist/fft_400.pt'))

In [8]:
import matplotlib

In [9]:
%matplotlib notebook

In [10]:
import matplotlib.pyplot as plt

In [13]:
model.eval()
t = torch.arange(100)
t = t.type(torch.FloatTensor)
t = t.to(device)

In [118]:
c = torch.zeros(1,10).to(device)

In [123]:
data = torch.rand(1, 1, 28, 28).to(device)

In [124]:
c[:,0] =1

In [125]:
rx, w, phase= model(data, c, t)   
img = rx.detach().cpu().numpy()
img = img.reshape(28,28)   
print(img.shape)
#plt.imshow(img)


(28, 28)
