In [0]:


!pip install torch==1.1.0
!pip install torchvision==0.2.1
import sys
print(sys.version) # python 3.6
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.utils as vutils
print(torch.__version__) 

%matplotlib inline
import matplotlib.pyplot as plt
import os, time

import itertools
import pickle
import imageio
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm

# You can use whatever display function you want. This is a really simple one that makes decent visualizations
def show_imgs(x,epochs,iterations, new_fig=True):
    grid = vutils.make_grid(x.detach().cpu(), nrow=8, normalize=False, pad_value=0.3)
    grid = grid.transpose(0,2).transpose(0,1) # channels as last dimension
    if new_fig:
        plt.figure()
    plt.imshow(grid.numpy())
    #plt.text(0, -20, 'Epoch: ' + str(epochs) + ', ' + 'Iteration: ' + str(iterations), fontsize=20)
    plt.savefig('/content/drive/My Drive/GAN_stuff/image_' + str(epochs) + '_' + str(iterations) + '.png')

from google.colab import drive
drive.mount('/content/drive')

3.6.8 (default, Oct  7 2019, 12:59:55) 
[GCC 8.3.0]
1.1.0
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:


# helper function to initialize the weights using a normal distribution. 
# this was done in the original work (instead of xavier) and has been shown
# to help GAN performance
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

class Generator(nn.Module):
    # initializers
    def __init__(self, d=32):
        super().__init__()
        self.conv0 = nn.Conv2d(3, 1024 , 2, 4 ,2 )
        self.conv1 = nn.Conv2d(1024, 1024 , 2, 4 ,2 )
        self.deconv1 = nn.ConvTranspose2d(1024, d*8, 4, 1, 0)
        self.deconv1_bn = nn.BatchNorm2d(d*8)
        self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*4)
        self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d*2)
        self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(d)
        #self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1)
        self.deconv5 = nn.ConvTranspose2d(d,3,6,2,37)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, x):

        #print("G " + str(x.size()))
        x = F.relu(self.conv0(x))
        #print("G " + str(x.size()))
        x = F.relu(self.conv1(x))
        #print("G " + str(x.size()))
        x = F.relu(self.deconv1_bn(self.deconv1(x)))
        #print("G " + str(x.size()))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        #print("G " + str(x.size()))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        #print("G " + str(x.size()))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        #print("G " + str(x.size()))
        x = torch.tanh(self.deconv5(x))
        #print("G " + str(x.size()))
        
        return x

class Discriminator(nn.Module):
    # initializers
    def __init__(self, d=32):
        super().__init__()
        self.conv0 = nn.Conv2d(3, d , 2, 4 ,2 )
        self.conv1 = nn.Conv2d(d, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, x):
        #print("D " + str(x.size()))
        x = F.leaky_relu(self.conv0(x), 0.2)
        #print("D " + str(x.size()))
        x = F.leaky_relu(self.conv1(x), 0.2)
        #print("D " + str(x.size()))
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        #print("D " + str(x.size()))
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        #print("D " + str(x.size()))
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        #print("D " + str(x.size()))
        x = torch.sigmoid(self.conv5(x))
        #print("D " + str(x.size()))
        return x

#####
# instantiate a Generator and Discriminator according to their class definition.
#####
D=Discriminator()
G=Generator()


In [0]:
batch_size = 32
lr = 0.0002
train_epoch = 3

import urllib.request
from zipfile import ZipFile
from torch.utils import data
from os import path
import imageio
img_size = 250

#download the data, and change the filepath
url='https://thleats-bucket.s3.us-east-2.amazonaws.com/CS/celeba-dataset.zip'
url2='https://thleats-bucket.s3.us-east-2.amazonaws.com/celeba_dataset_trans.zip'
location = '/content/celeba-dataset.zip'
location2='/content/celeba_dataset_trans.zip'

if path.exists(location):
  print('already downloaded!')
else:
  print('downloading')
  urllib.request.urlretrieve(url,location)
# Create a ZipFile Object and load sample.zip in it
  with ZipFile(location, 'r') as zipObj:
    # Extract all the contents of zip file in current directory
    zipObj.extractall()

if path.exists(location2):
  print('already downloaded!')
else:
  print('downloading')
  urllib.request.urlretrieve(url2,location2)
# Create a ZipFile Object and load sample.zip in it
  with ZipFile(location2, 'r') as zipObj:
    # Extract all the contents of zip file in current directory
    zipObj.extractall('/content/celeba_dataset_trans/celeba_dataset_trans')



dataset=datasets.ImageFolder(root='/content/img_align_celeba/',
                                      transform=transforms.Compose([transforms.Resize(img_size),
                                      transforms.CenterCrop(img_size),
                                      transforms.ToTensor(),
                                      ]))

dataset2=datasets.ImageFolder(root='/content/celeba_dataset_trans/',
                                      transform=transforms.Compose([transforms.Resize(img_size),
                                      transforms.ToTensor(),
                                      ]))


##### Create the dataloader #####
class Dataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self,dataset,dataset2):
    'Initialization'
    self.dataset1=dataset
    self.dataset2=dataset2
  def __len__(self):
    'Denotes the total number of samples'
    return len(self.dataset2)
    #return 1024
  def __getitem__(self, index):
    'Generates one sample of data'
    # Select sample
    x,_ = self.dataset1[index+1]
    x2,_=self.dataset2[index] 
    Y = index
    return x, x2, Y

thing=Dataset(dataset,dataset2)
params={'batch_size':batch_size,'shuffle':True}
training_generator=data.DataLoader(thing,**params)

xbatch, x2, _ = iter(training_generator).next()
xbatch.shape
D(xbatch)
D(xbatch).shape

already downloaded!
already downloaded!


torch.Size([32, 1, 1, 1])

In [0]:
G = Generator(32)
D = Discriminator(32)
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
G = G.cuda()
D = D.cuda()

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

In [0]:
num_iter = 0
fixed_z_ = torch.randn(32,100,1,1) 
collect_x_gen = []
train_epoch=50
import pdb
count=0
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    epoch_start_time = time.time()
    for x_, x2_,_ in tqdm(training_generator):
        if count==0:
          fixed_z_=x_
          fixed_z_true=x2_
        ######################### train discriminator D ###############################
        ###############################################################################
        if x_.size()[0]==32:
          D.zero_grad()
          
          mini_batch = x_.size()[0]
          ##Set optimizer grads to zero
          D_optimizer.zero_grad()
          G_optimizer.zero_grad()
          #create a random noise
          #z = torch.randn(mini_batch,100,1,1)
          z=x_
          #create the zeros and ones vector for real and fake
          y_real=torch.ones(x_.size(0)).cuda()
          y_fake=torch.zeros(x_.size(0)).cuda()
          #Pass through discriminiator - train it to recognize real images
          D_result=D(x2_.cuda()).squeeze(-1).squeeze(-1)
          #find the real loss for the discriminator
          D_real_loss=BCE_loss(D_result.squeeze(-1),y_real)
          #pass the noise through the generator
          #pdb.set_trace()
          G_result=G(z.cuda())
          #pass the Generated data through the discriminator
          D_result_G=D(G_result)
          #pdb.set_trace()
          #calculate how well the discriminator does at recognizing fake images
          D_fake_loss=BCE_loss(D_result_G.squeeze(-1).squeeze(-1).squeeze(-1),y_fake)
          #calculate the total loss (real + fake) - basically - how good is the discriminator at seeing real from fake
          D_train_loss=D_real_loss+D_fake_loss
          #backpropagation on teh network
          D_train_loss.backward()
          #treain the network
          D_optimizer.step()
          #record the losses
          D_losses.append(D_train_loss.item())
          #rezero the optimizers
          D_optimizer.zero_grad()
          G_optimizer.zero_grad()

          ######################### train generator G ###############################
          ###############################################################################
          G.zero_grad()
          #create more noise
          #z_new = torch.randn(32,100,1,1)
          #pass the noise through the generator
          G_result_G=G(z.cuda())
          #pass the generated data through the discriminator
          D_result_2=D(G_result_G)
          #find how good the generator is at generating fakes
          G_train_loss=BCE_loss(D_result_2.squeeze(-1).squeeze(-1).squeeze(-1),y_real)
          #calculate the gradients
          G_train_loss.backward()
          #train the network
          G_optimizer.step()    
          #record the stuff
          G_losses.append(G_train_loss.item())
          

          num_iter += 1

      # generate a fixed_z_ image and save
          if num_iter%8==0:
            x_gen = G(fixed_z_.cuda())
            collect_x_gen.append(x_gen.detach().clone())
            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time

            # print out statistics
            print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),
                                                                      torch.mean(torch.FloatTensor(G_losses))))
            
            show_imgs(x_gen,epoch,num_iter)

 11%|█         | 7/65 [00:09<01:18,  1.35s/it]

[1/50] - ptime: 10.79, loss_d: 1.742, loss_g: 0.979


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 23%|██▎       | 15/65 [00:20<01:07,  1.35s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 21.83, loss_d: 1.528, loss_g: 1.198


 35%|███▌      | 23/65 [00:31<00:56,  1.35s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 32.89, loss_d: 1.460, loss_g: 1.335


 48%|████▊     | 31/65 [00:42<00:45,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 43.92, loss_d: 1.428, loss_g: 1.358


 60%|██████    | 39/65 [00:53<00:34,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 54.85, loss_d: 1.403, loss_g: 1.345


 72%|███████▏  | 47/65 [01:04<00:24,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 65.82, loss_d: 1.403, loss_g: 1.305


 85%|████████▍ | 55/65 [01:15<00:13,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 76.78, loss_d: 1.408, loss_g: 1.263


 97%|█████████▋| 63/65 [01:26<00:02,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[1/50] - ptime: 87.76, loss_d: 1.415, loss_g: 1.236


100%|██████████| 65/65 [01:28<00:00,  1.03s/it]
 11%|█         | 7/65 [00:09<01:16,  1.33s/it]

[2/50] - ptime: 10.63, loss_d: 1.525, loss_g: 0.945


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 23%|██▎       | 15/65 [00:20<01:06,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[2/50] - ptime: 21.54, loss_d: 1.509, loss_g: 0.945


 35%|███▌      | 23/65 [00:31<00:55,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[2/50] - ptime: 32.44, loss_d: 1.498, loss_g: 0.930


 48%|████▊     | 31/65 [00:42<00:45,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[2/50] - ptime: 43.35, loss_d: 1.496, loss_g: 0.909


 60%|██████    | 39/65 [00:52<00:34,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[2/50] - ptime: 54.31, loss_d: 1.499, loss_g: 0.889


 72%|███████▏  | 47/65 [01:03<00:24,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[2/50] - ptime: 65.25, loss_d: 1.491, loss_g: 0.871


 85%|████████▍ | 55/65 [01:14<00:13,  1.34s/it]

[2/50] - ptime: 76.21, loss_d: 1.489, loss_g: 0.864


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 97%|█████████▋| 63/65 [01:25<00:02,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[2/50] - ptime: 87.13, loss_d: 1.482, loss_g: 0.861


100%|██████████| 65/65 [01:27<00:00,  1.03s/it]
 11%|█         | 7/65 [00:09<01:17,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[3/50] - ptime: 10.65, loss_d: 1.403, loss_g: 0.823


 23%|██▎       | 15/65 [00:20<01:07,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[3/50] - ptime: 21.62, loss_d: 1.419, loss_g: 0.835


 35%|███▌      | 23/65 [00:31<00:56,  1.33s/it]

[3/50] - ptime: 32.56, loss_d: 1.422, loss_g: 0.820


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 48%|████▊     | 31/65 [00:42<00:45,  1.35s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[3/50] - ptime: 43.66, loss_d: 1.423, loss_g: 0.821


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[3/50] - ptime: 54.56, loss_d: 1.426, loss_g: 0.817


 72%|███████▏  | 47/65 [01:04<00:24,  1.34s/it]

[3/50] - ptime: 65.51, loss_d: 1.426, loss_g: 0.812


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 85%|████████▍ | 55/65 [01:15<00:13,  1.35s/it]

[3/50] - ptime: 76.53, loss_d: 1.427, loss_g: 0.805


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 97%|█████████▋| 63/65 [01:26<00:02,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[3/50] - ptime: 87.49, loss_d: 1.428, loss_g: 0.800


100%|██████████| 65/65 [01:27<00:00,  1.03s/it]
 11%|█         | 7/65 [00:09<01:16,  1.32s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[4/50] - ptime: 10.62, loss_d: 1.420, loss_g: 0.758


 23%|██▎       | 15/65 [00:20<01:06,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[4/50] - ptime: 21.52, loss_d: 1.422, loss_g: 0.772


 35%|███▌      | 23/65 [00:31<00:55,  1.32s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[4/50] - ptime: 32.35, loss_d: 1.425, loss_g: 0.770


 48%|████▊     | 31/65 [00:41<00:45,  1.34s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[4/50] - ptime: 43.32, loss_d: 1.428, loss_g: 0.767


 60%|██████    | 39/65 [00:52<00:34,  1.33s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[4/50] - ptime: 54.25, loss_d: 1.427, loss_g: 0.769


 66%|██████▌   | 43/65 [00:58<00:29,  1.36s/it]

In [0]:
D_result.size()
