In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from torchvision.utils import save_image

In [None]:
class Discriminator(nn.Module):
    def __init__(self, nc=1, ld=100):
        super(Discriminator, self).__init__()
        self.seq_z = nn.Sequential(
            nn.Flatten(3),
            
            nn.ConvTranspose2d(ld, 512, 2, 2, 0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.ConvTranspose2d(512, 512, 2, 2, 0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.seq_x = nn.Sequential(
            nn.Conv2d(nc, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            #nn.Conv2d(512, 1024, 4, 2, 1),
            #nn.BatchNorm2d(1024),
            #nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.seq_xz = nn.Sequential(
            nn.Conv2d(512, 1, 4),
            nn.Flatten(),
            nn.Sigmoid()
        )
    def forward(self, x, z):
        x = self.seq_x(x)
        z = self.seq_z(z)
        return self.seq_xz(x+z)

In [None]:
class Generator(nn.Module):
    def __init__(self, nc=1, ld=100):
        super(Generator, self).__init__()
        self.seq = nn.Sequential(
            #nn.ConvTranspose2d(ld, 1024, 4, 2, 0),
            #nn.BatchNorm2d(1024),
            #nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(ld, 512, 4, 2, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, nc, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self, x):
        return self.seq(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, nc=1, ld=100):
        super(Encoder, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(nc, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
            #nn.Conv2d(512, 1024, 4, 2, 1),
            #nn.BatchNorm2d(1024),
            #nn.ReLU(inplace=True),
            
            #nn.Conv2d(1024, ld, 4),
            nn.Conv2d(512, ld, 4),
            nn.Tanh()
        )
    def forward(self, x):
        return self.seq(x)

In [None]:
torch.manual_seed(0)

In [None]:
device = 'cuda'
num_channels, latent_dim = 3, 100

In [None]:
dataset = CIFAR10(root='.', download=True, transform=Compose([Resize(32), ToTensor(), Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

In [None]:
G = Generator(nc=num_channels, ld=latent_dim).to(device)
D = Discriminator(nc=num_channels, ld=latent_dim).to(device)
E = Encoder(nc=num_channels, ld=latent_dim).to(device)

In [None]:
adver_criterion = nn.BCELoss().to(device)
recon_criterion = nn.MSELoss(reduction='sum').to(device)

In [None]:
D_optimizer = optim.Adam(D.parameters(), lr=0.0002)
G_optimizer = optim.Adam(G.parameters(), lr=0.0002)
E_optimizer = optim.Adam(E.parameters(), lr=0.0002)

In [None]:
fixed_latent = (2*torch.rand(64,100,1,1)-1).to(device)
outdir = 'cifar10_output'
os.makedirs(outdir, exist_ok=True)

In [None]:
losses = {
    'D':[],
    'G':[],
    'E':[],
    'I':[]
}

In [None]:
for epoch in range(200):
    for idx, (x, _) in enumerate(dataloader):
        batch_size = x.shape[0]
        x_real = x.detach().to(device)
        
        # Train D
        D.zero_grad()
        
        z_real = E(x_real).detach()
        z_fake = 2*torch.rand(batch_size, latent_dim, 1, 1).to(device)-1
        x_fake = G(z_fake).detach()
        
        real_pred = D(x_real, z_real)
        fake_pred = D(x_fake, z_fake)
        
        d_real_target = torch.ones(batch_size, 1).to(device)
        d_fake_target = torch.zeros(batch_size, 1).to(device)
        
        D_loss = adver_criterion(fake_pred, d_fake_target) + adver_criterion(real_pred, d_real_target)
        D_loss.backward()
        D_optimizer.step()
        
        # Train G
        G.zero_grad()
        
        z_fake = 2*torch.rand(batch_size, latent_dim, 1, 1).to(device)-1
        x_fake = G(z_fake)
        fake_pred = D(x_fake, z_fake)
        g_target = d_real_target.clone()
        
        G_loss = adver_criterion(fake_pred, g_target)
        G_loss.backward()
        G_optimizer.step()
        
        # Train E
        E.zero_grad()
        
        z_real = E(x_real)
        real_pred = D(x_real, z_real)
        
        e_target = torch.ones(batch_size, 1).to(device)
        E_loss = adver_criterion(real_pred, e_target)
        
        E_loss.backward()
        E_optimizer.step()
        
        
        # latent identity Loss
        E.zero_grad()
        G.zero_grad()
        
        z_real = E(x_real)
        x_recon = G(z_real)
        z_recon = E(x_recon)
    
        I_loss = recon_criterion(z_recon, z_real) + recon_criterion(x_recon, x_real)

        if idx % 2 == 1:
            I_loss.backward(retain_graph=True)
            E_optimizer.step()
            G_optimizer.step()
        
        losses['D'].append(D_loss.item())
        losses['G'].append(G_loss.item())
        losses['E'].append(E_loss.item())
        losses['I'].append(I_loss.item())
    save_image(G(fixed_latent), f'{outdir}/fixed_{epoch+1}.png')
    torch.save(G, f'{outdir}/G_{epoch+1}.pth')
    torch.save(D, f'{outdir}/D_{epoch+1}.pth')
    torch.save(E, f'{outdir}/E_{epoch+1}.pth')
    with open(f'{outdir}/losses.dat', 'wb') as fp:
        pickle.dump(losses, fp)