In [1]:
from model import *
import numpy as np
from tensorflow.keras.optimizers import Adam, RMSprop
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
# from dataset_builder import *


def generate_real_samples(dataset, ground_trud_ds, n_samples, patch_size):
    ix = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    gt = ground_trud_ds[ix]
    y = np.ones((n_samples, patch_size, patch_size, 1))
    return X, y, gt


def generate_fake_samples(g_model, dataset, patch_size):
    w_noise = np.random.normal(0, 1, (dataset.shape[0], 14, 14, 1024))
    X = g_model.predict([dataset, w_noise])
    y = np.zeros((len(X), patch_size, patch_size, 1))
    return X, y


def sample_images(generator, source, target, idx):
    print(target.shape)
    target = np.uint8(target * 127.5 + 127.5)
    w_noise = np.random.normal(0, 1, (1, 14, 14, 1024))
    predicted = generator.predict([source, w_noise])
    im = np.uint8(predicted[0, ...] * 127.5 + 127.5)
    im_source = np.uint8(source[0, ...] * 255)
    print(im_source.shape)
    im_c = np.concatenate((np.squeeze(im, axis=-1), np.squeeze(target, axis=-1),
                           im_source[..., 0], im_source[..., 1]), axis=1)
    plt.imsave('./outputs/sketch_conversion' + str(idx) + '.png', im_c, cmap='terrain')



def train_gan():
    data = np.load('training_data4.npz')
    XTrain = data['x']
    YTrain = data['y']
    input_shape_gen = (XTrain.shape[1], XTrain.shape[2], XTrain.shape[3])
    input_shape_disc = (YTrain.shape[1], YTrain.shape[2], YTrain.shape[3])
    print(input_shape_gen,input_shape_disc)

    terrain_generator = UNet(input_shape_gen)
    terrain_discriminator = patch_discriminator(input_shape_disc)
    optd = Adam(0.0001, 0.5)
    terrain_discriminator.compile(loss='binary_crossentropy', optimizer=optd)
    composite_model = mount_discriminator_generator(
        terrain_generator, terrain_discriminator, input_shape_gen)
    composite_model.compile(
        loss=[
            'binary_crossentropy', 'mae'], loss_weights=[
            1, 3], optimizer=optd)

    n_epochs, n_batch, = 150, 20
    bat_per_epo = int(len(XTrain) / n_batch)
    patch_size = 15
    n_steps = bat_per_epo * n_epochs
    min_loss = 999
    avg_loss = 0
    avg_dloss = 0
    for i in range(n_steps):
        X_real, labels_real, Y_target = generate_real_samples(XTrain, YTrain, n_batch, patch_size)
        Y_target[np.isnan(Y_target)] = 0
        X_real[np.isnan(X_real)] = 0
        
        # mask 
        mask = np.random.uniform(low=0.0, high=1.0, size=X_real.shape)>0.75 # mask 75 %
        X_real = X_real* mask

        Y_fake, labels_fake = generate_fake_samples(terrain_generator, X_real, patch_size)
        w_noise = np.random.normal(0, 1, (n_batch, 14, 14, 1024))
        losses_composite = composite_model.train_on_batch(
            [X_real, w_noise], [labels_real, Y_target])

        loss_discriminator_fake = terrain_discriminator.train_on_batch(
            [Y_fake, X_real], labels_fake)
        loss_discriminator_real = terrain_discriminator.train_on_batch(
            [Y_target, X_real], labels_real)
        d_loss = (loss_discriminator_fake + loss_discriminator_real) / 2
        avg_dloss = avg_dloss + (d_loss - avg_dloss) / (i + 1)
        avg_loss = avg_loss + (losses_composite[0] - avg_loss) / (i + 1)

        if i % 100 == 0:
            print(i,'total loss:' + str(avg_loss) + ' d_loss:' + str(avg_dloss))
            sample_images(terrain_generator, X_real[0:1, ...], Y_target[0, ...], i)
        if i % 500 == 0:
            terrain_discriminator.save('terrain_discriminator' + str(i) + '.h5', True)
            terrain_generator.save('terrain_generator' + str(i) + '.h5', True)

import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

train_gan()

Num GPUs Available:  4
(225, 225, 2) (225, 225, 1)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 225, 225, 2) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 225, 225, 64) 1216        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 225, 225, 64) 36928       conv2d[0][0]                     
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 112, 112, 64) 0           conv2d_1[0][0]                   
___________________________________________

0 total loss:2.1351747512817383 d_loss:0.7207650244235992
(225, 225, 1)
(225, 225, 2)
100 total loss:2.4661205652916784 d_loss:0.6850746268830677
(225, 225, 1)
(225, 225, 2)
200 total loss:2.5924362489833173 d_loss:0.6488100202633673
(225, 225, 1)
(225, 225, 2)
300 total loss:2.51589955285538 d_loss:0.6595132480254806
(225, 225, 1)
(225, 225, 2)
400 total loss:2.5013800999172893 d_loss:0.6535405568555558
(225, 225, 1)
(225, 225, 2)
500 total loss:2.455791227831812 d_loss:0.6550710295473864
(225, 225, 1)
(225, 225, 2)
600 total loss:2.4407718205412317 d_loss:0.6588368582137906
(225, 225, 1)
(225, 225, 2)
700 total loss:2.4280068687297507 d_loss:0.6506595860970088
(225, 225, 1)
(225, 225, 2)
800 total loss:2.4329782456494446 d_loss:0.6541736921465274
(225, 225, 1)
(225, 225, 2)
900 total loss:2.4421808679942694 d_loss:0.6526411654392498
(225, 225, 1)
(225, 225, 2)
1000 total loss:2.432399706764299 d_loss:0.6505616202049853
(225, 225, 1)
(225, 225, 2)
1100 total loss:2.433022020731483 d_l

6600 total loss:2.1830062079187638 d_loss:0.6605158269878306
(225, 225, 1)
(225, 225, 2)
6700 total loss:2.181508651635056 d_loss:0.6605586983928402
(225, 225, 1)
(225, 225, 2)
6800 total loss:2.179604882072865 d_loss:0.6606401729084859
(225, 225, 1)
(225, 225, 2)
6900 total loss:2.1777217248231997 d_loss:0.6608169585214845
(225, 225, 1)
(225, 225, 2)
7000 total loss:2.175935820245245 d_loss:0.660916654718463
(225, 225, 1)
(225, 225, 2)
7100 total loss:2.1738276238039234 d_loss:0.6610693807891153
(225, 225, 1)
(225, 225, 2)
7200 total loss:2.171981844036566 d_loss:0.6612017705251277
(225, 225, 1)
(225, 225, 2)
7300 total loss:2.170083384908334 d_loss:0.6613639836444385
(225, 225, 1)
(225, 225, 2)
7400 total loss:2.168470307357116 d_loss:0.6614829246823043
(225, 225, 1)
(225, 225, 2)
7500 total loss:2.166494461540028 d_loss:0.6615945112142448
(225, 225, 1)
(225, 225, 2)
7600 total loss:2.1645815942086992 d_loss:0.6617606188403545
(225, 225, 1)
(225, 225, 2)
7700 total loss:2.16281161091

13400 total loss:2.107774401963555 d_loss:0.6655706597662939
(225, 225, 1)
(225, 225, 2)
13500 total loss:2.1074450834884755 d_loss:0.6655656489400222
(225, 225, 1)
(225, 225, 2)
13600 total loss:2.1068463157568003 d_loss:0.6655959302953788
(225, 225, 1)
(225, 225, 2)
13700 total loss:2.1063289830531735 d_loss:0.6656149636531827
(225, 225, 1)
(225, 225, 2)
13800 total loss:2.105871049928185 d_loss:0.6656098451758797
(225, 225, 1)
(225, 225, 2)
13900 total loss:2.105382267275049 d_loss:0.66565575249481
(225, 225, 1)
(225, 225, 2)
14000 total loss:2.1049894759607573 d_loss:0.6656933725590655
(225, 225, 1)
(225, 225, 2)
14100 total loss:2.1045490922881296 d_loss:0.6657032301178822
(225, 225, 1)
(225, 225, 2)
14200 total loss:2.1042376898534054 d_loss:0.6656991554961037
(225, 225, 1)
(225, 225, 2)
14300 total loss:2.103854368169971 d_loss:0.6657099802169202
(225, 225, 1)
(225, 225, 2)
14400 total loss:2.1034120445193296 d_loss:0.6657283722215636
(225, 225, 1)
(225, 225, 2)
14500 total loss

20100 total loss:2.081465518632925 d_loss:0.6664099198002126
(225, 225, 1)
(225, 225, 2)
20200 total loss:2.081221963366592 d_loss:0.6664406262186252
(225, 225, 1)
(225, 225, 2)
20300 total loss:2.081005414409192 d_loss:0.6664241698162437
(225, 225, 1)
(225, 225, 2)
20400 total loss:2.0806932410190635 d_loss:0.6664453078515892
(225, 225, 1)
(225, 225, 2)
20500 total loss:2.080360554781231 d_loss:0.6664377468923504
(225, 225, 1)
(225, 225, 2)
20600 total loss:2.0800630282756165 d_loss:0.6664656966229965
(225, 225, 1)
(225, 225, 2)
20700 total loss:2.0797945988430975 d_loss:0.6664635347002987
(225, 225, 1)
(225, 225, 2)
20800 total loss:2.0795125966060324 d_loss:0.6664800293371868
(225, 225, 1)
(225, 225, 2)
20900 total loss:2.079218611763532 d_loss:0.6664889527060713
(225, 225, 1)
(225, 225, 2)
21000 total loss:2.078942275247106 d_loss:0.666508513463801
(225, 225, 1)
(225, 225, 2)
21100 total loss:2.0785682928584146 d_loss:0.6665046497797089
(225, 225, 1)
(225, 225, 2)
21200 total loss:

26700 total loss:2.0671256691867965 d_loss:0.6663965052468483
(225, 225, 1)
(225, 225, 2)
26800 total loss:2.066957211597619 d_loss:0.6663926266048208
(225, 225, 1)
(225, 225, 2)
26900 total loss:2.066835112422979 d_loss:0.6663579513490563
(225, 225, 1)
(225, 225, 2)
27000 total loss:2.0666550473413294 d_loss:0.6663688046845404
(225, 225, 1)
(225, 225, 2)
27100 total loss:2.0664982525106836 d_loss:0.6663627205643752
(225, 225, 1)
(225, 225, 2)
27200 total loss:2.066335935561252 d_loss:0.6663486720366008
(225, 225, 1)
(225, 225, 2)
27300 total loss:2.066248590180036 d_loss:0.6663352927179228
(225, 225, 1)
(225, 225, 2)
27400 total loss:2.0661010448525494 d_loss:0.6663219040782973
(225, 225, 1)
(225, 225, 2)
27500 total loss:2.065984384475449 d_loss:0.6663140644170077
(225, 225, 1)
(225, 225, 2)
27600 total loss:2.0658689148939584 d_loss:0.6662915449827473
(225, 225, 1)
(225, 225, 2)
27700 total loss:2.065767572397842 d_loss:0.6662936657859505
(225, 225, 1)
(225, 225, 2)
27800 total loss

33400 total loss:2.059880448620429 d_loss:0.6656101028558262
(225, 225, 1)
(225, 225, 2)
33500 total loss:2.0598860487062893 d_loss:0.6655845568225719
(225, 225, 1)
(225, 225, 2)
33600 total loss:2.059822555092237 d_loss:0.6655666311117299
(225, 225, 1)
(225, 225, 2)
33700 total loss:2.0597546067259986 d_loss:0.6655448794572575
(225, 225, 1)
(225, 225, 2)
33800 total loss:2.059606530889168 d_loss:0.6655432999402813
(225, 225, 1)
(225, 225, 2)
33900 total loss:2.059519610662786 d_loss:0.6655327808052068
(225, 225, 1)
(225, 225, 2)
34000 total loss:2.0595763448771147 d_loss:0.6655032412708907
(225, 225, 1)
(225, 225, 2)
34100 total loss:2.0594773844453784 d_loss:0.6654942382161726
(225, 225, 1)
(225, 225, 2)
34200 total loss:2.059380060932594 d_loss:0.6654814957960508
(225, 225, 1)
(225, 225, 2)
34300 total loss:2.059278007888331 d_loss:0.665460001315405
(225, 225, 1)
(225, 225, 2)
34400 total loss:2.059237415601833 d_loss:0.6654503648809248
(225, 225, 1)
(225, 225, 2)
34500 total loss:2

40100 total loss:2.0557263742131995 d_loss:0.6645499995485462
(225, 225, 1)
(225, 225, 2)
40200 total loss:2.0556180000791366 d_loss:0.664546898827444
(225, 225, 1)
(225, 225, 2)
40300 total loss:2.0555712061512 d_loss:0.6645352170471311
(225, 225, 1)
(225, 225, 2)
40400 total loss:2.05552673549634 d_loss:0.6645148826165089
(225, 225, 1)
(225, 225, 2)
40500 total loss:2.0554528160450167 d_loss:0.664510622375094
(225, 225, 1)
(225, 225, 2)
40600 total loss:2.0553676787833326 d_loss:0.664499137394434
(225, 225, 1)
(225, 225, 2)
40700 total loss:2.0552744289272784 d_loss:0.6644911510167766
(225, 225, 1)
(225, 225, 2)
40800 total loss:2.0551974483191944 d_loss:0.6644786497630725
(225, 225, 1)
(225, 225, 2)
40900 total loss:2.055149111064901 d_loss:0.6644663131624408
(225, 225, 1)
(225, 225, 2)
41000 total loss:2.05510362556586 d_loss:0.664450825924793
(225, 225, 1)
(225, 225, 2)
41100 total loss:2.055055127715576 d_loss:0.6644364867596667
(225, 225, 1)
(225, 225, 2)
41200 total loss:2.0550

46700 total loss:2.0526399408575484 d_loss:0.6636536968351318
(225, 225, 1)
(225, 225, 2)
46800 total loss:2.0525795983462176 d_loss:0.663648791059184
(225, 225, 1)
(225, 225, 2)
46900 total loss:2.0525551400218385 d_loss:0.6636271249200403
(225, 225, 1)
(225, 225, 2)
47000 total loss:2.052507919166051 d_loss:0.6636250410581327
(225, 225, 1)
(225, 225, 2)
47100 total loss:2.052454517025895 d_loss:0.663614491604601
(225, 225, 1)
(225, 225, 2)
47200 total loss:2.052406484627471 d_loss:0.6636045288018926
(225, 225, 1)
(225, 225, 2)
47300 total loss:2.0523182819545855 d_loss:0.6635977610381383
(225, 225, 1)
(225, 225, 2)
47400 total loss:2.052254731960774 d_loss:0.6635886866664843
(225, 225, 1)
(225, 225, 2)
47500 total loss:2.052166619051577 d_loss:0.663579650082173
(225, 225, 1)
(225, 225, 2)
47600 total loss:2.05209115571262 d_loss:0.6635792773527788
(225, 225, 1)
(225, 225, 2)
47700 total loss:2.052084022054511 d_loss:0.6635678422478152
(225, 225, 1)
(225, 225, 2)
47800 total loss:2.05

53400 total loss:2.0495517550843996 d_loss:0.662888512677444
(225, 225, 1)
(225, 225, 2)
53500 total loss:2.049503688346462 d_loss:0.6628732514979895
(225, 225, 1)
(225, 225, 2)
53600 total loss:2.0494599945631635 d_loss:0.6628644594911257
(225, 225, 1)
(225, 225, 2)
53700 total loss:2.0494358246269586 d_loss:0.6628473102521308
(225, 225, 1)
(225, 225, 2)
53800 total loss:2.0493689421457204 d_loss:0.662831618463966
(225, 225, 1)
(225, 225, 2)
53900 total loss:2.0492997082665716 d_loss:0.6628232717085427
(225, 225, 1)
(225, 225, 2)
54000 total loss:2.0492610677352787 d_loss:0.662813686832839
(225, 225, 1)
(225, 225, 2)
54100 total loss:2.0492096287214983 d_loss:0.6628057156531104
(225, 225, 1)
(225, 225, 2)
54200 total loss:2.0491962280003366 d_loss:0.6627934965247001
(225, 225, 1)
(225, 225, 2)
54300 total loss:2.0491886442268297 d_loss:0.6627805049080755
(225, 225, 1)
(225, 225, 2)
54400 total loss:2.0491572492900554 d_loss:0.6627617000396667
(225, 225, 1)
(225, 225, 2)
54500 total lo

60100 total loss:2.047548305999988 d_loss:0.6619854356611795
(225, 225, 1)
(225, 225, 2)
60200 total loss:2.0475207645531124 d_loss:0.6619752576170526
(225, 225, 1)
(225, 225, 2)
60300 total loss:2.047493155841841 d_loss:0.6619574923819338
(225, 225, 1)
(225, 225, 2)
60400 total loss:2.0474622462873566 d_loss:0.6619434744196508
(225, 225, 1)
(225, 225, 2)
60500 total loss:2.0473958891156325 d_loss:0.6619351233878373
(225, 225, 1)
(225, 225, 2)
60600 total loss:2.047364863129175 d_loss:0.6619170066759473
(225, 225, 1)
(225, 225, 2)
60700 total loss:2.0473448070229976 d_loss:0.6619022856248394
(225, 225, 1)
(225, 225, 2)
60800 total loss:2.0473168229992234 d_loss:0.6618841913260565
(225, 225, 1)
(225, 225, 2)
60900 total loss:2.0472674244105793 d_loss:0.6618755435343937
(225, 225, 1)
(225, 225, 2)
61000 total loss:2.0472779296635992 d_loss:0.661859894321864
(225, 225, 1)
(225, 225, 2)
61100 total loss:2.047256926745234 d_loss:0.6618457706282228
(225, 225, 1)
(225, 225, 2)
61200 total los

KeyboardInterrupt: 