In [1]:
import pandas as pd
import numpy as np
import sys
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import warnings
from IPython.display import clear_output
from multiprocessing import Pool
from time import time

warnings.filterwarnings('ignore')

from model import ADAIN
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

### Fold 0

In [13]:
fold = str(0)
station_metaq = np.load('../data/adain/fold_'+fold+'_station_metaq_data_test.npz')['arr_0']
station_dist = np.load('../data/adain/fold_'+fold+'_station_dist_data_test.npz')['arr_0']

local_met = np.load('../data/adain/fold_'+fold+'_local_met_data_test.npz')['arr_0']
local_aq0 = np.load('../data/adain/fold_'+fold+'_local_aq_data_test.npz')['arr_0']
local_stationid = np.load('../data/adain/fold_'+fold+'_local_stationids_test.npz')['arr_0']

In [14]:
model = tf.keras.models.load_model('../results/adain/trained_'+fold+'.h5')

preds0 = model.predict([local_met, station_dist, station_metaq])

mean_squared_error(local_aq0, preds0, squared=False)

30.707960314612567

In [19]:
stationids = np.unique(local_stationid)
for station in stationids:
    rows = np.where(local_stationid==station)
    err = mean_squared_error(preds0[rows],local_aq0[rows], squared=False)
    print(station,err)

1003 30.93912378694612
1005 32.53990770704477
1006 24.777214314227084
1010 31.23569927226019
1011 32.529045391544585
1014 39.83379178509891
1018 34.227025954477455
1019 45.763058625076575
1030 34.961665556272344
1035 52.92676109308956


### Fold 1

In [15]:
fold = str(1)
station_metaq = np.load('../data/adain/fold_'+fold+'_station_metaq_data_test.npz')['arr_0']
station_dist = np.load('../data/adain/fold_'+fold+'_station_dist_data_test.npz')['arr_0']

local_met = np.load('../data/adain/fold_'+fold+'_local_met_data_test.npz')['arr_0']
local_aq1 = np.load('../data/adain/fold_'+fold+'_local_aq_data_test.npz')['arr_0']
local_stationid = np.load('../data/adain/fold_'+fold+'_local_stationids_test.npz')['arr_0']

In [16]:
model = tf.keras.models.load_model('../results/adain/trained_'+fold+'.h5')

preds1 = model.predict([local_met, station_dist, station_metaq])

mean_squared_error(local_aq1, preds1, squared=False)

32.968716579457165

In [17]:
stationids = np.unique(local_stationid)
for station in stationids:
    rows = np.where(local_stationid==station)
    err = mean_squared_error(preds1[rows],local_aq1[rows], squared=False)
    print(station,err)

1002 27.5453142446929
1004 28.464204971368687
1007 21.065093389258717
1023 25.956718582577068
1025 30.08510766820561
1029 41.5760741621442
1031 38.78357541289693
1032 31.692457490634297
1033 38.912281853212896
1034 39.07033554540359


In [18]:
d1 = np.array(np.split(local_aq1, 10))
d2 = np.array(np.split(preds1, 10))

for i in range(10):
    print(mean_squared_error(d1[i], d2[i], squared=False))

21.700990724407983
40.0969200145252
26.24898106584621
26.013292590894764
31.714203587218503
42.21255767923697
19.5589825522535
22.083001031352353
24.50167723974685
56.2737121744582


### Fold 2

In [19]:
fold = str(2)
station_metaq = np.load('../data/adain/fold_'+fold+'_station_metaq_data_test.npz')['arr_0']
station_dist = np.load('../data/adain/fold_'+fold+'_station_dist_data_test.npz')['arr_0']

local_met = np.load('../data/adain/fold_'+fold+'_local_met_data_test.npz')['arr_0']
local_aq2 = np.load('../data/adain/fold_'+fold+'_local_aq_data_test.npz')['arr_0']
local_stationid = np.load('../data/adain/fold_'+fold+'_local_stationids_test.npz')['arr_0']

In [20]:
model = tf.keras.models.load_model('../results/adain/trained_'+fold+'.h5')

preds2 = model.predict([local_met, station_dist, station_metaq])

mean_squared_error(local_aq2, preds2, squared=False)

32.19474111798532

In [29]:
stationids = np.unique(local_stationid)
for station in stationids:
    rows = np.where(local_stationid==station)
    err = mean_squared_error(preds2[rows],local_aq2[rows], squared=False)
    print(station,err)

1008 31.042461571999908
1012 35.51963735378074
1016 32.64421096814803
1017 36.614491748838
1022 24.84871318995769
1024 33.3731261870793
1026 35.89425851512612
1027 33.42338187053883
1028 34.67165093331203
1036 51.350452478595074
