In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 641.0698852539062
1 596.6747436523438
2 557.7380981445312
3 523.2838134765625
4 492.4844055175781
5 464.98486328125
6 439.8149719238281
7 416.69049072265625
8 395.1796875
9 375.2060241699219
10 356.4792175292969
11 338.8349304199219
12 322.2173156738281
13 306.3661193847656
14 291.2605285644531
15 276.9498291015625
16 263.28656005859375
17 250.20050048828125
18 237.67408752441406
19 225.7671356201172
20 214.42071533203125
21 203.572265625
22 193.23353576660156
23 183.376953125
24 173.96966552734375
25 164.99476623535156
26 156.42930603027344
27 148.25270080566406
28 140.4392852783203
29 132.94419860839844
30 125.80224609375
31 119.00061798095703
32 112.54032897949219
33 106.39104461669922
34 100.5414810180664
35 94.9902114868164
36 89.73754119873047
37 84.74893951416016
38 79.9990463256836
39 75.49420166015625
40 71.22856140136719
41 67.19147491455078
42 63.372955322265625
43 59.76323318481445
44 56.35694885253906
45 53.1414909362793
46 50.1088981628418
47 47.25053405761719
48 44.559

470 3.6409585391083965e-06
471 3.5260416098026326e-06
472 3.4132888231397374e-06
473 3.3061205613194034e-06
474 3.201116442141938e-06
475 3.1005440632725367e-06
476 3.001604909513844e-06
477 2.9067341529298574e-06
478 2.8154388473922154e-06
479 2.7262694857199676e-06
480 2.6399663966003573e-06
481 2.5568981527612777e-06
482 2.4758132894930895e-06
483 2.3980633159226272e-06
484 2.322020009160042e-06
485 2.24806376536435e-06
486 2.177251190005336e-06
487 2.1092737370054238e-06
488 2.0423960904736305e-06
489 1.9779051854129648e-06
490 1.9154119854647433e-06
491 1.8556474969955161e-06
492 1.7966224277188303e-06
493 1.740423158480553e-06
494 1.685539132267877e-06
495 1.632676571716729e-06
496 1.5816104905752582e-06
497 1.531224484097038e-06
498 1.4830255850029062e-06
499 1.4364138678502059e-06


In [4]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)


# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))



Epoch[1/15], Step [10/469], Reconst Loss: 35607.0117, KL Div: 3252.1628
Epoch[1/15], Step [20/469], Reconst Loss: 30020.0273, KL Div: 1089.5189
Epoch[1/15], Step [30/469], Reconst Loss: 27580.6211, KL Div: 1174.6149
Epoch[1/15], Step [40/469], Reconst Loss: 26698.7949, KL Div: 645.2008
Epoch[1/15], Step [50/469], Reconst Loss: 26766.0840, KL Div: 617.2636
Epoch[1/15], Step [60/469], Reconst Loss: 26262.5547, KL Div: 700.7596
Epoch[1/15], Step [70/469], Reconst Loss: 25164.1992, KL Div: 808.3029
Epoch[1/15], Step [80/469], Reconst Loss: 24524.2715, KL Div: 857.1606
Epoch[1/15], Step [90/469], Reconst Loss: 23485.5879, KL Div: 1264.2985
Epoch[1/15], Step [100/469], Reconst Loss: 22249.1855, KL Div: 1290.5806
Epoch[1/15], Step [110/469], Reconst Loss: 22326.3262, KL Div: 1375.0464
Epoch[1/15], Step [120/469], Reconst Loss: 20346.1992, KL Div: 1678.9233
Epoch[1/15], Step [130/469], Reconst Loss: 20621.3887, KL Div: 1729.7217
Epoch[1/15], Step [140/469], Reconst Loss: 19616.2246, KL Div: 18

Epoch[3/15], Step [220/469], Reconst Loss: 11947.2988, KL Div: 3039.6143
Epoch[3/15], Step [230/469], Reconst Loss: 12028.2783, KL Div: 3044.8308
Epoch[3/15], Step [240/469], Reconst Loss: 11262.1514, KL Div: 2913.2280
Epoch[3/15], Step [250/469], Reconst Loss: 11590.3047, KL Div: 3048.4224
Epoch[3/15], Step [260/469], Reconst Loss: 11555.4795, KL Div: 3061.5991
Epoch[3/15], Step [270/469], Reconst Loss: 11841.4736, KL Div: 2988.6853
Epoch[3/15], Step [280/469], Reconst Loss: 11411.2979, KL Div: 3004.0178
Epoch[3/15], Step [290/469], Reconst Loss: 11631.3428, KL Div: 3039.4556
Epoch[3/15], Step [300/469], Reconst Loss: 11496.9375, KL Div: 3139.2642
Epoch[3/15], Step [310/469], Reconst Loss: 11533.1143, KL Div: 3017.4001
Epoch[3/15], Step [320/469], Reconst Loss: 11960.7695, KL Div: 3145.2212
Epoch[3/15], Step [330/469], Reconst Loss: 11250.1475, KL Div: 2952.7268
Epoch[3/15], Step [340/469], Reconst Loss: 11622.8809, KL Div: 3061.2671
Epoch[3/15], Step [350/469], Reconst Loss: 11316.88

KeyboardInterrupt: 