In [1]:
# https://machinelearningmastery.com/how-to-develop-a-pix2pix-gan-for-image-to-image-translation/

In [2]:
# from models import define_gan

In [3]:
from models import *
from datagen import *
from utils import *
import os

Using TensorFlow backend.


In [20]:
# train pix2pix models
def train(d_model, g_model, gan_model, all_paths, path_labels, n_epochs=100, batch_size=1, path_results='model_performance/'):
    if not os.path.isdir(path_results):
        os.mkdir(path_results)
    
    # determine the output square shape of the discriminator
    n_patch = (d_model.output_shape[1],d_model.output_shape[2])

    bat_per_epo = int(len(all_paths)/batch_size)
    n_steps = bat_per_epo * n_epochs
    #n_steps = 5
    print('Batch per epochs = ', bat_per_epo)
    print('Total Steps = ', n_steps)
    # manually enumerate epochs
    for i in range(n_steps):
        # select a batch of real samples
        X_realA, X_realB, y_real = next(my_datagen(all_paths,path_labels, patch_shape = n_patch, batch_size=batch_size))
        # generate a batch of fake samples
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
        
#         print(X_realA.shape)
#         print(X_realB.shape)
#         print(y_real.shape)
#         print('-----------')
#         print(X_fakeB.shape)
#         print(y_fake.shape)
        
        
        # update discriminator for real samples
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
        # update discriminator for generated samples
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
        # update the generator
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
        # summarize performance
        print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
        # summarize model performance
        if i%bat_per_epo==0 and i>1:
            # select a sample of input images
            X_realA, X_realB, _ = next(my_datagen(all_paths, path_labels, patch_shape=n_patch, batch_size=3))
            # generate a batch of fake samples
            X_fakeB, _ = generate_fake_samples(g_model, X_realA, patch_shape=(1,1))
            summarize_performance(path_results, i, g_model,  X_realA, X_realB, X_fakeB, n_samples=3)


In [21]:
path_images = 'maps_dataset_subset/train_images/'
path_labels = 'maps_dataset_subset/train_maps/'

all_paths = get_paths(path_images)

image_shape = (256,256,3)

d_model, g_model, gan_model = define_gan(image_shape)



In [None]:
train(d_model, g_model, gan_model, all_paths, path_labels,batch_size=4)

Batch per epochs =  12
Total Steps =  1200
>1, d1[0.300] d2[1.019] g[84.720]
>2, d1[0.304] d2[0.768] g[77.527]
>3, d1[0.311] d2[0.591] g[80.461]
>4, d1[0.325] d2[0.466] g[69.512]
>5, d1[0.457] d2[0.446] g[65.270]
>6, d1[0.343] d2[0.387] g[69.373]
>7, d1[0.346] d2[0.370] g[63.384]
>8, d1[0.531] d2[0.416] g[62.548]
>9, d1[0.547] d2[0.762] g[55.499]
>10, d1[0.594] d2[0.604] g[51.920]
>11, d1[0.485] d2[0.344] g[52.379]
>12, d1[0.390] d2[0.328] g[49.261]
>13, d1[0.259] d2[0.291] g[44.398]
>Saved: model_performance/plot_000013.png and model_performance/model.h5
>14, d1[0.184] d2[0.271] g[46.697]
>15, d1[0.159] d2[0.235] g[41.513]
>16, d1[0.124] d2[0.158] g[38.173]
>17, d1[0.151] d2[0.135] g[37.146]
>18, d1[0.078] d2[0.170] g[44.074]
>19, d1[0.088] d2[0.092] g[34.062]
>20, d1[0.060] d2[0.087] g[37.061]
>21, d1[0.091] d2[0.074] g[28.028]
>22, d1[0.053] d2[0.076] g[31.059]
>23, d1[0.043] d2[0.064] g[31.717]
>24, d1[0.044] d2[0.046] g[26.464]
>25, d1[0.044] d2[0.061] g[26.498]
>Saved: model_perf

>200, d1[0.071] d2[0.023] g[12.218]
>201, d1[0.081] d2[0.031] g[12.539]
>202, d1[0.334] d2[0.090] g[7.725]
>203, d1[0.016] d2[0.372] g[10.438]
>204, d1[0.020] d2[0.015] g[14.053]
>205, d1[0.579] d2[0.033] g[7.824]
>Saved: model_performance/plot_000205.png and model_performance/model.h5
>206, d1[0.095] d2[0.278] g[9.731]
>207, d1[0.025] d2[0.039] g[11.083]
>208, d1[0.024] d2[0.022] g[9.122]
>209, d1[0.028] d2[0.017] g[8.557]
>210, d1[0.022] d2[0.012] g[12.636]
>211, d1[0.061] d2[0.013] g[8.087]
>212, d1[1.833] d2[0.345] g[11.163]
>213, d1[0.020] d2[0.083] g[11.931]
>214, d1[0.071] d2[0.068] g[8.016]
>215, d1[0.009] d2[0.106] g[8.366]
>216, d1[0.019] d2[0.024] g[10.789]
>217, d1[1.431] d2[0.137] g[7.903]
>Saved: model_performance/plot_000217.png and model_performance/model.h5
>218, d1[0.006] d2[0.176] g[11.271]
>219, d1[0.140] d2[0.107] g[7.352]
>220, d1[0.012] d2[0.061] g[9.979]
>221, d1[0.951] d2[0.217] g[5.550]
>222, d1[0.003] d2[0.159] g[12.299]
>223, d1[0.006] d2[0.160] g[9.801]
>22