In [1]:
import torch
from torch import nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torchvision import transforms
import torchvision.utils as vutils
import numpy as np
from comet_ml import Experiment

In [2]:
experiment = Experiment(api_key="E3oWJUSFulpXpCUQfc5oGz0zY", project_name="pytorch-avb")

jupyter comet_ml enable 
COMET INFO: old comet version (1.0.29) detected. current: 1.0.31 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/santient/pytorch-avb/7cf132e5da3c4224982654749e11e7bf



In [3]:
cudnn.benchmark = True

In [None]:
# representation_size = 2
# input_size = 4
# n_samples = 2000
# batch_size = 500
# gen_hidden_size = 200
# enc_hidden_size = 200
# disc_hidden_size = 200

In [None]:
# n_samples_per_batch = n_samples//input_size

# y = np.array([i for i in range(input_size)  for _ in range(n_samples_per_batch)])

# d = np.identity(input_size)
# x = np.array([d[i] for i in y], dtype=np.float32)

In [None]:
# print(x[[10, 58 ,610, 790, 1123, 1258, 1506, 1988]])

In [4]:
device = torch.device("cuda:0")

In [None]:
# class VAE(nn.Module):
#     def __init__(self):
#         super(VAE, self).__init__()
#         self.gen_l1 = torch.nn.Linear(representation_size, gen_hidden_size)
#         self.gen_l2 = torch.nn.Linear(gen_hidden_size, input_size)
        
#         self.enc_l1 = torch.nn.Linear(input_size+representation_size, 
#                                       enc_hidden_size)
#         self.enc_l2 = torch.nn.Linear(enc_hidden_size, representation_size)
        
#         self.disc_l1 = torch.nn.Linear(input_size+representation_size, 
#                                        disc_hidden_size)
#         self.disc_l2 = torch.nn.Linear(disc_hidden_size, 1)
        
#         self.relu = torch.nn.ReLU()
#         self.sigmoid = torch.nn.Sigmoid()
        
#     def sample_prior(self, s):
#         if self.training:
#             m = torch.zeros((s.data.shape[0], representation_size))
#             std = torch.ones((s.data.shape[0], representation_size))
#             d = Variable(torch.normal(m,std)).cuda()
#         else:
#             d = Variable(torch.zeros((s.data.shape[0], representation_size))).cuda()
        
#         return d
    
#     def discriminator(self, x,z):
#         i = torch.cat((x, z), dim=1).cuda()
#         h = self.relu(self.disc_l1(i))
#         return self.disc_l2(h)
    
#     def sample_posterior(self, x):
#         i = torch.cat((x, self.sample_prior(x)), dim=1).cuda()
#         h = self.relu(self.enc_l1(i))
#         return self.enc_l2(h)
    
#     def decoder(self, z):
#         i = self.relu(self.gen_l1(z))
#         h = self.sigmoid(self.gen_l2(i))
#         return h
    
#     def forward(self, x):
#         z_p = self.sample_prior(x)
        
#         z_q = self.sample_posterior(x)
#         log_d_prior = self.discriminator(x, z_p)
#         log_d_posterior = self.discriminator(x, z_q)
#         disc_loss = torch.mean(
#             torch.nn.functional.binary_cross_entropy_with_logits(
#             log_d_posterior, torch.ones_like(log_d_posterior)
#         )
#         + torch.nn.functional.binary_cross_entropy_with_logits(
#             log_d_prior, torch.zeros_like(log_d_prior))
#         )
        
#         x_recon = self.decoder(z_q)
#         recon_liklihood = -torch.nn.functional.binary_cross_entropy(
#                                                 x_recon, x)*x.data.shape[0]
        
#         gen_loss = torch.mean(log_d_posterior)-torch.mean(recon_liklihood)
        
#         return disc_loss, gen_loss

In [5]:
img_size = 64
channels = 3
latent_dim = 64

In [6]:
class AVB(nn.Module):
    def __init__(self):
        super(AVB, self).__init__()
        
        # height and width of downsampled image
        self.ds_size = img_size // 2**4
        
        self.gen_proj = nn.Linear(latent_dim, 256*self.ds_size**2)
        self.gen_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
#             nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(256, 256, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(256, 128, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(128, 0.8),
#             nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(128, 64, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(64, 0.8),
#             nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(64, 32, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            nn.BatchNorm2d(32, 0.8),
#             nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(32, channels, 3, stride=1, padding=1),
            nn.Sigmoid()
        )
        
        self.enc_proj = nn.Linear(latent_dim, img_size**2)
        self.enc_blocks = nn.Sequential(
            nn.Conv2d(channels+1, 32, 3, 1, 1),
            nn.Conv2d(32, 32, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.Conv2d(128, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.Conv2d(256, 256, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(256, 0.8)
        )
        self.enc_layer = nn.Linear(256*self.ds_size**2, latent_dim)
        
        self.dis_proj = nn.Linear(latent_dim, img_size**2)
        self.dis_blocks = nn.Sequential(
            nn.Conv2d(channels+1, 32, 3, 1, 1),
            nn.Conv2d(32, 32, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.Conv2d(128, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
            
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.Conv2d(256, 256, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
#             nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(256, 0.8)
        )
        self.dis_layer = nn.Sequential(
            nn.Linear(256*self.ds_size**2, 1),
            nn.Sigmoid()
        )
        
    def sample_prior(self, s):
        if self.training:
            m = torch.zeros((s.data.shape[0], latent_dim))
            std = torch.ones((s.data.shape[0], latent_dim))
            d = Variable(torch.normal(m,std))
        else:
            d = Variable(torch.zeros((s.data.shape[0], latent_dim)))
        return d.cuda()
    
    def discriminator(self, x,z):
        z_proj = self.dis_proj(z)
        z_proj = z_proj.view(z_proj.shape[0], 1, img_size, img_size)
        i = torch.cat((x, z_proj), dim=1)
        h = self.dis_blocks(i)
        h = h.view(h.shape[0], 256*self.ds_size**2)
        out = self.dis_layer(h)
        return out
    
    def sample_posterior(self, x):
        prior_proj = self.enc_proj(self.sample_prior(x))
        prior_proj = prior_proj.view(prior_proj.shape[0], 1, img_size, img_size)
        i = torch.cat((x, prior_proj), dim=1)
        h = self.enc_blocks(i)
        h = h.view(h.shape[0], 256*self.ds_size**2)
        out = self.enc_layer(h)
        return out
    
    def decoder(self, z):
        z_proj = self.gen_proj(z)
        z_proj = z_proj.view(z_proj.shape[0], 256, self.ds_size, self.ds_size)
        out = self.gen_blocks(z_proj)
        return out
    
    def forward(self, x):
        z_p = self.sample_prior(x)
        
        z_q = self.sample_posterior(x)
        log_d_prior = self.discriminator(x, z_p)
        log_d_posterior = self.discriminator(x, z_q)
        dis_loss = torch.mean(
            torch.nn.functional.binary_cross_entropy_with_logits(
            log_d_posterior, torch.ones_like(log_d_posterior)
        )
        + torch.nn.functional.binary_cross_entropy_with_logits(
            log_d_prior, torch.zeros_like(log_d_prior))
        )
        
        x_recon = self.decoder(z_q)
        recon_likelihood = -torch.nn.functional.binary_cross_entropy(
                                                x_recon, x)*x.data.shape[0]
        
        gen_loss = torch.mean(log_d_posterior)-torch.mean(recon_likelihood)
        
        return dis_loss, gen_loss

In [7]:
model = AVB().cuda()

In [8]:
print(model)

AVB(
  (gen_proj): Linear(in_features=64, out_features=4096, bias=True)
  (gen_blocks): Sequential(
    (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ConvTranspose2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (2): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (6): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (9): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (10): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11)

In [9]:
dis_params = []
gen_params = []
for name, param in model.named_parameters():
    if 'dis' in name:
        dis_params.append(param)
    else:
        gen_params.append(param)

In [10]:
dis_optimizer = torch.optim.Adam(dis_params, lr=1e-3)
gen_optimizer = torch.optim.Adam(gen_params, lr=1e-3)

In [11]:
dataroot = "/home/santiago/Downloads/celebA/"

batch_size = 128
workers = 4
dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.CenterCrop(128),
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=int(workers))

In [12]:
fixed_noise = torch.randn(batch_size, latent_dim, device=device)

In [13]:
log_interval = 1
sample_interval = 500
epochs = 100

In [14]:
batches_done = 0

In [15]:
with experiment.train():
    for epoch in range(epochs):
        model.train()

        for i, data in enumerate(dataloader, 0):
            data = Variable(data[0].cuda(), requires_grad=False)

            dis_loss, gen_loss = model(data)

            gen_optimizer.zero_grad()
            gen_loss.backward(retain_graph=True)
            gen_optimizer.step()

            dis_optimizer.zero_grad()
            dis_loss.backward(retain_graph=True)
            dis_optimizer.step()

            if (i % log_interval == 0) and (epoch % 1 == 0):
                print('Train Epoch: {} [{}/{}]\td_loss: {:.6f}\tg_loss: {:.6f}'.format(
                    epoch, i * batch_size, len(dataset),
                    dis_loss.data[0] / len(data), gen_loss.data[0] / len(data)), end='\r', flush=True)
                experiment.log_metric("d_loss", dis_loss.data[0] / len(data), step=batches_done)
                experiment.log_metric("g_loss", gen_loss.data[0] / len(data), step=batches_done)

            if (i % sample_interval == 0) and (epoch % 1 == 0):
                vutils.save_image(data,
                            '../avb/images/real_samples.png',
                            normalize=True)
                fake = model.decoder(fixed_noise)
                vutils.save_image(fake.detach(),
                            '../avb/images/fake_samples_step_%03d.png' % batches_done,
                            normalize=True)
                # do checkpointing
                torch.save(model.state_dict(), '../avb/checkpoints/avb_step_%d.pth' % batches_done)
                torch.save(gen_optimizer.state_dict(), '../avb/checkpoints/gen_opt_step_%d.pth' % batches_done)
                torch.save(dis_optimizer.state_dict(), '../avb/checkpoints/dis_opt_step_%d.pth' % batches_done)

            batches_done += 1
        print("Epoch {} done!".format(epoch))

Train Epoch: 0 [0/202599]	d_loss: 0.011315	g_loss: 0.716230



Epoch 0 done!0 [202496/202599]	d_loss: 0.009771	g_loss: -13.462034
Epoch 1 done!1 [202496/202599]	d_loss: 0.009771	g_loss: -14.289164
Epoch 2 done!2 [202496/202599]	d_loss: 0.009771	g_loss: -13.746735
Epoch 3 done!3 [202496/202599]	d_loss: 0.009771	g_loss: -14.212850
Epoch 4 done!4 [202496/202599]	d_loss: 0.009771	g_loss: -13.454019
Epoch 5 done!5 [202496/202599]	d_loss: 0.009771	g_loss: -13.627545
Epoch 6 done!6 [202496/202599]	d_loss: 0.009771	g_loss: -13.152081
Epoch 7 done!7 [202496/202599]	d_loss: 0.009771	g_loss: -13.807567
Epoch 8 done!8 [202496/202599]	d_loss: 0.009771	g_loss: -13.210717
Epoch 9 done!9 [202496/202599]	d_loss: 0.009771	g_loss: -12.999068
Epoch 10 done!0 [202496/202599]	d_loss: 0.009771	g_loss: -13.798413
Epoch 11 done!1 [202496/202599]	d_loss: 0.009771	g_loss: -13.287612
Epoch 12 done!2 [202496/202599]	d_loss: 0.009771	g_loss: -13.218492
Epoch 13 done!3 [202496/202599]	d_loss: 0.009771	g_loss: -12.596121
Epoch 14 done!4 [202496/202599]	d_loss: 0.009771	g_loss: -

In [16]:
vutils.save_image(data,
            '../avb/images/real_samples.png',
            normalize=True)
fake = model.decoder(fixed_noise)
vutils.save_image(fake.detach(),
            '../avb/images/fake_samples_step_%03d.png' % batches_done,
            normalize=True)
# do checkpointing
torch.save(model.state_dict(), '../avb/checkpoints/avb_step_%d.pth' % batches_done)
torch.save(gen_optimizer.state_dict(), '../avb/checkpoints/gen_opt_step_%d.pth' % batches_done)
torch.save(dis_optimizer.state_dict(), '../avb/checkpoints/dis_opt_step_%d.pth' % batches_done)

In [None]:
# END

In [None]:
def train(epoch, log_interval=1, sample_interval=100):
    model.train()
    
    for i, data in enumerate(dataloader, 0):
        data = Variable(data[0].cuda(), requires_grad=False)
        
        dis_loss, gen_loss = model(data)
        
        gen_optimizer.zero_grad()
        gen_loss.backward(retain_graph=True)
        gen_optimizer.step()
        
        dis_optimizer.zero_grad()
        dis_loss.backward(retain_graph=True)
        dis_optimizer.step()
        
        if (i % log_interval == 0) and (epoch % 1 == 0):
            print('Train Epoch: {} [{}/{}]\td_loss: {:.6f}\tg_loss: {:.6f}'.format(
                epoch, i * batch_size, len(dataset),
                dis_loss.data[0] / len(data), gen_loss.data[0] / len(data)), flush=True)
        
        if (i % sample_interval == 0) and (epoch % 1 == 0):
            vutils.save_image(data,
                        '../avb/images/real_samples.png',
                        normalize=True)
            fake = model.decoder(fixed_noise)
            vutils.save_image(fake.detach(),
                        '../avb/images/fake_samples_step_%03d.png' % batches_done,
                        normalize=True)
            # do checkpointing
            torch.save(model.state_dict(), '../avb/checkpoints/avb_step_%d.pth' % batches_done)
            torch.save(gen_optimizer.state_dict(), '../avb/checkpoints/gen_opt_step_%d.pth' % batches_done)
            torch.save(dis_optimizer.state_dict(), '../avb/checkpoints/dis_opt_step_%d.pth' % batches_done)
        
        batches_done += 1
#     ind = np.arange(x.shape[0])
#     for i in range(batches_per_epoch):
#         data = torch.from_numpy(x[np.random.choice(ind, size=batch_size)])
#         data = Variable(data.cuda(), requires_grad=False)
        
        
#         discrim_loss, gen_loss= model(data)
        
#         gen_optimizer.zero_grad()
#         gen_loss.backward(retain_graph=True)
#         gen_optimizer.step()
        
#         disc_optimizer.zero_grad()
#         discrim_loss.backward(retain_graph=True)
#         disc_optimizer.step()
#         if (i % log_interval == 0) and (epoch % 1 ==0):
#             #Print progress
#             print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tLoss: {:.6f}'.format(
#                 epoch, i * batch_size, batch_size*batches_per_epoch,
#                 discrim_loss.data[0] / len(data), gen_loss.data[0] / len(data)))

#     print('====> Epoch: {} done!'.format(
#           epoch))

In [None]:
epochs = 100

In [None]:
for epoch in range(epochs):
    train(epoch)

In [None]:
for epoch in range(1, 15):
    train(epoch)

In [None]:
data = Variable(torch.from_numpy(x), requires_grad=False).cuda()

model.train()
zs = model.sample_posterior(data).data.cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.scatter(zs[:,0], zs[:, 1], c=y)

In [None]:
data = Variable(torch.from_numpy(x), requires_grad=False).cuda()
model.eval()
zs = model.sample_posterior(data).data.cpu().numpy()

plt.scatter(zs[:,0], zs[:, 1], c=y)

In [None]:
test_point = np.array([0.5, 0.6], dtype=np.float32).reshape(1,-1)
test_point = Variable(torch.from_numpy(test_point), requires_grad=False).cuda()
s = model.decoder(test_point)
s.data