In [None]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pickle

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#!unzip /content/drive/MyDrive/STTP/MRCDBlack.zip
!cd /content/drive/MyDrive/STTP/
!ls

drive  sample_data


In [None]:
use_cuda = torch.cuda.is_available()
nz = 50
n_l = 5
n_channel = 3
n_disc = 16
n_gen = 64
nef=64
ndf=64
image_size=128
#dataroot="/home/user_3/Downloads/Ae/MorphAging"
dataroot="/content/drive/MyDrive/STTP/FaceDataSTTP/"
outf="/content/drive/MyDrive/STTP/result"
manual_seed = random.randint(1, 10000)
batch_size=64
if use_cuda:
    torch.cuda.manual_seed_all(manual_seed)
transform = transforms.Compose([
    transforms.Resize((image_size,image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))])
datafolder = dset.ImageFolder(root=dataroot, transform=transform)
dataloader = torch.utils.data.DataLoader(datafolder, shuffle=True, batch_size=batch_size, drop_last=True)
nz = int(nz)
nef = int(nef)
ndf = int(ndf)
nc = 3
out_size = image_size // 16  # 64
if use_cuda:
    BCE = nn.BCELoss().cuda()
    L1  = nn.L1Loss().cuda()
    CE = nn.CrossEntropyLoss().cuda()
    mse = nn.MSELoss().cuda()

class _Encoder(nn.Module):

    def __init__(self):
        super(_Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(nc, nef, 4, 2, padding=1),
            nn.ReLU(True),

            nn.Conv2d(nef, nef * 2, 4, 2, padding=1),
            nn.ReLU(True),

            nn.Conv2d(nef * 2, nef * 4, 4, 2, padding=1),
            nn.ReLU(True),

            nn.Conv2d(nef * 4, nef * 8, 4, 2, padding=1),
            nn.ReLU(True),
        )
        self.latent = nn.Linear(nef * 8 * out_size * out_size, nz)

    def forward(self, input):
        batch_size = input.size(0)
        hidden = self.encoder(input)
        hidden = hidden.view(batch_size, -1)
        latent_z = self.latent(hidden)
        return latent_z

encoder = _Encoder()

class _Decoder(nn.Module):
    def __init__(self):
        super(_Decoder, self).__init__()

        self.decoder_dense = nn.Sequential(
            nn.Linear(nz, ndf * 8 * out_size * out_size),
            nn.ReLU(True)
        )
        self.decoder_conv = nn.Sequential(
            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(ndf * 8, ndf * 4, 3, padding=1),
            nn.ReLU(True),


            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(ndf * 4, ndf * 2, 3, padding=1),
            nn.ReLU(True),


            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(ndf * 2, ndf, 3, padding=1),
            nn.ReLU(True),

            nn.UpsamplingNearest2d(scale_factor=2),
            nn.Conv2d(ndf, nc, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        batch_size = z.size(0)
        hidden = self.decoder_dense(z).view(batch_size,ndf * 8, out_size, out_size)
        output = self.decoder_conv(hidden)
        return output

decoder = _Decoder()

input = torch.FloatTensor(batch_size, nc, image_size, image_size)
if use_cuda:
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    input = input.cuda()
input = Variable(input)
optimizerE = optim.Adam(encoder.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizerD = optim.Adam(decoder.parameters(),lr=0.0002,betas=(0.5,0.999))
encoder.train()
decoder.train()
for epoch in range(10):
    for i, (img_data, img_label) in enumerate(dataloader):
        torch.cuda.empty_cache()
        img_data_v = Variable(img_data)
        if epoch == 0 and i == 0:
            fixed_noise = img_data[:8].repeat(5, 1, 1, 1)
            fixed_img_v = Variable(fixed_noise)
            if use_cuda:
                fixed_img_v = fixed_img_v.cuda()
                #vutils.save_image(fixed_img_v.data,'{}/initial_inputs.png'.format(outf),normalize=True)
            if use_cuda:
               img_data_v = img_data_v.cuda()
        batchSize = img_data_v.size(0)
        optimizerE.zero_grad()
        optimizerD.zero_grad()
        input.data.copy_(img_data)
        latent_z = encoder(input)
        #print("Latent Z",latent_z.shape)
        recon = decoder(latent_z)
        mse_l1  = L1(input,recon)
        mse_l1.backward()
        optimizerE.step()
        optimizerD.step()

        if i % 100== 0:
           vutils.save_image(input.data,'{}/inputs.png'.format(outf),normalize=True)


    fixed_z = encoder(fixed_img_v)
    fixed_fake = decoder(fixed_z)
    vutils.save_image(fixed_fake.data,'%s/reconst_epoch%03d.png' % (outf, epoch + 1),normalize=True)
    if(i%100==0):
      msg3 = format("MSE loss:%f" % (mse_l1.item()), "<30")
      print(msg3)
