In [36]:
import gc
from copy import deepcopy
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator

from libs.util import random_mask
from libs.pconv_model import PConvUnet
from libs.properties import properties

# Settings
MAX_BATCH_SIZE = 32

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
properties_dict = properties()
length = properties_dict["length"]

In [38]:
matrix_df = pd.read_csv('./data/trafficV201301_M.csv', index_col=0, parse_dates=True)
matrix_df = np.array(matrix_df)
matrix_df.shape

(2976, 1024)

In [39]:
train_df = matrix_df.reshape(matrix_df.shape[0], length, length,1)
X_train, X_val = train_test_split(train_df, test_size = 0.1, random_state=42)

In [40]:
class DataGenerator(ImageDataGenerator):
    def flow(self, X, *args, **kwargs):
        while True:
            
            # Get augmentend image samples
            ori = next(super().flow(X, *args, **kwargs))
    
            # Get masks for each image sample
            mask = np.stack([random_mask(ori.shape[1], ori.shape[2], size=20) for _ in range(ori.shape[0])], axis=0)

            # Apply masks to all image sample
            masked = deepcopy(ori)
            masked_mean = masked[mask==1].mean()
            masked[mask==0] = masked_mean

            # Yield ([ori, masl],  ori) training batches
            # print(masked.shape, ori.shape)
            gc.collect()
            yield [masked, mask], ori
            
train_datagen = DataGenerator()
train_generator = train_datagen.flow(
    X_train, batch_size=MAX_BATCH_SIZE
)

# Create validation generator
val_datagen = DataGenerator()
val_generator = val_datagen.flow(
    X_val, batch_size=MAX_BATCH_SIZE
)

In [41]:
model = PConvUnet()

model.fit(
    train_generator, 
    validation_data=val_generator,
    steps_per_epoch = 500,
    validation_steps=50,
    epochs=5
)

Epoch 1/1
 98/500 [====>.........................] - ETA: 1:32 - loss: 209287024.979

KeyboardInterrupt: 

In [None]:
model.summary()

In [None]:
import random
test_num = random.randint(0, 200)
test = deepcopy(X_val[test_num,np.newaxis,:])

test_mask = random_mask(test.shape[1], test.shape[2])
test_mask = test_mask[np.newaxis,:]

test_mask[0,:,:,0].shape
test[test_mask==0] = -1

# test_mask.shape
# plt.imshow(test[0,:,:,0]*255)

test_res = model.predict([test, test_mask])
np.sum((test-test_res)**2)

In [None]:
X_val[test_num,np.newaxis,:]

In [None]:
test

In [None]:
test_res