# Kenya Vegetation Health Dataset

In [1]:
%load_ext autoreload
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

base_dir = Path(".").absolute().parents[0]
import sys

if sys.path[0] != base_dir.as_posix():
    sys.path = [base_dir.as_posix()] + sys.path

In [2]:
from spatio_temporal.config import Config
from spatio_temporal.training.trainer import Trainer
from spatio_temporal.training.tester import Tester
from tests.utils import (
    create_linear_ds,
    _test_sklearn_model,
    get_pollution_data_beijing,
    create_test_oxford_run_data
)
from spatio_temporal.training.eval_utils import _plot_loss_curves, save_losses

# Load in Data

In [3]:
ds = pickle.load((base_dir / "data/kenya.pkl").open("rb"))
ds

# Load in config file

In [13]:
cfg = Config(base_dir / "configs/kenya.yml")
cfg._cfg["n_epochs"] = 10
cfg._cfg["autoregressive"] = True
cfg._cfg["horizon"] = 1
cfg._cfg["seq_length"] = 3
cfg._cfg["device"] = "cpu"
cfg._cfg["scheduler"] = "cycle"
# cfg._cfg["input_variables"] = ["precip"]
cfg._cfg["target_variable"] = "VCI3M"
cfg

{'autoregressive': True,
 'batch_size': 100,
 'data_dir': PosixPath('data/kenya.pkl'),
 'device': 'cpu',
 'experiment_name': 'kenya',
 'hidden_size': 64,
 'horizon': 1,
 'input_variables': ['precip', 't2m', 'SMsurf'],
 'learning_rate': 0.001,
 'loss': 'huber',
 'n_epochs': 10,
 'num_workers': 4,
 'optimizer': 'Adam',
 'pixel_dims': ['lat', 'lon'],
 'run_dir': None,
 'scheduler': 'cycle',
 'seed': 1234,
 'seq_length': 3,
 'target_variable': 'VCI3M',
 'test_end_date': Timestamp('2020-12-31 00:00:00'),
 'test_start_date': Timestamp('2016-01-31 00:00:00'),
 'train_end_date': Timestamp('2015-12-31 00:00:00'),
 'train_start_date': Timestamp('2002-01-01 00:00:00'),
 'validation_end_date': Timestamp('2002-12-31 00:00:00'),
 'validation_start_date': Timestamp('2000-01-31 00:00:00')}

# Create trainer

In [None]:
trainer = Trainer(cfg, ds)
tester = Tester(cfg, ds)

  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
Loading Data:  93%|█████████▎| 1458/1575 [00:04<00:00, 341.80it/s]

In [None]:
losses = trainer.train_and_validate()

# Check losses of trained model

In [None]:
train_losses, valid_losses = losses
f, ax = plt.subplots()
ax.plot(train_losses, label="Train", color="C0", marker="x")
ax.plot(valid_losses, label="Validation", color="C1", marker="x")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.legend()

# Run Evaluation

In [None]:
cfg = Config(Path("runs/kenya_2302_221323/config.yml"))
tester = Tester(cfg, ds)

In [None]:
tester.run_test()

In [None]:
tester.test_dl.dataset.normalizer.mean_

In [None]:
preds.mean()
preds.std()

# The output forecasts! 

In [None]:
# TODO: convert back to lat, lon
# TODO: unnormalize

In [None]:
xr_path = sorted(list(cfg.run_dir.glob("*.nc")))[-1]
preds = xr.open_dataset(xr_path).isel(horizon=0).drop(["horizon"])
preds

In [None]:
# .sel(time=slice("2015-01-01", "2015-02-01"))
for i in np.arange(2):
    f, ax = plt.subplots(figsize=(12, 4))
    pixel = np.random.choice(preds.pixel.values)
    preds.sel(pixel=pixel).to_dataframe().plot(ax=ax)
    ax.set_title(pixel)
    plt.xticks(rotation=70);
    sns.despine()
# preds.isel(pixel=10, time=slice(0, 100)).to_dataframe().plot(ax=ax)