From ded224bce2ce3a04be84bfa8705a95336964ab65 Mon Sep 17 00:00:00 2001 From: Anh Nhu Date: Sun, 21 Jan 2024 12:20:21 -0500 Subject: [PATCH 1/2] Add requirements.txt --- pytorch_Pix2Pix_cGAN.py | 263 ++++++++++++++++++++++++++++++++++++++++ requirements.txt | 5 + 2 files changed, 268 insertions(+) create mode 100644 pytorch_Pix2Pix_cGAN.py create mode 100644 requirements.txt diff --git a/pytorch_Pix2Pix_cGAN.py b/pytorch_Pix2Pix_cGAN.py new file mode 100644 index 0000000000..d3ab62ec1c --- /dev/null +++ b/pytorch_Pix2Pix_cGAN.py @@ -0,0 +1,263 @@ +import os, time +import matplotlib.pyplot as plt +import itertools +import pickle +import imageio +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.autograd import Variable + +# G(z) +class generator(nn.Module): + # initializers + def __init__(self, d=128): + super(generator, self).__init__() + self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, 1, 0) + self.deconv1_bn = nn.BatchNorm2d(d*8) + self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1) + self.deconv2_bn = nn.BatchNorm2d(d*4) + self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1) + self.deconv3_bn = nn.BatchNorm2d(d*2) + self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1) + self.deconv4_bn = nn.BatchNorm2d(d) + self.deconv5 = nn.ConvTranspose2d(d, 1, 4, 2, 1) + + # weight_init + def weight_init(self, mean, std): + for m in self._modules: + normal_init(self._modules[m], mean, std) + + # forward method + def forward(self, input): + # x = F.relu(self.deconv1(input)) + x = F.relu(self.deconv1_bn(self.deconv1(input))) + x = F.relu(self.deconv2_bn(self.deconv2(x))) + x = F.relu(self.deconv3_bn(self.deconv3(x))) + x = F.relu(self.deconv4_bn(self.deconv4(x))) + x = F.tanh(self.deconv5(x)) + + return x + +class discriminator(nn.Module): + # initializers + def __init__(self, d=128): + super(discriminator, self).__init__() + self.conv1 = nn.Conv2d(1, d, 4, 2, 1) + self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1) + self.conv2_bn = nn.BatchNorm2d(d*2) + self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1) + self.conv3_bn = nn.BatchNorm2d(d*4) + self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1) + self.conv4_bn = nn.BatchNorm2d(d*8) + self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0) + + # weight_init + def weight_init(self, mean, std): + for m in self._modules: + normal_init(self._modules[m], mean, std) + + # forward method + def forward(self, input): + x = F.leaky_relu(self.conv1(input), 0.2) + x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2) + x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2) + x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2) + x = F.sigmoid(self.conv5(x)) + + return x + +def normal_init(m, mean, std): + if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): + m.weight.data.normal_(mean, std) + m.bias.data.zero_() + +fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100, 1, 1) # fixed noise +fixed_z_ = Variable(fixed_z_.cuda(), volatile=True) +def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False): + z_ = torch.randn((5*5, 100)).view(-1, 100, 1, 1) + z_ = Variable(z_.cuda(), volatile=True) + + G.eval() + if isFix: + test_images = G(fixed_z_) + else: + test_images = G(z_) + G.train() + + size_figure_grid = 5 + fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) + for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): + ax[i, j].get_xaxis().set_visible(False) + ax[i, j].get_yaxis().set_visible(False) + + for k in range(5*5): + i = k // 5 + j = k % 5 + ax[i, j].cla() + ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray') + + label = 'Epoch {0}'.format(num_epoch) + fig.text(0.5, 0.04, label, ha='center') + plt.savefig(path) + + if show: + plt.show() + else: + plt.close() + +def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): + x = range(len(hist['D_losses'])) + + y1 = hist['D_losses'] + y2 = hist['G_losses'] + + plt.plot(x, y1, label='D_loss') + plt.plot(x, y2, label='G_loss') + + plt.xlabel('Iter') + plt.ylabel('Loss') + + plt.legend(loc=4) + plt.grid(True) + plt.tight_layout() + + if save: + plt.savefig(path) + + if show: + plt.show() + else: + plt.close() + +# training parameters +batch_size = 128 +lr = 0.0002 +train_epoch = 20 + +# data_loader +img_size = 64 +transform = transforms.Compose([ + transforms.Scale(img_size), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) +]) +train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=transform), + batch_size=batch_size, shuffle=True) + +# network +G = generator(128) +D = discriminator(128) +G.weight_init(mean=0.0, std=0.02) +D.weight_init(mean=0.0, std=0.02) +G.cuda() +D.cuda() + +# Binary Cross Entropy loss +BCE_loss = nn.BCELoss() + +# Adam optimizer +G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) +D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) + +# results save folder +if not os.path.isdir('MNIST_DCGAN_results'): + os.mkdir('MNIST_DCGAN_results') +if not os.path.isdir('MNIST_DCGAN_results/Random_results'): + os.mkdir('MNIST_DCGAN_results/Random_results') +if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'): + os.mkdir('MNIST_DCGAN_results/Fixed_results') + +train_hist = {} +train_hist['D_losses'] = [] +train_hist['G_losses'] = [] +train_hist['per_epoch_ptimes'] = [] +train_hist['total_ptime'] = [] +num_iter = 0 + +print('training start!') +start_time = time.time() +for epoch in range(train_epoch): + D_losses = [] + G_losses = [] + epoch_start_time = time.time() + for x_, _ in train_loader: + # train discriminator D + D.zero_grad() + + mini_batch = x_.size()[0] + + y_real_ = torch.ones(mini_batch) + y_fake_ = torch.zeros(mini_batch) + + x_, y_real_, y_fake_ = Variable(x_.cuda()), Variable(y_real_.cuda()), Variable(y_fake_.cuda()) + D_result = D(x_).squeeze() + D_real_loss = BCE_loss(D_result, y_real_) + + z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1) + z_ = Variable(z_.cuda()) + G_result = G(z_) + + D_result = D(G_result).squeeze() + D_fake_loss = BCE_loss(D_result, y_fake_) + D_fake_score = D_result.data.mean() + + D_train_loss = D_real_loss + D_fake_loss + + D_train_loss.backward() + D_optimizer.step() + + # D_losses.append(D_train_loss.data[0]) + D_losses.append(D_train_loss.data[0]) + + # train generator G + G.zero_grad() + + z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1) + z_ = Variable(z_.cuda()) + + G_result = G(z_) + D_result = D(G_result).squeeze() + G_train_loss = BCE_loss(D_result, y_real_) + G_train_loss.backward() + G_optimizer.step() + + G_losses.append(G_train_loss.data[0]) + + num_iter += 1 + + epoch_end_time = time.time() + per_epoch_ptime = epoch_end_time - epoch_start_time + + + print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)), + torch.mean(torch.FloatTensor(G_losses)))) + p = 'MNIST_DCGAN_results/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png' + fixed_p = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png' + show_result((epoch+1), save=True, path=p, isFix=False) + show_result((epoch+1), save=True, path=fixed_p, isFix=True) + train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses))) + train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses))) + train_hist['per_epoch_ptimes'].append(per_epoch_ptime) + +end_time = time.time() +total_ptime = end_time - start_time +train_hist['total_ptime'].append(total_ptime) + +print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime)) +print("Training finish!... save training results") +torch.save(G.state_dict(), "MNIST_DCGAN_results/generator_param.pkl") +torch.save(D.state_dict(), "MNIST_DCGAN_results/discriminator_param.pkl") +with open('MNIST_DCGAN_results/train_hist.pkl', 'wb') as f: + pickle.dump(train_hist, f) + +show_train_hist(train_hist, save=True, path='MNIST_DCGAN_results/MNIST_DCGAN_train_hist.png') + +images = [] +for e in range(train_epoch): + img_name = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(e + 1) + '.png' + images.append(imageio.imread(img_name)) +imageio.mimsave('MNIST_DCGAN_results/generation_animation.gif', images, fps=5) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..f03c27b58f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch==0.1.12+cu80 +torchvision==0.1.8+cu80 +matplotlib==1.3.1 +imageio==2.2.0 +scipy==0.19.1 From 47110e334d3ed7f9875f010b995d699a7308b6bc Mon Sep 17 00:00:00 2001 From: Anh Nhu Date: Sun, 21 Jan 2024 13:01:53 -0500 Subject: [PATCH 2/2] Implement Pix2Pix cGAN framework --- pytorch_Pix2Pix_cGAN.py | 570 +++++++++++++++++++++++----------------- 1 file changed, 334 insertions(+), 236 deletions(-) diff --git a/pytorch_Pix2Pix_cGAN.py b/pytorch_Pix2Pix_cGAN.py index d3ab62ec1c..96c996bff0 100644 --- a/pytorch_Pix2Pix_cGAN.py +++ b/pytorch_Pix2Pix_cGAN.py @@ -1,4 +1,21 @@ +""" + This is the code for Pix2Pix framework: https://arxiv.org/abs/1611.07004 + + The basic idea of Pix2Pix is to use conditional GAN (cGAN) to train a model + to translate an image representation to another representation. + E.g: satellite -> map; original -> cartoon; scence day -> scene night; etc + => the output is "conditioned" on the input image + + Some details about the framework + 1. Training framework: Generative Adversarial Network (GAN) + + Input: original image I1 + + Output: translated image I2 (size(I1) = size(I2)) + 2. Generator: U-Net + 3. Discriminator: Convolutional Neural Network Binary Classifier +""" + import os, time +import numpy as np import matplotlib.pyplot as plt import itertools import pickle @@ -7,257 +24,338 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +import torchvision from torchvision import datasets, transforms from torch.autograd import Variable -# G(z) +""" + The Generator is a U-Net 256 with skip connections between Encoder and Decoder +""" class generator(nn.Module): - # initializers - def __init__(self, d=128): + def __init__(self, ngpu): super(generator, self).__init__() - self.deconv1 = nn.ConvTranspose2d(100, d*8, 4, 1, 0) - self.deconv1_bn = nn.BatchNorm2d(d*8) - self.deconv2 = nn.ConvTranspose2d(d*8, d*4, 4, 2, 1) - self.deconv2_bn = nn.BatchNorm2d(d*4) - self.deconv3 = nn.ConvTranspose2d(d*4, d*2, 4, 2, 1) - self.deconv3_bn = nn.BatchNorm2d(d*2) - self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1) - self.deconv4_bn = nn.BatchNorm2d(d) - self.deconv5 = nn.ConvTranspose2d(d, 1, 4, 2, 1) - - # weight_init - def weight_init(self, mean, std): - for m in self._modules: - normal_init(self._modules[m], mean, std) - - # forward method - def forward(self, input): - # x = F.relu(self.deconv1(input)) - x = F.relu(self.deconv1_bn(self.deconv1(input))) - x = F.relu(self.deconv2_bn(self.deconv2(x))) - x = F.relu(self.deconv3_bn(self.deconv3(x))) - x = F.relu(self.deconv4_bn(self.deconv4(x))) - x = F.tanh(self.deconv5(x)) - - return x - + self.ngpu = ngpu + + """ + ===== Encoder ====== + + * Encoder has the following architecture: + 0) Inp3 + 1) C64 + 2) Leaky, C128, Norm + 3) Leaky, C256, Norm + 4) Leaky, C512, Norm + 5) Leaky, C512, Norm + 6) Leaky, C512, Norm + 7) Leaky, C512 + + * The structure of 1 encoder block is: + 1) LeakyReLU(prev layer) + 2) Conv2D + 3) BatchNorm + + Where Conv2D has kernel_size-4, stride=2, padding=1 for all layers + """ + self.encoder1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False) + + self.encoder2 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(128) + ) + + self.encoder3 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(256), + ) + + self.encoder4 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512) + ) + + self.encoder5 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512) + ) + + self.encoder6 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512) + ) + + self.encoder7 = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False) + ) + + """ + ===== Decoder ===== + * Decoder has the following architecture: + 1) ReLU(from latent space), DC512, Norm, Drop 0.5 - Residual + 2) ReLU, DC512, Norm, Drop 0.5, Residual + 3) ReLU, DC512, Norm, Drop 0.5, Residual + 4) ReLU, DC256, Norm, Residual + 5) ReLU, DC128, Norm, Residual + 6) ReLU, DC64, Norm, Residual + 7) ReLU, DC3, Tanh() + + * Note: only apply Dropout in the first 3 Decoder layers + + * The structure of each Decoder block is: + 1) ReLU(from prev layer) + 2) ConvTranspose2D + 3) BatchNorm + 4) Dropout + 5) Skip connection + + Where ConvTranpose2D has kernel_size=4, stride=2, padding=1 + """ + self.decoder1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.Dropout(0.5) + ) + # skip connection in forward() + + self.decoder2 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512*2, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.Dropout(0.5) + ) + # skip connection in forward() + + self.decoder3 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512*2, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.Dropout(0.5) + ) + # skip connection in forward() + + self.decoder4 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=512*2, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(256), + #nn.Dropout(0.5) + ) + + self.decoder5 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=256*2, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(128), + #nn.Dropout(0.5) + ) + + self.decoder6 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=128*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(64), + #nn.Dropout(0.5) + ) + + self.decoder7 = nn.Sequential( + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=64*2, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False), + nn.Tanh() + ) + + def forward(self, x): + e1 = self.encoder1(x) + e2 = self.encoder2(e1) + e3 = self.encoder3(e2) + e4 = self.encoder4(e3) + e5 = self.encoder5(e4) + e6 = self.encoder6(e5) + + latent_space = self.encoder7(e6) + + d1 = torch.cat([self.decoder1(latent_space), e6], dim=1) + d2 = torch.cat([self.decoder2(d1), e5], dim=1) + d3 = torch.cat([self.decoder3(d2), e4], dim=1) + d4 = torch.cat([self.decoder4(d3), e3], dim=1) + d5 = torch.cat([self.decoder5(d4), e2], dim=1) + d6 = torch.cat([self.decoder6(d5), e1], dim=1) + + out = self.decoder7(d6) + + return out + +""" + The Discriminator is the binary classifier with CNN architecture +""" class discriminator(nn.Module): - # initializers - def __init__(self, d=128): + def __init__(self, ngpu): super(discriminator, self).__init__() - self.conv1 = nn.Conv2d(1, d, 4, 2, 1) - self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1) - self.conv2_bn = nn.BatchNorm2d(d*2) - self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1) - self.conv3_bn = nn.BatchNorm2d(d*4) - self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1) - self.conv4_bn = nn.BatchNorm2d(d*8) - self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0) - - # weight_init - def weight_init(self, mean, std): - for m in self._modules: - normal_init(self._modules[m], mean, std) - - # forward method - def forward(self, input): - x = F.leaky_relu(self.conv1(input), 0.2) - x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2) - x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2) - x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2) - x = F.sigmoid(self.conv5(x)) - - return x - -def normal_init(m, mean, std): - if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): - m.weight.data.normal_(mean, std) - m.bias.data.zero_() - -fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100, 1, 1) # fixed noise -fixed_z_ = Variable(fixed_z_.cuda(), volatile=True) -def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False): - z_ = torch.randn((5*5, 100)).view(-1, 100, 1, 1) - z_ = Variable(z_.cuda(), volatile=True) - - G.eval() - if isFix: - test_images = G(fixed_z_) - else: - test_images = G(z_) - G.train() - - size_figure_grid = 5 - fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) - for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): - ax[i, j].get_xaxis().set_visible(False) - ax[i, j].get_yaxis().set_visible(False) - - for k in range(5*5): - i = k // 5 - j = k % 5 - ax[i, j].cla() - ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray') - - label = 'Epoch {0}'.format(num_epoch) - fig.text(0.5, 0.04, label, ha='center') - plt.savefig(path) - - if show: - plt.show() - else: - plt.close() - -def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): - x = range(len(hist['D_losses'])) - - y1 = hist['D_losses'] - y2 = hist['G_losses'] - - plt.plot(x, y1, label='D_loss') - plt.plot(x, y2, label='G_loss') - - plt.xlabel('Iter') - plt.ylabel('Loss') - - plt.legend(loc=4) - plt.grid(True) - plt.tight_layout() - - if save: - plt.savefig(path) - - if show: - plt.show() - else: - plt.close() + self.ngpu = ngpu + + self.structure = nn.Sequential( + nn.Conv2d(in_channels=3*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=64, out_channels= 128, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + + nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + return self.structure(x) + +""" + weight initializer +""" +def weights_init(m): + name = m.__class__.__name__ + + if(name.find("Conv") > -1): + nn.init.normal_(m.weight.data, 0.0, 0.02) # ~N(mean=0.0, std=0.02) + elif(name.find("BatchNorm") > -1): + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0.0) + +def show_image(img, title="No title", figsize=(5,5)): + img = img.numpy().transpose(1,2,0) + mean = np.array([0.5, 0.5, 0.5]) + std = np.array([0.5, 0.5, 0.5]) + + img = img * std + mean + np.clip(img, 0, 1) + + plt.figure(figsize=figsize) + plt.imshow(img) + plt.title(title) + plt.imsave(f'{title}.png') # training parameters -batch_size = 128 -lr = 0.0002 -train_epoch = 20 +NUM_EPOCHS=100 +bs=1 # suggested by the paper +lr=0.0002 +beta1=0.5 +beta2=0.999 +NUM_EPOCHS = 200 +ngpu = 1 +L1_lambda = 100 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # data_loader -img_size = 64 -transform = transforms.Compose([ - transforms.Scale(img_size), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) +data_dir = "maps" +data_transform = transforms.Compose([ + transforms.Resize((256, 512)), + transforms.CenterCrop((256, 512)), + transforms.RandomVerticalFlip(p=0.5), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) -train_loader = torch.utils.data.DataLoader( - datasets.MNIST('data', train=True, download=True, transform=transform), - batch_size=batch_size, shuffle=True) +dataset_train = datasets.ImageFolder(root=os.path.join(data_dir, "train"), transform=data_transform) +dataset_val = datasets.ImageFolder(root=os.path.join(data_dir, "val"), transform=data_transform) +dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=bs, shuffle=True, num_workers=0) +dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=24, shuffle=True, num_workers=0) # network -G = generator(128) -D = discriminator(128) -G.weight_init(mean=0.0, std=0.02) -D.weight_init(mean=0.0, std=0.02) -G.cuda() -D.cuda() +model_G = generator(ngpu=1) +if(device == "cuda" and ngpu > 1): + model_G = nn.DataParallel(model_G, list(range(ngpu))) +model_G.apply(weights_init) +model_G.to(device) + +model_D = discriminator(ngpu=1) +if(device == "cuda" and ngpu>1): + model_D = torch.DataParallel(model_D, list(range(ngpu))) +model_D.apply(weights_init) +model_D.to(device) # Binary Cross Entropy loss -BCE_loss = nn.BCELoss() +criterion = nn.BCELoss() # Adam optimizer -G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) -D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) - -# results save folder -if not os.path.isdir('MNIST_DCGAN_results'): - os.mkdir('MNIST_DCGAN_results') -if not os.path.isdir('MNIST_DCGAN_results/Random_results'): - os.mkdir('MNIST_DCGAN_results/Random_results') -if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'): - os.mkdir('MNIST_DCGAN_results/Fixed_results') - -train_hist = {} -train_hist['D_losses'] = [] -train_hist['G_losses'] = [] -train_hist['per_epoch_ptimes'] = [] -train_hist['total_ptime'] = [] -num_iter = 0 - -print('training start!') -start_time = time.time() -for epoch in range(train_epoch): - D_losses = [] - G_losses = [] - epoch_start_time = time.time() - for x_, _ in train_loader: - # train discriminator D - D.zero_grad() - - mini_batch = x_.size()[0] - - y_real_ = torch.ones(mini_batch) - y_fake_ = torch.zeros(mini_batch) - - x_, y_real_, y_fake_ = Variable(x_.cuda()), Variable(y_real_.cuda()), Variable(y_fake_.cuda()) - D_result = D(x_).squeeze() - D_real_loss = BCE_loss(D_result, y_real_) - - z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1) - z_ = Variable(z_.cuda()) - G_result = G(z_) - - D_result = D(G_result).squeeze() - D_fake_loss = BCE_loss(D_result, y_fake_) - D_fake_score = D_result.data.mean() - - D_train_loss = D_real_loss + D_fake_loss - - D_train_loss.backward() - D_optimizer.step() - - # D_losses.append(D_train_loss.data[0]) - D_losses.append(D_train_loss.data[0]) - - # train generator G - G.zero_grad() - - z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1) - z_ = Variable(z_.cuda()) - - G_result = G(z_) - D_result = D(G_result).squeeze() - G_train_loss = BCE_loss(D_result, y_real_) - G_train_loss.backward() - G_optimizer.step() - - G_losses.append(G_train_loss.data[0]) - - num_iter += 1 - - epoch_end_time = time.time() - per_epoch_ptime = epoch_end_time - epoch_start_time - - - print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)), - torch.mean(torch.FloatTensor(G_losses)))) - p = 'MNIST_DCGAN_results/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png' - fixed_p = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png' - show_result((epoch+1), save=True, path=p, isFix=False) - show_result((epoch+1), save=True, path=fixed_p, isFix=True) - train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses))) - train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses))) - train_hist['per_epoch_ptimes'].append(per_epoch_ptime) - -end_time = time.time() -total_ptime = end_time - start_time -train_hist['total_ptime'].append(total_ptime) - -print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime)) -print("Training finish!... save training results") -torch.save(G.state_dict(), "MNIST_DCGAN_results/generator_param.pkl") -torch.save(D.state_dict(), "MNIST_DCGAN_results/discriminator_param.pkl") -with open('MNIST_DCGAN_results/train_hist.pkl', 'wb') as f: - pickle.dump(train_hist, f) - -show_train_hist(train_hist, save=True, path='MNIST_DCGAN_results/MNIST_DCGAN_train_hist.png') - -images = [] -for e in range(train_epoch): - img_name = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(e + 1) + '.png' - images.append(imageio.imread(img_name)) -imageio.mimsave('MNIST_DCGAN_results/generation_animation.gif', images, fps=5) +optimizerD = optim.Adam(model_D.parameters(), lr=lr, betas=(beta1, beta2)) +optimizerG = optim.Adam(model_G.parameters(), lr=lr, betas=(beta1, beta2)) + +for epoch in range(NUM_EPOCHS+1): + print(f"Training epoch {epoch+1}") + for images,_ in iter(dataloader_train): + # ========= Train Discriminator =========== + # Train on real data + # Maximize log(D(x,y)) <- maximize D(x,y) + model_D.zero_grad() + + inputs = images[:,:,:,:256].to(device) # input image data + targets = images[:,:,:,256:].to(device) # real targets data + + real_data = torch.cat([inputs, targets], dim=1).to(device) + outputs = model_D(real_data) # label "real" data + labels = torch.ones(size = outputs.shape, dtype=torch.float, device=device) + + lossD_real = 0.5 * criterion(outputs, labels) # divide the objective by 2 -> slow down D + lossD_real.backward() + + # Train on fake data + # Maximize log(1-D(x,G(x))) <- minimize D(x,G(x)) + gens = model_G(inputs).detach() + + fake_data = torch.cat([inputs, gens], dim=1) # generated image data + outputs = model_D(fake_data) + labels = torch.zeros(size = outputs.shape, dtype=torch.float, device=device) # label "fake" data + + lossD_fake = 0.5 * criterion(outputs, labels) # divide the objective by 2 -> slow down D + lossD_fake.backward() + + optimizerD.step() + + # ========= Train Generator x2 times ============ + # maximize log(D(x, G(x))) + for i in range(2): + model_G.zero_grad() + + gens = model_G(inputs) + + gen_data = torch.cat([inputs, gens], dim=1) # concatenated generated data + outputs = model_D(gen_data) + labels = torch.ones(size = outputs.shape, dtype=torch.float, device=device) + + lossG = criterion(outputs, labels) + L1_lambda * torch.abs(gens-targets).sum() + lossG.backward() + optimizerG.step() + + if(epoch%5==0): + torch.save(model_G, "./sat2map_model_G.pth") # save Generator's weights + torch.save(model_D, "./sat2map_model_D.pth") # save Discriminator's weights +print("Done!") + + +"""******************************************************* + Generator Evaluation +*******************************************************""" +model_G = torch.load("./sat2map_model_G.pth") +model_G.apply(weights_init) +test_imgs,_ = next(iter(dataloader_val)) + +satellite = test_imgs[:,:,:,:256].to(device) +maps = test_imgs[:,:,:,256:].to(device) + +gen = model_G(satellite) +#gen = gen[0] + +satellite = satellite.detach().cpu() +gen = gen.detach().cpu() +maps = maps.detach().cpu() + +show_image(torchvision.utils.make_grid(satellite, padding=10), title="Pix2Pix - Input Satellite Images", figsize=(50,50)) +show_image(torchvision.utils.make_grid(gen, padding=10), title="Pix2Pix - Generated Maps", figsize=(50,50))