In [1]:
import os
import logging

import numpy as np
import xarray as xr
import torch

from tqdm import tqdm_notebook

import pytassim
from pytassim.generator import observation_generator

In [2]:
logger = logging.getLogger()
rnd = np.random.RandomState(42)
data_path = '/scratch/local1/Data/neural_nets/neural_assim/data/'

# General settings

In [3]:
lead_time = 2

# Load the data

In [4]:
train_vr1_path = os.path.join(data_path, 'train_vr1.nc')
train_ens_path = os.path.join(data_path, 'train_ens.nc')
test_vr1_path = os.path.join(data_path, 'test_vr1.nc')
test_ens_path = os.path.join(data_path, 'test_ens.nc')

In [5]:
train_ens = xr.open_dataarray(train_ens_path)
train_vr1 = xr.open_dataarray(train_vr1_path)
test_ens = xr.open_dataarray(test_ens_path)
test_vr1 = xr.open_dataarray(test_vr1_path)

## Obs settings

In [6]:
obs_bias = 0
obs_std = 0.5
obs_var = np.sqrt(obs_std)
obs_random = rnd.normal
obs_indices = [6, 17, 25, 36, 13]
obs_timestep = 1

In [7]:
train_obs = observation_generator(
    train_vr1, obs_random, obs_indices, obs_timestep,
    time_axis='time', loc=obs_bias, scale=obs_std
)
test_obs = observation_generator(
    test_vr1, obs_random, obs_indices, obs_timestep,
    time_axis='time', loc=obs_bias, scale=obs_std
)

100%|██████████| 5/5 [00:00<00:00, 88.21it/s]
100%|██████████| 5/5 [00:00<00:00, 215.29it/s]


In [8]:
obs_cov = xr.DataArray(
    data=np.sqrt(obs_std)*np.identity(len(obs_indices)),
    coords={
        'obs_grid_1': train_obs.obs_grid_1.values,
        'obs_grid_2': train_obs.obs_grid_1.values,
    },
    dims=('obs_grid_1', 'obs_grid_2')
)

train_obs = xr.Dataset(
    {
        'observations': train_obs,
        'covariance': obs_cov
    }
)

test_obs = xr.Dataset(
    {
        'observations': test_obs,
        'covariance': obs_cov
    }
)

In [9]:
def observation_operator(self, state):
    pseudo_obs = state.sel(var_name='x')
    # Nearest neighbor interpolation
    pseudo_obs = pseudo_obs.sel(
        time=self.ds.time.values,
        grid=self.ds.obs_grid_1.values,
        method='nearest')
    pseudo_obs = pseudo_obs.rename(grid='obs_grid_1')
    pseudo_obs['time'] = self.ds.time.values
    pseudo_obs['obs_grid_1'] = self.ds.obs_grid_1.values
    return pseudo_obs

In [10]:
train_hx = train_ens.sel(time=lead_time).drop('time').rename(analysis='time')
train_hx['time'] += lead_time

test_hx = test_ens.sel(time=lead_time).drop('time').rename(analysis='time')
test_hx['time'] += lead_time

In [11]:
train_obs = train_obs.sel(time=train_hx.time.values[:-1])
test_obs = test_obs.sel(time=test_hx.time.values[:-1])

In [12]:
train_obs.obs.operator = observation_operator
test_obs.obs.operator = observation_operator

In [13]:
train_hx = train_obs.obs.operator(train_hx)
test_hx = test_obs.obs.operator(test_hx)

In [14]:
train_mean, train_perts = train_hx.state.split_mean_perts()
test_mean, test_perts = test_hx.state.split_mean_perts()

In [15]:
train_innov = train_obs['observations'] - train_mean
test_innov = test_obs['observations'] - test_mean

# Define model etc

In [16]:
class VAETKF(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gain_wo_op = torch.nn.Linear(
            len(obs_indices), len(obs_indices)
        )
        
    def forward(self, innov, hx_perts):
        ana_mean = torch.matmul(self.gain_wo_op(innov), torch.transpose(hx_perts, dim0=1, dim1=2))
        recon_innov = torch.matmul(ana_mean, hx_perts)
        return ana_mean, recon_innov

In [17]:
model = VAETKF().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [18]:
def loss_function(recon_innov, innov, ana_mean):
    mse_scaled = torch.mean(torch.sum(innov - recon_innov, dim=1)**2 / obs_var)
    kl_scaled = torch.mean(49 * torch.sum(ana_mean, dim=1)**2)
    return mse_scaled + kl_scaled

# Train model

In [19]:
BATCH_SIZE = 64
EPOCHS = 35
train_samples = len(train_innov.time)
nr_iters_p_epoch = train_samples//BATCH_SIZE

In [20]:
train_innov_t = torch.tensor(train_innov.values).float().cuda()
test_innov_t = torch.tensor(test_innov.values).float().cuda()

In [21]:
train_perts_t = torch.tensor(train_perts.values).float().cuda()
test_perts_t = torch.tensor(test_perts.values).float().cuda()

In [24]:
for e in tqdm_notebook(range(EPOCHS)):
    np_ind = rnd.choice(
        train_samples, size=nr_iters_p_epoch*BATCH_SIZE
    )
    train_innov_epoch = train_innov_t[np_ind]
    train_perts_epoch = train_perts_t[np_ind]
    running_loss = 0
    for it in range(nr_iters_p_epoch):
        train_innov_batch = train_innov_epoch[it*BATCH_SIZE:(it+1)*BATCH_SIZE]
        train_perts_batch = train_perts_epoch[it*BATCH_SIZE:(it+1)*BATCH_SIZE]
        
        optimizer.zero_grad()
        ana_mean, recon_innov = model(train_innov_batch, train_perts_batch)
        loss = loss_function(recon_innov, train_innov_batch, ana_mean)
        loss.backward()
        optimizer.step()
    test_ana_mean, test_recon_innov = model(test_innov_t, test_perts_t)
    old_loss = loss.item()
    val_loss = loss_function(test_recon_innov, test_innov_t, test_ana_mean)
    new_loss = val_loss.item()
    print(old_loss)
    print(new_loss)
    print('Finished epoch {0:d}, loss:{1:.2f}'.format(e+1, new_loss))
    if new_loss > old_loss:
        break

HBox(children=(IntProgress(value=0, max=35), HTML(value='')))

18328.435546875
2096.043212890625
Finished epoch 1, loss:2096.04
4789.330078125
1372.94384765625
Finished epoch 2, loss:1372.94
861.02685546875
1625.4134521484375
Finished epoch 3, loss:1625.41



In [33]:
model(test_innov_t, test_perts_t)[0].size()

torch.Size([400, 400, 50])

In [31]:
w_mean.shape

torch.Size([400, 50])