In [1]:
%%capture
!pip install xbatcher tensorboard

In [2]:
import os
if 'notebooks' in os.getcwd():
    os.chdir("..")

import util
import torch
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from importlib import reload



Prepare data

In [3]:
from google.cloud.storage import Client
client = Client(project="forest-lst")
df = util.gcs.read_gcs_csv(client, "preisler_tfdata", "preisler-rectangular")

In [4]:
ds = df.drop(["system:index", ".geo"], axis=1).set_index(["year", "latitude", "longitude"]).to_xarray()

In [5]:
ds

Make sure windowing works right

In [6]:
window = dict(latitude=[4, False], longitude=[4, False], year=[5, False])
mort_ds = util.training.WindowXarrayDataset(ds.pct_mortality, window)

In [7]:
len(mort_ds)

121750

In [8]:
X, y = mort_ds[10]

In [9]:
y.shape

torch.Size([4, 4])

In [10]:
mort_window = mort_ds._get_window(10)
window_x = mort_window.isel(year=slice(None, -1))
window_y = mort_window.isel(year=-1)
assert np.allclose(window_x.values, X)
assert np.allclose(window_y.values, y)

Set up model definition

In [11]:
reload(util.training)

<module 'util.training_torch' from '/home/jovyan/ForestLST/util/training_torch.py'>

In [12]:
class DamageConvLSTM(torch.nn.Module):
    '''
    Conv LSTM taking tensors of shape (N, T, C, H, W) and outputting (N, C, H, W)
    '''
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(DamageConvLSTM, self).__init__()
        self.convlstm = ConvLSTM(input_dim, hidden_dim, kernel_size, num_layers,
                                batch_first=batch_first, bias=bias, return_all_layers=return_all_layers)

        self.conv    = torch.nn.Conv2d(hidden_dim, 1, kernel_size=1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, X):
        # Add a channel axis
        X = X.unsqueeze(2)
        # Pass to convlstm
        X = self.convlstm(X)[1][0][0]
        # Convolve out the hidden dimensions
        X = self.conv(X)
        # Pass to sigmoid
        X = self.sigmoid(X)
        # Drop channel axis
        return X.squeeze(1)
        

In [13]:
from util.convlstm import ConvLSTM

input_channel = 1
hidden_dim = 16
num_layers = 3
kernel = (3, 3)

m = DamageConvLSTM(input_channel, hidden_dim, kernel, num_layers, batch_first=True) 

Verify that the shapes all work out.

In [14]:
m(X.unsqueeze(0)) # with a fake batch dimension

tensor([[[0.4584, 0.4583, 0.4583, 0.4585],
         [0.4585, 0.4582, 0.4582, 0.4583],
         [0.4585, 0.4582, 0.4582, 0.4583],
         [0.4585, 0.4582, 0.4581, 0.4582]]], grad_fn=<SqueezeBackward1>)

Data pipeline

In [15]:
# Prepare data laoders. Since we are using a 5-year lookback here, each dataset
# has to be temporally disjoint so we don't leak data. Since 2020 had very few
# surveys we have to ignore it.
years = ds.year.values
train_years = np.concatenate((years[:5], years[10:15]))
valid_years = years[5:10]
test_years  = years[15:20]

print("Training years:", train_years)
print("Validation years:", valid_years)
print("Testing years:", test_years)

Training years: [2000 2001 2002 2003 2004 2010 2011 2012 2013 2014]
Validation years: [2005 2006 2007 2008 2009]
Testing years: [2015 2016 2017 2018 2019]


In [16]:
# N.b. we could boost the number of samples by allowing up to 20% NA in each window and then
# replacing these values with zero in the dataloader.
train_wds = util.training.WindowXarrayDataset(ds.sel(year=train_years).pct_mortality, window)
valid_wds = util.training.WindowXarrayDataset(ds.sel(year=valid_years).pct_mortality, window)
test_wds  = util.training.WindowXarrayDataset(ds.sel(year= test_years).pct_mortality, window)

print("Count of examples")
print("Training:", len(train_wds))
print("Validation:", len(valid_wds))
print("Testing:", len(test_wds))

Count of examples
Training: 35820
Validation: 7304
Testing: 6588


In [17]:
from torch.utils.data import DataLoader

batch_size = 16

train_loader = DataLoader(train_wds, batch_size, shuffle=True)
valid_loader = DataLoader(valid_wds, batch_size, shuffle=True)
test_loader  = DataLoader(test_wds , batch_size, shuffle=True)

In [18]:
# Again make sure the sizes work out
X, y = tuple(next(iter(train_loader)))
print(X.shape)
out = m(X)
print(out.shape)
assert out.shape == y.shape

torch.Size([16, 4, 4, 4])
torch.Size([16, 4, 4])


Training loop

In [19]:
reload(util.training)

<module 'util.training_torch' from '/home/jovyan/ForestLST/util/training_torch.py'>

In [20]:
import torchmetrics

model_name = "convlstm__4_4_5__1channel__temporal"

def image_mse_loss(output, target):
    err = (output - target) ** 2
    return err.sum(0)

loss = torch.nn.MSELoss()
m = DamageConvLSTM(input_channel, 8, kernel, num_layers, batch_first=True)
opt  = torch.optim.Adam(m.parameters(), lr=0.005)
metrics = [
    torchmetrics.regression.ExplainedVariance(),
    #torchmetrics.regression.R2Score()
]
trainer = util.training.BaseTrainer(
    m, opt, loss, train_loader, valid_loader,
    metrics=metrics, 
    n_epochs=20,
    tensorboard_log=os.path.join("logs", model_name, "history"),
    model_log=os.path.join("logs", model_name, "model.pth"),
)

In [21]:
trainer.train()

Epoch 1/20
                               Train            Valid
ExplainedVariance()  tensor(-0.0570)  tensor(-0.0002)
Loss                        0.011833   tensor(0.0039)

Epoch 2/20
                               Train            Valid
ExplainedVariance()  tensor(-0.0015)  tensor(-0.0001)
Loss                        0.011209   tensor(0.0039)

Epoch 3/20
                               Train                Valid
ExplainedVariance()  tensor(-0.0013)  tensor(-6.7234e-05)
Loss                        0.011207       tensor(0.0039)

Epoch 4/20
                               Train                Valid
ExplainedVariance()  tensor(-0.0012)  tensor(-1.4134e-05)
Loss                        0.011206       tensor(0.0038)

Epoch 5/20
                               Train               Valid
ExplainedVariance()  tensor(-0.0010)  tensor(1.5818e-05)
Loss                        0.011204      tensor(0.0040)

Epoch 6/20
                               Train               Valid
ExplainedVariance()  tensor(-

In [23]:
r2  = torchmetrics.regression.ExplainedVariance()
mse = torchmetrics.regression.MeanSquaredError()

with torch.no_grad():
    for (X, y) in test_loader:
        y_hat = trainer._model(X)
        r2(y_hat, y)
        mse(y_hat, y)

print("Test results")
print("R2:", r2.compute())
print("MSE:", mse.compute())

Test results
R2: tensor(0.1393)
MSE: tensor(0.0227)
