In [None]:
import xarray as xr 
import matplotlib.pyplot as plt 
import numpy as np
import tensorflow as tf

#plot parameters that I personally like, feel free to make these your own.
import matplotlib
import matplotlib.patheffects as path_effects


#outlines for text 
pe1 = [path_effects.withStroke(linewidth=1.5,
                             foreground="k")]
pe2 = [path_effects.withStroke(linewidth=1.5,
                             foreground="w")]

matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9] #makes a grey background to the axis face
matplotlib.rcParams['axes.labelsize'] = 14 #fontsize in pts
matplotlib.rcParams['axes.titlesize'] = 14 
matplotlib.rcParams['xtick.labelsize'] = 12 
matplotlib.rcParams['ytick.labelsize'] = 12 
matplotlib.rcParams['legend.fontsize'] = 12 
matplotlib.rcParams['legend.facecolor'] = 'w' 
matplotlib.rcParams['savefig.transparent'] = False

#make default resolution of figures much higher (i.e., High definition)
%config InlineBackend.figure_format = 'retina'

from keras_unet_collection import models

2023-10-18 16:00:15.101734: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
ds_train_xr = xr.open_dataset('/pl/active/ATOC_SynopticMet/data/ar_data/Research3/Data/Training_data/train.nc')

ds_val_xr = xr.open_dataset('/pl/active/ATOC_SynopticMet/data/ar_data/Research3/Data/Training_data/validate.nc')
   

In [None]:
#make datasets 
ds_train = tf.data.Dataset.from_tensor_slices((ds_train_xr.features.values,ds_train_xr.labels_2d.values))
ds_val = tf.data.Dataset.from_tensor_slices((ds_val_xr.features.values,ds_val_xr.labels_2d.values))

# #shuffle only the training
ds_train = ds_train.shuffle(ds_train.cardinality().numpy())

#batch both 
batch_size = 64
ds_train = ds_train.batch(batch_size)
ds_val = ds_val.batch(batch_size)

In [None]:
for batch in ds_train:
    break 

In [None]:
model = models.unet_2d([256, 32, 8],[2,4],1,stack_num_down=1,stack_num_up=1,output_activation='Sigmoid',weights=None)

In [None]:
model.summary()

In [None]:
model.compile(loss=tf.keras.losses.BinaryCrossentropy(),optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3))

history = model.fit(ds_train,validation_data=ds_val,epochs=30)


In [None]:
#like sklearn, we do .predict!
y_preds = model.predict(ds_val)

plt.hist(y_preds.ravel())
plt.xlabel('prob of AR')
plt.ylabel('count')
plt.xlim([0,1])

In [None]:
#get the features for example 12 in this batch 
one_example_features = batch[0][5]
#get the label for that same example 
one_example_label = batch[1][5]


fig,axes = plt.subplots(1,4,figsize=(20,5))
axes[0].imshow(one_example_features[:,:,0],cmap='Blues')
axes[1].imshow(one_example_features[:,:,1],cmap='turbo')
axes[2].imshow(one_example_features[:,:,2],cmap='Spectral_r')
axes[3].imshow(one_example_features[:,:,3],cmap='Greys_r')

plt.tight_layout()


fig,axes = plt.subplots(1,2,figsize=(10,5),facecolor='w')
pm = axes[0].imshow(one_example_label)
plt.colorbar(pm,ax=axes[0])
pm = axes[1].imshow(y_preds[27])
plt.colorbar(pm,ax=axes[1],label='AR_prob')


plt.tight_layout()