In [None]:
%load_ext autoreload
%autoreload 2

import pickle
import yaml
import copy
import os

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

from netCDF4 import Dataset

import wandb

from dataset import PINNDataset
from model import PINN
from visualization import GeneratorVisualizationCallback

In [None]:
matplotlib.use('svg')
figure_params = { "font.family": 'Times', "font.size": 12, "font.serif": [], "svg.fonttype": 'none'}
matplotlib.rcParams.update(figure_params)

## Load Data and Saved Model

In [None]:
api = wandb.Api()
run_tag = "33b55t6g" #model loss: "33b55t6g" no model loss: "7dj3sms2"
run = api.run(f"teisberg/igarss2021/{run_tag}")
for file in run.files():
    if file.name.startswith('gen_model/'):
        file.download(root='downloaded_model_1d/', replace=True)
    if file.name == 'config.yaml':
        file.download(root='downloaded_model_1d/', replace=True)

In [None]:
model_path = 'downloaded_model_1d/gen_model'
parameter_yaml_filename = 'downloaded_model_1d/config.yaml'

#
# Parameters
#

with open(parameter_yaml_filename) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
for k in config:
    if isinstance(config[k], dict):
        config[k] = config[k]['value']

config['wandb'] = False
print(config)

#
# Data Loading
#

with open(config['input_data_filename'], 'rb') as f:
    data = pickle.load(f)

dataset = PINNDataset(data, batch_size=config['batch_size'],
                        mode=config['mode'], n_random=config['n_random_points'])

#
# Create model
#

model = PINN(config, dataset, gen_model_filename=model_path)

## Whole Domain

In [None]:
visualizer_callback = GeneratorVisualizationCallback(config, dataset, model)

roi = copy.copy(config['eval_regions'][0])
print(roi['title'])
roi['spacing'] = 100

x, _, pred = visualizer_callback.roi_prediction(roi)

fig, ax = plt.subplots(figsize=(6,2))
ax.scatter(dataset.r['x'], dataset.r['h'], label='Measurements', s=8)
ax.plot(data['truth']['x'], data['truth']['h'], '--', label='True')
ax.plot(x, pred['h'], label='Predicted')

ax.set_xlabel('Distance along flowline [m]')
ax.set_ylabel('Ice thickness [m]')

ax.legend()

rect = matplotlib.patches.Rectangle((1800, 725),
                                        2000, 75,
                                        linewidth=1, edgecolor='red', facecolor='none')
ax.add_patch(rect)

fig.savefig(f'figures/results-1d-{run_tag}.svg', format='svg', dpi=1000)

In [None]:
fig, ax = plt.subplots(figsize=(2,1))
ax.scatter(dataset.r['x'], dataset.r['h'], label='Measurements', s=8)
ax.plot(data['truth']['x'], data['truth']['h'], '--', label='True')
ax.plot(x, pred['h'], label='Predicted')
ax.set_xlim(1800,3800)
ax.set_ylim(725,800)

fig.savefig(f'figures/results-1d-{run_tag}-zoom.svg', format='svg', dpi=1000)