In [60]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

from src.preprocess import preprocess_folder
from src.format_data import create_dataset
from src.model import create_embedding_model
from src.losses import get_companion_std, keep_back

from tensorflow.keras.optimizers import Adam
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

%load_ext autoreload
%autoreload 2

In [61]:
cube, psf, rot_angles, table = preprocess_folder(root='./data/fake', 
												 target_folder='./data/fake/preprocessed')

table = table[table['snr'] > 3]
window_size = 15
dataset = create_dataset(cube, psf, rot_angles, table, window_size=window_size,
                         batch_size=2000, repeat=20)

model = create_embedding_model(window_size=window_size)

optimizer = Adam(1e-5)
model.compile(loss_fn=keep_back, optimizer=optimizer)

es = tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        min_delta=1e-4,
        patience=50,
        mode='min',
        restore_best_weights=True,
    )

In [62]:
# model.summary()
table

In [None]:
%%time
hist = model.fit(dataset, epochs=10000, callbacks=[es])

In [None]:
plt.plot(hist.history['loss'])

In [None]:
test_ds = create_dataset(cube, psf, rot_angles, table, window_size=window_size, batch_size=2000, repeat=1)

In [None]:
pred, params = model.predict(test_ds)

In [None]:
fig, axes = plt.subplots(1, 3)
print(np.median(params[-3]))
for x, y in test_ds:
    n = 0

    win_0 = tf.squeeze(x['windows'][n], axis=-1)
    win_1 = pred[n]
    
    axes[0].imshow(win_0)
    axes[1].imshow(win_1)
    residuals = tf.math.subtract(win_0, win_1)
    res_square = tf.pow(residuals, 2)
    print(res_square.shape)
    axes[2].set_title('{:.2f}'.format(tf.math.reduce_std(res_square)))
    axes[2].imshow(res_square)
    break
    
plt.show()