In [250]:
import gc
from copy import deepcopy
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator

from libs.util import random_mask
from libs.pconv_model import PConvUnet
from libs.properties import properties

# Settings
MAX_BATCH_SIZE = 32

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [251]:
properties_dict = properties()
length = properties_dict["length"]

In [252]:
matrix_df = pd.read_csv('./data/trafficV201301_M.csv', index_col=0, parse_dates=True)
matrix_df = np.array(matrix_df)
matrix_df.shape

(2976, 1024)

In [253]:
train_df = matrix_df.reshape(matrix_df.shape[0], length, length,1)
X_train, X_val = train_test_split(train_df, test_size = 0.1, random_state=42)

In [254]:
class DataGenerator(ImageDataGenerator):
    def flow(self, X, *args, **kwargs):
        while True:
            
            # Get augmentend image samples
            ori = next(super().flow(X, *args, **kwargs))
    
            # Get masks for each image sample
            mask = np.stack([random_mask(ori.shape[1], ori.shape[2], size=0) for _ in range(ori.shape[0])], axis=0)

            # Apply masks to all image sample
            masked = deepcopy(ori)
            masked[mask==0] = -1

            # Yield ([ori, masl],  ori) training batches
            # print(masked.shape, ori.shape)
            gc.collect()
            yield [masked, mask], ori
            
train_datagen = DataGenerator()
train_generator = train_datagen.flow(
    X_train, batch_size=MAX_BATCH_SIZE
)

# Create validation generator
val_datagen = DataGenerator()
val_generator = val_datagen.flow(
    X_val, batch_size=MAX_BATCH_SIZE
)

In [256]:
model = PConvUnet()

model.fit(
    train_generator, 
    validation_data=val_generator,
    steps_per_epoch = 50,
    validation_steps=50,
    epochs=1
)

Epoch 1/1
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
 1/50 [..............................] - ETA: 9:35 - loss: 295400.312532768
 2/50 [>.............................] - ETA: 4:53 - loss: 320692.765632768
 3/50 [>.............................] - ETA: 3:17 - loss: 341027.812532768
 4/50 [=>............................] - ETA: 2:29 - loss: 319920.445332768
 6/50 [==>...........................] - ETA: 1:38 - loss: 325115.838532768
 7/50 [===>..........................] - ETA: 1:24 - loss: 320108.584832768
 8/50 [===>..........................] - ETA: 1:14 - loss: 317914.171932768
 9/50 [====>.........................] - ETA: 1:06 - loss: 318724.618132768
10/50 [=====>........................] - ETA: 59s - loss: 319511.3906 32768
11/50 [=====>........................] - ETA: 54s - loss: 318290.031232768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
32768
327

In [257]:
test = X_val[9,np.newaxis,:]

test_mask = random_mask(test.shape[1], test.shape[2])
test_mask = test_mask[np.newaxis,:]

test_mask[0,:,:,0].shape
test[test_mask==0] = -1

# test_mask.shape
# plt.imshow(test[0,:,:,0]*255)

test_res = model.predict([test, test_mask])
test_res.sum()

177.03925

In [258]:
test_res

array([[[[6.3336766e-01],
         [5.5583358e-01],
         [5.3973895e-01],
         ...,
         [5.7288700e-01],
         [5.3796852e-01],
         [1.1081165e-01]],

        [[6.7021870e-07],
         [1.0091960e-10],
         [4.8795878e-10],
         ...,
         [5.5498197e-03],
         [1.6098699e-02],
         [2.2438752e-02]],

        [[2.6856709e-04],
         [8.0704112e-13],
         [9.0779146e-12],
         ...,
         [2.2887174e-02],
         [5.3930599e-03],
         [4.9100653e-03]],

        ...,

        [[1.3514141e-02],
         [7.9915859e-02],
         [6.4885080e-02],
         ...,
         [8.4660989e-01],
         [8.3301985e-01],
         [6.8189448e-01]],

        [[9.0191938e-02],
         [1.0811681e-01],
         [1.6225349e-02],
         ...,
         [8.8258755e-01],
         [8.1635296e-01],
         [7.2080505e-01]],

        [[1.2827500e-03],
         [6.3185287e-01],
         [3.6745891e-01],
         ...,
         [7.0890480e-01],
        