In [306]:
import gc
from copy import deepcopy
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator

from libs.util import random_mask
from libs.pconv_model import PConvUnet
from libs.properties import properties

# Settings
MAX_BATCH_SIZE = 32

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [307]:
properties_dict = properties()
length = properties_dict["length"]

In [308]:
matrix_df = pd.read_csv('./data/trafficV201301_M.csv', index_col=0, parse_dates=True)
matrix_df = np.array(matrix_df)
matrix_df.shape

(2976, 1024)

In [309]:
train_df = matrix_df.reshape(matrix_df.shape[0], length, length,1)
X_train, X_val = train_test_split(train_df, test_size = 0.1, random_state=42)

In [310]:
class DataGenerator(ImageDataGenerator):
    def flow(self, X, *args, **kwargs):
        while True:
            
            # Get augmentend image samples
            ori = next(super().flow(X, *args, **kwargs))
    
            # Get masks for each image sample
            mask = np.stack([random_mask(ori.shape[1], ori.shape[2], size=20) for _ in range(ori.shape[0])], axis=0)

            # Apply masks to all image sample
            masked = deepcopy(ori)
            masked[mask==0] = -1

            # Yield ([ori, masl],  ori) training batches
            # print(masked.shape, ori.shape)
            gc.collect()
            yield [masked, mask], ori
            
train_datagen = DataGenerator()
train_generator = train_datagen.flow(
    X_train, batch_size=MAX_BATCH_SIZE
)

# Create validation generator
val_datagen = DataGenerator()
val_generator = val_datagen.flow(
    X_val, batch_size=MAX_BATCH_SIZE
)

In [312]:
model = PConvUnet()

model.fit(
    train_generator, 
    validation_data=val_generator,
    steps_per_epoch = 500,
    validation_steps=50,
    epochs=5
)

Epoch 1/1
Epoch 2/2
Epoch 3/3
Epoch 4/4
Epoch 5/5


In [331]:
test = X_val[55,np.newaxis,:]

test_mask = random_mask(test.shape[1], test.shape[2])
test_mask = test_mask[np.newaxis,:]

test_mask[0,:,:,0].shape
test[test_mask==0] = -1

# test_mask.shape
# plt.imshow(test[0,:,:,0]*255)

test_res = model.predict([test, test_mask])
np.sum((test-test_res)**2)

13277051.806805186

In [332]:
np.sum((test-test_res)**2)

13277051.806805186

In [333]:
test

array([[[[128.  ],
         [103.5 ],
         [106.  ],
         ...,
         [155.5 ],
         [164.5 ],
         [136.  ]],

        [[ 97.  ],
         [137.5 ],
         [134.5 ],
         ...,
         [117.  ],
         [101.5 ],
         [135.17]],

        [[581.  ],
         [603.5 ],
         [738.  ],
         ...,
         [319.  ],
         [315.75],
         [121.63]],

        ...,

        [[248.5 ],
         [296.  ],
         [275.63],
         ...,
         [ 99.5 ],
         [130.5 ],
         [130.  ]],

        [[498.75],
         [503.5 ],
         [472.5 ],
         ...,
         [ 89.57],
         [130.5 ],
         [130.  ]],

        [[404.25],
         [411.  ],
         [506.5 ],
         ...,
         [114.75],
         [ 77.75],
         [ 54.19]]]])

In [334]:
test_res

array([[[[ 67.359886],
         [ 79.67486 ],
         [ 65.841095],
         ...,
         [138.99298 ],
         [138.11185 ],
         [ 93.571884]],

        [[147.45096 ],
         [236.6901  ],
         [236.04326 ],
         ...,
         [160.43773 ],
         [175.7856  ],
         [116.8552  ]],

        [[438.61798 ],
         [582.5236  ],
         [555.0709  ],
         ...,
         [228.96028 ],
         [228.94179 ],
         [131.78683 ]],

        ...,

        [[249.90799 ],
         [365.1483  ],
         [409.99185 ],
         ...,
         [ 97.713455],
         [119.54411 ],
         [ 94.05987 ]],

        [[376.00986 ],
         [500.07422 ],
         [491.2203  ],
         ...,
         [118.95691 ],
         [129.91394 ],
         [ 91.310646]],

        [[299.46558 ],
         [444.9208  ],
         [465.98157 ],
         ...,
         [106.42385 ],
         [ 84.2543  ],
         [ 49.023655]]]], dtype=float32)