In [None]:
import tensorflow as tf
from torchvision import datasets
import pathlib
import torch
#from scipy.misc import imresize
#from PIL import Image
from torchvision import transforms
import numpy as np
import itertools
import imageio
from matplotlib import pyplot as plt

In [None]:
def get_data(data_key):
  #data key can be a number or the name of the data set
  #0:facades
  #1:maps
  #2: edges to shoes

  # downlodeading the data_set
  if data_key == 0: data_key = 'facades'
  if data_key == 1: data_key = 'maps'
  if data_key == 2: data_key = 'edges2shoes'

  _URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{data_key}.tar.gz'
  path_to = tf.keras.utils.get_file(fname=f"{data_key}.tar.gz", origin=_URL, extract=True)
  path_to  = pathlib.Path(path_to)
  path = path_to.parent/data_key
  #print(path)

  return path

In [None]:
def load_data(data_key):
  #getting the data
  path = get_data(data_key)
  transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
  d_set = datasets.ImageFolder(path, transform) 
  d_set2 = datasets.ImageFolder(path, transform)
  ind_train = d_set.class_to_idx['train']
  #ind_test = d_set2.class_to_idx['test']

  #splitting into test and training set
  n = 0
  m = 0
  for i in range(d_set.__len__()):
    if ind_train != d_set.imgs[n][1]:
      del d_set.imgs[n]
      n -= 1
    else: 
      del d_set2.imgs[m]
      m -= 1
    n += 1
    m += 1
  
  
  train_load = torch.utils.data.DataLoader(d_set, 1, shuffle=True)
  test_load = torch.utils.data.DataLoader(d_set2, 1, shuffle=True)

  return train_load, test_load
  

In [None]:
train_load, test_load = load_data(0)
n= 0
m= 0
for x,_ in train_load:
  n+=1
for y,_ in test_load:
  m+=1
print('n: ',n)
print('m: ', m)

Downloading data from http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
n:  400
m:  206


In [None]:
from re import X
def prepare_image(x_,size=256,order = True,crop= True):
#works for facades, for maps size needs to be 600 for eges2shoes order and crop needs to be false

  if order:
    y_ = x_[:, :, :, 0:size]
    x_ = x_[:, :, :, size:]
  else:
    y_ = x_[:, :, :, size:]
    x_ = x_[:, :, :, 0:size]

  #resize if size = 256
  if size == 256 and crop:
    output_x = torch.FloatTensor(1, 3, 286, 286)
    output_y = torch.FloatTensor(1, 3, 286, 286)
    resizer = transforms.Resize(286)
    x_ = resizer(x_)
    y_ = resizer(y_)

  #random crop
  if crop:
    cur_size = x_.size()[2]
    output_x = torch.FloatTensor(1, 3, 256, 256)
    output_y = torch.FloatTensor(1, 3, 256, 256)
    rand1 = np.random.randint(0, cur_size - 256)
    rand2 = np.random.randint(0, cur_size - 256)
    output_x = x_[:,:, rand1: 256 + rand1, rand2: 256 + rand2]
    output_y = y_[:,:, rand1: 256 + rand1, rand2: 256 + rand2]
    x_ = output_x
    y_ = output_y
      
  #flip ocassionally
  if torch.rand(1)[0] < 0.3:
    #print('flip')
    #print(x_)
    #print(y_)
    outputs_x = torch.FloatTensor(x_.size())
    outputs_y = torch.FloatTensor(x_.size())
    #for i in range(256):
    img_x = torch.FloatTensor((np.fliplr(x_[0].numpy().transpose(1, 2, 0)).transpose(2, 0, 1).reshape(-1, 3, 256, 256) + 1) / 2)
    outputs_x[0] = (img_x - 0.5) / 0.5
    img_y = torch.FloatTensor((np.fliplr(y_[0].numpy().transpose(1, 2, 0)).transpose(2, 0, 1).reshape(-1, 3, 256, 256) + 1) / 2)
    outputs_y[0] = (img_y - 0.5) / 0.5
  
    x_ = outputs_x
    y_ = outputs_y
    #print(x_)
    #print(y_)

  return x_, y_



In [None]:
def test_prepare_image():
  training, testing = load_data(0)
  for x,_ in training:
    xp,yp = prepare_image(x)
    print(xp.type()) 


In [None]:
def create_picture(G,x_,y_,epoch,path): 
  #creates an image with the current model
  cr_im = G(x_)
  size_figure_grid = 3
  fig, ax = plt.subplots(x_.size()[0], size_figure_grid, figsize=(5, 5))
  y_ = y_.cpu()

  ax[0].imshow((x_[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
  ax[1].imshow((cr_im[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
  ax[2].imshow((y_[0].numpy().transpose(1, 2, 0) + 1) / 2)

  label = 'Epoch {0}'.format(epoch)
  plt.savefig(path)
  plt.close()


In [None]:
def plot_results(history, root, model, epoch,loss_key):
  images = []
  for e in range(epoch):
    img_name = root + 'Fixed_results/loss_key'+str(loss_key)+'/' + model  + str(epoch + 1) + '.png'
    images.append(imageio.imread(img_name))
  imageio.mimsave(root + model+ str(loss_key) + 'generate_animation.gif', images, fps=2)

  x = range(len(history['D_losses']))
  y1 = history['D_losses']
  y2 = history['G_losses']

  plt.plot(x, y1, label='D_loss')
  plt.plot(x, y2, label='G_loss')
  plt.xlabel('Iteration')
  plt.ylabel('Losses')

  plt.legend(loc=4)
  plt.grid(True)
  plt.tight_layout()
  plt.savefig(root + model + 'train_hist.png')
  plt.close()
