In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        x = x.reshape(x.size(0),-1)
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):

        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar


c_net = VAE().cuda()
pkl_dir = "/home/lrh/program/git/pytorch-example/mnist_autoencoder/vae.pth"
c_net.load_state_dict(torch.load(pkl_dir))

print c_net

VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
)


In [5]:
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils

result_directory = "."

data = dset.MNIST(root="/home/lrh/dataset/mnist",train = True,download=True,transform=transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
dataloader = torch.utils.data.DataLoader(data,batch_size=32,shuffle=True,drop_last=True)


for epoch,data in enumerate(dataloader,0):
    real_x,label = data
    real_x = real_x.cuda()
    fake_x,mu,var = c_net(real_x)
    fake_x = fake_x.reshape(32,1,28,28)
    vutils.save_image(fake_x.cpu().detach(),'%s/fake_samples_epoch_%03d.png'
    % (result_directory,epoch),normalize=True)
    break

In [4]:
import torchvision.utils as vutils
result_directory = "."
fake_mu = torch.randn(64,20).cuda()
#fake_sigma = torch.randn(64,20).cuda()

#fake_x = c_net.decode(c_net.reparametrize(fake_mu,fake_sigma))
fake_x = c_net.decode(fake_mu)
fake_x = fake_x.reshape(64,1,28,28)
vutils.save_image(fake_x.cpu().detach(),'%s/fake_samples_epoch_%03d.png'
% (result_directory,0),normalize=True)