In [1]:
%load_ext autoreload
%load_ext line_profiler
%autoreload 2

In [2]:
import os
import sys
import numpy as np
import numexpr as ne
import pandas as pd
import torch
import gpytorch
from gpytorch import means, kernels, likelihoods, distributions, lazy
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import xarray as xr
import tqdm
import utils as utils


base_dir = os.path.join(os.getcwd(), '..')
sys.path.append(base_dir)

from src.fair import run, get_params
from src.fair.ancil import get_gas_params, get_thermal_params
from src.preprocessing import load_emissions_dataset, load_response_dataset
from src.structures import Scenario, ScenarioDataset
import utils_spatial as spatial

<IPython.core.display.Javascript object>

In [3]:
train_keys = ['historical', 'ssp126', 'ssp370', 'ssp585']
test_keys = ['ssp245']
keys = train_keys + test_keys
inputs = {key: load_emissions_dataset(f'../data/inputs_{key}.nc') for key in keys}
outputs = {key: load_response_dataset(f'../data/outputs_{key}.nc') for key in keys}

In [4]:
# Make non-spatial scenarios
def make_scenario(name, hist_scenario=None):
    time, _, emission, tas = utils.extract_arrays(inputs[name], outputs[name])
    scenario = Scenario(name=name,
                        timesteps=torch.from_numpy(time).float(),
                        emissions=torch.from_numpy(emission).float().T,
                        tas=torch.from_numpy(tas).float(),
                        hist_scenario=hist_scenario)
    return scenario

hist_scenario = make_scenario('historical')
ssps = ['ssp126', 'ssp245', 'ssp370', 'ssp585']
scenarios = {'historical': hist_scenario}
for name in ssps:
    scenario = make_scenario(name, hist_scenario)
    scenarios[name] = scenario
    
train_scenarios = ScenarioDataset(scenarios=list([scenarios[key] for key in train_keys]),
                                  hist_scenario=hist_scenario)
test_scenarios = ScenarioDataset(scenarios=list([scenarios[key] for key in test_keys]),
                                 hist_scenario=hist_scenario)

In [5]:
# Compute global FaIR forcing time serie Fdet
def compute_forcings(scenario_dataset):
    base_kwargs = get_params()
    forcings = dict()
    for name, scenario in scenario_dataset.scenarios.items():
        res = run(scenario.full_timesteps.numpy(),
                  scenario.full_emissions.T.numpy(),
                  base_kwargs)
        Fdet = np.sum(res['RF'].T, axis=-1)
        Fdet = scenario.trim_hist(Fdet)
        forcings.update({scenario: torch.from_numpy(Fdet).float()})
    return forcings

train_Fdet = compute_forcings(train_scenarios)
test_Fdet = compute_forcings(test_scenarios)

In [6]:
# Create dummy grids of dj and qj
base_kwargs = get_params()

d = base_kwargs['d']
q = base_kwargs['q']

d_map = np.tile(d, 96 * 144).reshape(96, 144, 3)
q_map = np.tile(q, 96 * 144).reshape(96, 144, 3)

In [7]:
# Compute spatialised mean thermal responses based on Fdet
def step_temperature(S_old, F, q, d, dt=1):
    decay_factor = ne.evaluate("exp(-dt/d)")  # noqa: F841
    S_new = ne.evaluate("q * F * (1 - decay_factor) + S_old * decay_factor")
    T = ne.evaluate("sum( (S_old + S_new)/2, axis=0 )")
    return S_new, T

def compute_spatial_mjs(Fdet, d_map, q_map, timesteps):
    n_timesteps = len(timesteps)
    S = np.zeros((n_timesteps,) + d_map.shape)
    dt = timesteps[0]
    for i, tstep in enumerate(timesteps):
        dt = max(1, tstep - dt)
        S[i], _ = step_temperature(S[max(i - 1, 0)], Fdet[i], q_map, d_map, dt)
        dt = tstep
    return S

Fdet = train_Fdet[train_scenarios[0]]
mjs = compute_spatial_mjs(Fdet, d_map, q_map, train_scenarios[0].timesteps) # time*lat*lon*nbox

In [8]:
# Make spatial scenarios
spatial_hist_scenario = spatial.make_scenario(inputs, outputs, 'historical')
ssps = ['ssp126', 'ssp245', 'ssp370', 'ssp585']
spatial_scenarios = {'historical': spatial_hist_scenario}
for name in ssps:
    scenario = spatial.make_scenario(inputs, outputs, name, spatial_hist_scenario)
    spatial_scenarios[name] = scenario
    
spatial_train_scenarios = ScenarioDataset(scenarios=list([spatial_scenarios[key] for key in train_keys]),
                                          hist_scenario=spatial_hist_scenario)
spatial_test_scenarios = ScenarioDataset(scenarios=list([spatial_scenarios[key] for key in test_keys]),
                                         hist_scenario=spatial_hist_scenario)

In [9]:
# Compute forcing kernel on spatial emission maps
rff = kernels.RFFKernel(num_samples=50)

In [10]:
mu = spatial_train_scenarios.mu_emissions
sigma = spatial_train_scenarios.sigma_emissions
scenario_emissions_std = (spatial_train_scenarios[0].full_emissions - mu) / sigma
dataset_emissions_std = (spatial_train_scenarios.full_emissions - mu) / sigma

In [11]:
foo = dataset_emissions_std.reshape(-1, 4)
bar = scenario_emissions_std.reshape(-1, 4)
# K = rff(foo, bar).evaluate()

In [None]:
scenario_dataset = ?
kernel = ?

mu, sigma = scenario_dataset.mu_inputs, scenario_dataset.sigma_inputs
dataset_emissions_std = (scenario_dataset.full_emissions - mu) / sigma
flat_dataset_emissions_std = dataset_emissions_std.reshape(-1, dataset_emissions_std.size(-1))
flat_dataset_size = flat_dataset_emissions_std.size(0)


for scenario in scenario_dataset.scenarios.values():
    scenario_emissions_std = (scenario.full_emissions - mu) / sigma
    flat_scenario_emissions_std = scenario_emissions_std.size(-1, scenario_emissions_std.size(-1))
    flat_scenario_size = flat_scenario_emissions_std.size(0)

    I_old = torch.zeros((flat_dataset_size, flat_scenario_size, len(d)))
    Kj = ?

    for t in range(1, len(scenario_emission_std)):
        flat_scenario_emissions_std_t = flat_scenario_emissions_std[t]
        K_new = kernel(flat_dataset_emissions_std, flat_scenario_emissions_std_t)
        I_new = step_I(I_old, K_new, d?)
        I_old = I_new.squeeze()
        

In [14]:
def compute_I(scenario_dataset, kernel, d):
    I = [compute_I_scenario(scenario_dataset, scenario, kernel, d)
         for scenario in scenario_dataset.scenarios.values()]
    I = torch.cat(I, dim=-2)
    return I


def compute_I_scenario(scenario_dataset, scenario, kernel, d):
    mu, sigma = scenario_dataset.mu_inputs, scenario_dataset.sigma_inputs
    scenario_emissions_std = (scenario.full_inputs - mu) / sigma
    dataset_emissions_std = (scenario_dataset.full_inputs - mu) / sigma

    K = kernel(dataset_emissions_std, scenario_emissions_std).evaluate().unsqueeze(-1)
    I = torch.zeros((K.size(0), K.size(1), len(d)))
    for t in range(1, len(scenario_emissions_std)):
        I_old = I[:, t - 1]
        K_new = K[:, t]
        I_new = step_I(I_old, K_new, d.unsqueeze(0))
        I[:, t] = I_new.squeeze()
    return I


def compute_covariance(scenario_dataset, I, q, d):
    Kj = [compute_covariance_scenario(scenario_dataset, scenario, I, q, d)
          for scenario in scenario_dataset.scenarios.values()]
    Kj = torch.cat(Kj, dim=-2)
    Kj = scenario_dataset.trim_hist(Kj)
    return Kj


def compute_covariance_scenario(scenario_dataset, scenario, I, q, d):
    I_scenario = I[scenario_dataset.full_slices[scenario.name]]
    Kj = torch.zeros_like(I_scenario)
    for t in range(1, I_scenario.size(0)):
        Kj_old = Kj[t - 1]
        I_new = I_scenario[t]
        Kj_new = step_kernel(Kj_old, I_new, q.unsqueeze(0), d.unsqueeze(0))
        Kj[t] = Kj_new
    Kj = scenario.trim_hist(Kj)
    return Kj.permute(1, 0, 2)

torch.Size([918, 96, 144, 4])

In [15]:
scenario_emissions_std.shape

torch.Size([165, 96, 144, 4])

In [16]:
foo.shape

torch.Size([12690432, 4])

In [17]:
bar.shape

torch.Size([2280960, 4])

What do I need to do to run a simple experiment:
- Take the usual GP model
- ~Can I use ScenarioDataset? PRobably not, that's where I need to start... because this is going to be useful either way~
- Allow d and q to be functions
- Ugh, is it going to be compatible with all the subfunctions? Probably not, gonna have to be careful with that
- ~Can evaluate d, q than flatten into 3 * n_pixel boxes and then reshape. Good trick to avoid coding too much in case it fails (or even as a first attempt)~ -> cannot use the code as is because it won't scale... maybe only some latitudes/longitudes at first