Skip to content

Commit

Permalink
implement batching
Browse files Browse the repository at this point in the history
  • Loading branch information
siarez committed Dec 24, 2017
1 parent 0e2f59f commit 1515f30
Showing 1 changed file with 43 additions and 25 deletions.
68 changes: 43 additions & 25 deletions vae.py
@@ -1,17 +1,20 @@
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import fetch_mldata
from tqdm import tqdm
from logger import Logger
import numpy as np

BATCH_SIZE = 3
pz_size = 40
mid_size = 100
mid_size = 200
image_size = 28 * 28

x = torch.randn(image_size)
loss_func = torch.nn.MSELoss()

#TODO: Check for CUDA
dtype_var = torch.cuda.FloatTensor

class VAE(torch.nn.Module):
def __init__(self):
Expand All @@ -23,18 +26,20 @@ def __init__(self):
self.relu_activate = torch.nn.ReLU()
self.middle_out_layer = torch.nn.Linear(int(pz_size / 2), mid_size)
self.output_layer = torch.nn.Linear(mid_size, image_size)
self.normal_tensor = torch.zeros(int(pz_size/2)).cuda()

def forward(self, x):
mu_sigma = self.relu_activate(self.middle_layer(self.relu_activate(self.input_layer(x))))
mu, logvar = torch.chunk(mu_sigma, 2)
mu, logvar = torch.chunk(mu_sigma, 2, dim=1)
std = logvar.mul(0.5).exp()

# Calculating hidden representation `z`
normal_sample = torch.autograd.Variable(
torch.normal(torch.zeros(int(pz_size / 2)), torch.ones(int(pz_size / 2))),
requires_grad=False)
self.z = torch.mul(normal_sample, std) + mu

# normal_sample = torch.autograd.Variable(
# torch.normal(torch.zeros(int(pz_size / 2)), torch.ones(int(pz_size / 2))),
# requires_grad=False).cuda()
self.normal_tensor.normal_()
# self.z = torch.mul(normal_sample, std) + mu
self.z = torch.mul(torch.autograd.Variable(self.normal_tensor, requires_grad=False), std) + mu
# Calculating output
return self.activation_layer(self.output_layer(self.relu_activate(self.middle_out_layer(self.z)))), mu, logvar

Expand All @@ -45,48 +50,61 @@ def forward_classical(self, x):
# Calculating output
return self.activation_layer(self.output_layer(self.relu_activate(self.middle_out_layer(self.z))))

class MyData(Dataset):
def __init__(self, input):
self.input = input

def __len__(self):
return self.input.shape[0]

def __getitem__(self, item):
return self.input[item]


def kl_loss(mu, logvar):
return -0.5*torch.sum(1 + logvar - mu.pow(2) - torch.exp(logvar))

def _prepare_data():
mnist = fetch_mldata('MNIST original', data_home='./')
mnist_tensor = torch.ByteTensor(mnist.data).type(torch.FloatTensor)
mnist_tensor = torch.ByteTensor(mnist.data).type(dtype_var)
rnd_idx = np.random.choice(len(mnist.data), len(mnist.data), replace=False) # pick index values randomly
mnist_tensor = mnist_tensor[rnd_idx, :] # randomize the data set
mnist_tensor = mnist_tensor / 255 # normalize
print(mnist_tensor.shape)
return mnist_tensor

dataset = MyData(mnist_tensor)
print(mnist_tensor.shape)
return dataset

def train(data_var, epochs=5):
loss_trace = torch.zeros(epochs, data_var.shape[0])

def train(data_set, epochs=5):
# loss_trace = torch.zeros(epochs, len(data_set))
loss_trace = []
data_loader = DataLoader(data_set, batch_size=BATCH_SIZE)
for i in tqdm(range(epochs)):
for j in tqdm(range(data_var.shape[0])):
output, mu, logvar = vae(data_var[j][:]) # forward pass
loss = loss_func(output, data_var[j][:]) + kl_loss(mu, logvar) # calculate loss
loss_trace[i][j] = loss.data[0]
for batch_raw in tqdm(data_loader):
batch = torch.autograd.Variable(batch_raw)
output, mu, logvar = vae(batch) # forward pass
loss = loss_func(output, batch) + kl_loss(mu, logvar) # calculate loss
loss_trace.append(torch.sum(loss.data))
logger.scalar_summary("loss", loss.data[0], i + 1)
loss.backward() # back propagate loss to calculate the deltas (the "gradient")
optim.step() # use the learning rates and the gradient to update parameters
if j % 5000 == 0:
log_images(data_var[j][:], output, (j+1) + data_var.shape[0]*(i + 1))
# if j % 5000 == 0:
# log_images(batch.data[], output, (j+1) + data_var.shape[0]*(i + 1))

return loss_trace


def log_images(input, output, epoch):
reshaped_in = input.view(28, 28).data.numpy()
reshaped_out = output.view(28, 28).data.numpy()
reshaped_in = input.view(28, 28).data.cpu().numpy()
reshaped_out = output.view(28, 28).data.cpu().numpy()
logger.image_summary('prediction', [reshaped_in, reshaped_out], epoch)


logger = Logger("./logs")

vae = VAE()
vae = VAE().cuda()
optim = Adam(vae.parameters(), lr=0.00001) # lr: learning rate, weight decay: ?
mnist_tensor = _prepare_data()
input_var = torch.autograd.Variable(mnist_tensor)
trace = train(data_var=input_var)
mnist_dataset = _prepare_data()
trace = train(data_set=mnist_dataset)
# normal distribution definition

0 comments on commit 1515f30

Please sign in to comment.