Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytorch_Pix2Pix_cGAN.py: implementation of Pix2Pix with conditional GAN (cGAN) #14

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 361 additions & 0 deletions pytorch_Pix2Pix_cGAN.py
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ignore this file in this commit and only care about requirements.txt. The correct pytorch_Pix2Pix_cGAN is in the later commit.

Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
"""
This is the code for Pix2Pix framework: https://arxiv.org/abs/1611.07004

The basic idea of Pix2Pix is to use conditional GAN (cGAN) to train a model
to translate an image representation to another representation.
E.g: satellite -> map; original -> cartoon; scence day -> scene night; etc
=> the output is "conditioned" on the input image

Some details about the framework
1. Training framework: Generative Adversarial Network (GAN)
+ Input: original image I1
+ Output: translated image I2 (size(I1) = size(I2))
2. Generator: U-Net
3. Discriminator: Convolutional Neural Network Binary Classifier
"""

import os, time
import numpy as np
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable

"""
The Generator is a U-Net 256 with skip connections between Encoder and Decoder
"""
class generator(nn.Module):
def __init__(self, ngpu):
super(generator, self).__init__()
self.ngpu = ngpu

"""
===== Encoder ======

* Encoder has the following architecture:
0) Inp3
1) C64
2) Leaky, C128, Norm
3) Leaky, C256, Norm
4) Leaky, C512, Norm
5) Leaky, C512, Norm
6) Leaky, C512, Norm
7) Leaky, C512

* The structure of 1 encoder block is:
1) LeakyReLU(prev layer)
2) Conv2D
3) BatchNorm

Where Conv2D has kernel_size-4, stride=2, padding=1 for all layers
"""
self.encoder1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)

self.encoder2 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128)
)

self.encoder3 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
)

self.encoder4 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512)
)

self.encoder5 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512)
)

self.encoder6 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512)
)

self.encoder7 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False)
)

"""
===== Decoder =====
* Decoder has the following architecture:
1) ReLU(from latent space), DC512, Norm, Drop 0.5 - Residual
2) ReLU, DC512, Norm, Drop 0.5, Residual
3) ReLU, DC512, Norm, Drop 0.5, Residual
4) ReLU, DC256, Norm, Residual
5) ReLU, DC128, Norm, Residual
6) ReLU, DC64, Norm, Residual
7) ReLU, DC3, Tanh()

* Note: only apply Dropout in the first 3 Decoder layers

* The structure of each Decoder block is:
1) ReLU(from prev layer)
2) ConvTranspose2D
3) BatchNorm
4) Dropout
5) Skip connection

Where ConvTranpose2D has kernel_size=4, stride=2, padding=1
"""
self.decoder1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.Dropout(0.5)
)
# skip connection in forward()

self.decoder2 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=512*2, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.Dropout(0.5)
)
# skip connection in forward()

self.decoder3 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=512*2, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.Dropout(0.5)
)
# skip connection in forward()

self.decoder4 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=512*2, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
#nn.Dropout(0.5)
)

self.decoder5 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=256*2, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
#nn.Dropout(0.5)
)

self.decoder6 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=128*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
#nn.Dropout(0.5)
)

self.decoder7 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=64*2, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)

def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
e5 = self.encoder5(e4)
e6 = self.encoder6(e5)

latent_space = self.encoder7(e6)

d1 = torch.cat([self.decoder1(latent_space), e6], dim=1)
d2 = torch.cat([self.decoder2(d1), e5], dim=1)
d3 = torch.cat([self.decoder3(d2), e4], dim=1)
d4 = torch.cat([self.decoder4(d3), e3], dim=1)
d5 = torch.cat([self.decoder5(d4), e2], dim=1)
d6 = torch.cat([self.decoder6(d5), e1], dim=1)

out = self.decoder7(d6)

return out

"""
The Discriminator is the binary classifier with CNN architecture
"""
class discriminator(nn.Module):
def __init__(self, ngpu):
super(discriminator, self).__init__()
self.ngpu = ngpu

self.structure = nn.Sequential(
nn.Conv2d(in_channels=3*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(in_channels=64, out_channels= 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=False),
nn.Sigmoid()
)

def forward(self, x):
return self.structure(x)

"""
weight initializer
"""
def weights_init(m):
name = m.__class__.__name__

if(name.find("Conv") > -1):
nn.init.normal_(m.weight.data, 0.0, 0.02) # ~N(mean=0.0, std=0.02)
elif(name.find("BatchNorm") > -1):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)

def show_image(img, title="No title", figsize=(5,5)):
img = img.numpy().transpose(1,2,0)
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])

img = img * std + mean
np.clip(img, 0, 1)

plt.figure(figsize=figsize)
plt.imshow(img)
plt.title(title)
plt.imsave(f'{title}.png')

# training parameters
NUM_EPOCHS=100
bs=1 # suggested by the paper
lr=0.0002
beta1=0.5
beta2=0.999
NUM_EPOCHS = 200
ngpu = 1
L1_lambda = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# data_loader
data_dir = "maps"
data_transform = transforms.Compose([
transforms.Resize((256, 512)),
transforms.CenterCrop((256, 512)),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = datasets.ImageFolder(root=os.path.join(data_dir, "train"), transform=data_transform)
dataset_val = datasets.ImageFolder(root=os.path.join(data_dir, "val"), transform=data_transform)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=bs, shuffle=True, num_workers=0)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=24, shuffle=True, num_workers=0)

# network
model_G = generator(ngpu=1)
if(device == "cuda" and ngpu > 1):
model_G = nn.DataParallel(model_G, list(range(ngpu)))
model_G.apply(weights_init)
model_G.to(device)

model_D = discriminator(ngpu=1)
if(device == "cuda" and ngpu>1):
model_D = torch.DataParallel(model_D, list(range(ngpu)))
model_D.apply(weights_init)
model_D.to(device)

# Binary Cross Entropy loss
criterion = nn.BCELoss()

# Adam optimizer
optimizerD = optim.Adam(model_D.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(model_G.parameters(), lr=lr, betas=(beta1, beta2))

for epoch in range(NUM_EPOCHS+1):
print(f"Training epoch {epoch+1}")
for images,_ in iter(dataloader_train):
# ========= Train Discriminator ===========
# Train on real data
# Maximize log(D(x,y)) <- maximize D(x,y)
model_D.zero_grad()

inputs = images[:,:,:,:256].to(device) # input image data
targets = images[:,:,:,256:].to(device) # real targets data

real_data = torch.cat([inputs, targets], dim=1).to(device)
outputs = model_D(real_data) # label "real" data
labels = torch.ones(size = outputs.shape, dtype=torch.float, device=device)

lossD_real = 0.5 * criterion(outputs, labels) # divide the objective by 2 -> slow down D
lossD_real.backward()

# Train on fake data
# Maximize log(1-D(x,G(x))) <- minimize D(x,G(x))
gens = model_G(inputs).detach()

fake_data = torch.cat([inputs, gens], dim=1) # generated image data
outputs = model_D(fake_data)
labels = torch.zeros(size = outputs.shape, dtype=torch.float, device=device) # label "fake" data

lossD_fake = 0.5 * criterion(outputs, labels) # divide the objective by 2 -> slow down D
lossD_fake.backward()

optimizerD.step()

# ========= Train Generator x2 times ============
# maximize log(D(x, G(x)))
for i in range(2):
model_G.zero_grad()

gens = model_G(inputs)

gen_data = torch.cat([inputs, gens], dim=1) # concatenated generated data
outputs = model_D(gen_data)
labels = torch.ones(size = outputs.shape, dtype=torch.float, device=device)

lossG = criterion(outputs, labels) + L1_lambda * torch.abs(gens-targets).sum()
lossG.backward()
optimizerG.step()

if(epoch%5==0):
torch.save(model_G, "./sat2map_model_G.pth") # save Generator's weights
torch.save(model_D, "./sat2map_model_D.pth") # save Discriminator's weights
print("Done!")


"""*******************************************************
Generator Evaluation
*******************************************************"""
model_G = torch.load("./sat2map_model_G.pth")
model_G.apply(weights_init)
test_imgs,_ = next(iter(dataloader_val))

satellite = test_imgs[:,:,:,:256].to(device)
maps = test_imgs[:,:,:,256:].to(device)

gen = model_G(satellite)
#gen = gen[0]

satellite = satellite.detach().cpu()
gen = gen.detach().cpu()
maps = maps.detach().cpu()

show_image(torchvision.utils.make_grid(satellite, padding=10), title="Pix2Pix - Input Satellite Images", figsize=(50,50))
show_image(torchvision.utils.make_grid(gen, padding=10), title="Pix2Pix - Generated Maps", figsize=(50,50))
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torch==0.1.12+cu80
torchvision==0.1.8+cu80
matplotlib==1.3.1
imageio==2.2.0
scipy==0.19.1