diff --git a/pytorch_Pix2Pix_cGAN.py b/pytorch_Pix2Pix_cGAN.py new file mode 100644 index 0000000000..96c996bff0 --- /dev/null +++ b/pytorch_Pix2Pix_cGAN.py @@ -0,0 +1,361 @@ +""" + This is the code for Pix2Pix framework: https://arxiv.org/abs/1611.07004 + + The basic idea of Pix2Pix is to use conditional GAN (cGAN) to train a model + to translate an image representation to another representation. + E.g: satellite -> map; original -> cartoon; scence day -> scene night; etc + => the output is "conditioned" on the input image + + Some details about the framework + 1. Training framework: Generative Adversarial Network (GAN) + + Input: original image I1 + + Output: translated image I2 (size(I1) = size(I2)) + 2. Generator: U-Net + 3. Discriminator: Convolutional Neural Network Binary Classifier +""" + +import os, time +import numpy as np +import matplotlib.pyplot as plt +import itertools +import pickle +import imageio +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +from torchvision import datasets, transforms +from torch.autograd import Variable + +""" + The Generator is a U-Net 256 with skip connections between Encoder and Decoder +""" +class generator(nn.Module): + def __init__(self, ngpu): + super(generator, self).__init__() + self.ngpu = ngpu + + """ + ===== Encoder ====== + + * Encoder has the following architecture: + 0) Inp3 + 1) C64 + 2) Leaky, C128, Norm + 3) Leaky, C256, Norm + 4) Leaky, C512, Norm + 5) Leaky, C512, Norm + 6) Leaky, C512, Norm + 7) Leaky, C512 + + * The structure of 1 encoder block is: + 1) LeakyReLU(prev layer) + 2) Conv2D + 3) BatchNorm + + Where Conv2D has kernel_size-4, stride=2, padding=1 for all layers + """ + self.encoder1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False) + + self.encoder2 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(128) + ) + + self.encoder3 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(256), + ) + + self.encoder4 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512) + ) + + self.encoder5 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512) + ) + + self.encoder6 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512) + ) + + self.encoder7 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False) + ) + + """ + ===== Decoder ===== + * Decoder has the following architecture: + 1) ReLU(from latent space), DC512, Norm, Drop 0.5 - Residual + 2) ReLU, DC512, Norm, Drop 0.5, Residual + 3) ReLU, DC512, Norm, Drop 0.5, Residual + 4) ReLU, DC256, Norm, Residual + 5) ReLU, DC128, Norm, Residual + 6) ReLU, DC64, Norm, Residual + 7) ReLU, DC3, Tanh() + + * Note: only apply Dropout in the first 3 Decoder layers + + * The structure of each Decoder block is: + 1) ReLU(from prev layer) + 2) ConvTranspose2D + 3) BatchNorm + 4) Dropout + 5) Skip connection + + Where ConvTranpose2D has kernel_size=4, stride=2, padding=1 + """ + self.decoder1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.Dropout(0.5) + ) + # skip connection in forward() + + self.decoder2 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512*2, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.Dropout(0.5) + ) + # skip connection in forward() + + self.decoder3 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512*2, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.Dropout(0.5) + ) + # skip connection in forward() + + self.decoder4 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512*2, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(256), + #nn.Dropout(0.5) + ) + + self.decoder5 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=256*2, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(128), + #nn.Dropout(0.5) + ) + + self.decoder6 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=128*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(64), + #nn.Dropout(0.5) + ) + + self.decoder7 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=64*2, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False), + nn.Tanh() + ) + + def forward(self, x): + e1 = self.encoder1(x) + e2 = self.encoder2(e1) + e3 = self.encoder3(e2) + e4 = self.encoder4(e3) + e5 = self.encoder5(e4) + e6 = self.encoder6(e5) + + latent_space = self.encoder7(e6) + + d1 = torch.cat([self.decoder1(latent_space), e6], dim=1) + d2 = torch.cat([self.decoder2(d1), e5], dim=1) + d3 = torch.cat([self.decoder3(d2), e4], dim=1) + d4 = torch.cat([self.decoder4(d3), e3], dim=1) + d5 = torch.cat([self.decoder5(d4), e2], dim=1) + d6 = torch.cat([self.decoder6(d5), e1], dim=1) + + out = self.decoder7(d6) + + return out + +""" + The Discriminator is the binary classifier with CNN architecture +""" +class discriminator(nn.Module): + def __init__(self, ngpu): + super(discriminator, self).__init__() + self.ngpu = ngpu + + self.structure = nn.Sequential( + nn.Conv2d(in_channels=3*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=64, out_channels= 128, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + return self.structure(x) + +""" + weight initializer +""" +def weights_init(m): + name = m.__class__.__name__ + + if(name.find("Conv") > -1): + nn.init.normal_(m.weight.data, 0.0, 0.02) # ~N(mean=0.0, std=0.02) + elif(name.find("BatchNorm") > -1): + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0.0) + +def show_image(img, title="No title", figsize=(5,5)): + img = img.numpy().transpose(1,2,0) + mean = np.array([0.5, 0.5, 0.5]) + std = np.array([0.5, 0.5, 0.5]) + + img = img * std + mean + np.clip(img, 0, 1) + + plt.figure(figsize=figsize) + plt.imshow(img) + plt.title(title) + plt.imsave(f'{title}.png') + +# training parameters +NUM_EPOCHS=100 +bs=1 # suggested by the paper +lr=0.0002 +beta1=0.5 +beta2=0.999 +NUM_EPOCHS = 200 +ngpu = 1 +L1_lambda = 100 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# data_loader +data_dir = "maps" +data_transform = transforms.Compose([ + transforms.Resize((256, 512)), + transforms.CenterCrop((256, 512)), + transforms.RandomVerticalFlip(p=0.5), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) +]) +dataset_train = datasets.ImageFolder(root=os.path.join(data_dir, "train"), transform=data_transform) +dataset_val = datasets.ImageFolder(root=os.path.join(data_dir, "val"), transform=data_transform) +dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=bs, shuffle=True, num_workers=0) +dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=24, shuffle=True, num_workers=0) + +# network +model_G = generator(ngpu=1) +if(device == "cuda" and ngpu > 1): + model_G = nn.DataParallel(model_G, list(range(ngpu))) +model_G.apply(weights_init) +model_G.to(device) + +model_D = discriminator(ngpu=1) +if(device == "cuda" and ngpu>1): + model_D = torch.DataParallel(model_D, list(range(ngpu))) +model_D.apply(weights_init) +model_D.to(device) + +# Binary Cross Entropy loss +criterion = nn.BCELoss() + +# Adam optimizer +optimizerD = optim.Adam(model_D.parameters(), lr=lr, betas=(beta1, beta2)) +optimizerG = optim.Adam(model_G.parameters(), lr=lr, betas=(beta1, beta2)) + +for epoch in range(NUM_EPOCHS+1): + print(f"Training epoch {epoch+1}") + for images,_ in iter(dataloader_train): + # ========= Train Discriminator =========== + # Train on real data + # Maximize log(D(x,y)) <- maximize D(x,y) + model_D.zero_grad() + + inputs = images[:,:,:,:256].to(device) # input image data + targets = images[:,:,:,256:].to(device) # real targets data + + real_data = torch.cat([inputs, targets], dim=1).to(device) + outputs = model_D(real_data) # label "real" data + labels = torch.ones(size = outputs.shape, dtype=torch.float, device=device) + + lossD_real = 0.5 * criterion(outputs, labels) # divide the objective by 2 -> slow down D + lossD_real.backward() + + # Train on fake data + # Maximize log(1-D(x,G(x))) <- minimize D(x,G(x)) + gens = model_G(inputs).detach() + + fake_data = torch.cat([inputs, gens], dim=1) # generated image data + outputs = model_D(fake_data) + labels = torch.zeros(size = outputs.shape, dtype=torch.float, device=device) # label "fake" data + + lossD_fake = 0.5 * criterion(outputs, labels) # divide the objective by 2 -> slow down D + lossD_fake.backward() + + optimizerD.step() + + # ========= Train Generator x2 times ============ + # maximize log(D(x, G(x))) + for i in range(2): + model_G.zero_grad() + + gens = model_G(inputs) + + gen_data = torch.cat([inputs, gens], dim=1) # concatenated generated data + outputs = model_D(gen_data) + labels = torch.ones(size = outputs.shape, dtype=torch.float, device=device) + + lossG = criterion(outputs, labels) + L1_lambda * torch.abs(gens-targets).sum() + lossG.backward() + optimizerG.step() + + if(epoch%5==0): + torch.save(model_G, "./sat2map_model_G.pth") # save Generator's weights + torch.save(model_D, "./sat2map_model_D.pth") # save Discriminator's weights +print("Done!") + + +"""******************************************************* + Generator Evaluation +*******************************************************""" +model_G = torch.load("./sat2map_model_G.pth") +model_G.apply(weights_init) +test_imgs,_ = next(iter(dataloader_val)) + +satellite = test_imgs[:,:,:,:256].to(device) +maps = test_imgs[:,:,:,256:].to(device) + +gen = model_G(satellite) +#gen = gen[0] + +satellite = satellite.detach().cpu() +gen = gen.detach().cpu() +maps = maps.detach().cpu() + +show_image(torchvision.utils.make_grid(satellite, padding=10), title="Pix2Pix - Input Satellite Images", figsize=(50,50)) +show_image(torchvision.utils.make_grid(gen, padding=10), title="Pix2Pix - Generated Maps", figsize=(50,50)) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..f03c27b58f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch==0.1.12+cu80 +torchvision==0.1.8+cu80 +matplotlib==1.3.1 +imageio==2.2.0 +scipy==0.19.1