In [None]:
from collections import OrderedDict
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.optim.lr_scheduler import ExponentialLR
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics import Metric
import pytorch_lightning as pl
from scipy.stats import dirichlet
import numpy as np

def _absolute_error_update(
    preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor
) -> Tensor:
    _check_same_shape(preds, target)
    diff = torch.abs(preds - target)
    sum_abs_error = torch.sum(diff * diff * area, axis=1)
    absolute_error = torch.sum(sum_abs_error * omegas.squeeze())
    return absolute_error


def _absolute_error_compute(absolute_error) -> Tensor:
    return absolute_error


def absolute_error(
    preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor
) -> Tensor:
    """
    Computes squared absolute error
    Args:
        preds: estimated labels
        target: ground truth labels
        omegas: weights
        area: area of each cell
    Return:
        Tensor with absolute error
    Example:
        >>> x = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]).T
        >>> y = torch.tensor([[0, 1, 2, 1], [2, 3, 4, 4]]).T
        >>> o = torch.tensor([0.25, 0.25, 0.3, 0.2])
        >>> a = torch.tensor([0.25, 0.25])
        >>> absolute_error(x, y, o, a)
        tensor(0.4000)
    """
    sum_abs_error = _absolute_error_update(preds, target, omegas, area)
    return _absolute_error_compute(sum_abs_error)


class AbsoluteError(Metric):
    def __init__(self, compute_on_step: bool = True, dist_sync_on_step=False):
        # call `self.add_state`for every internal state that is needed for the metrics computations
        # dist_reduce_fx indicates the function that should be used to reduce
        # state from multiple processes
        super().__init__(
            compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step
        )

        self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor):
        """
        Update state with predictions and targets, and area.
        Args:
            preds: Predictions from model
            target: Ground truth values
            omegas: Weights
            area: Area of each cell
        """
        sum_abs_error = _absolute_error_update(preds, target, omegas, area)
        self.sum_abs_error += sum_abs_error

    def compute(self):
        """
        Computes absolute error over state.
        """
        return _absolute_error_compute(self.sum_abs_error)

    @property
    def is_differentiable(self):
        return True


class NNEmulator(pl.LightningModule):
    def __init__(
        self,
        n_parameters,
        n_eigenglaciers,
        V_hat,
        F_mean,
        area,
        hparams,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters(hparams)
        n_hidden_1 = self.hparams.n_hidden_1
        n_hidden_2 = self.hparams.n_hidden_2
        n_hidden_3 = self.hparams.n_hidden_3
        n_hidden_4 = self.hparams.n_hidden_4

        # Inputs to hidden layer linear transformation
        self.l_1 = nn.Linear(n_parameters, n_hidden_1)
        self.norm_1 = nn.LayerNorm(n_hidden_1)
        self.dropout_1 = nn.Dropout(p=0.0)
        self.l_2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.norm_2 = nn.LayerNorm(n_hidden_2)
        self.dropout_2 = nn.Dropout(p=0.5)
        self.l_3 = nn.Linear(n_hidden_2, n_hidden_3)
        self.norm_3 = nn.LayerNorm(n_hidden_3)
        self.dropout_3 = nn.Dropout(p=0.5)
        self.l_4 = nn.Linear(n_hidden_3, n_hidden_4)
        self.norm_4 = nn.LayerNorm(n_hidden_3)
        self.dropout_4 = nn.Dropout(p=0.5)
        self.l_5 = nn.Linear(n_hidden_4, n_eigenglaciers)

        self.V_hat = torch.nn.Parameter(V_hat, requires_grad=False)
        self.F_mean = torch.nn.Parameter(F_mean, requires_grad=False)

        self.register_buffer("area", area)

        self.train_ae = AbsoluteError()
        self.test_ae = AbsoluteError()

    def forward(self, x, add_mean=False):
        # Pass the input tensor through each of our operations

        a_1 = self.l_1(x)
        a_1 = self.norm_1(a_1)
        a_1 = self.dropout_1(a_1)
        z_1 = torch.relu(a_1)

        a_2 = self.l_2(z_1)
        a_2 = self.norm_2(a_2)
        a_2 = self.dropout_2(a_2)
        z_2 = torch.relu(a_2) + z_1

        a_3 = self.l_3(z_2)
        a_3 = self.norm_3(a_3)
        a_3 = self.dropout_3(a_3)
        z_3 = torch.relu(a_3) + z_2

        a_4 = self.l_4(z_3)
        a_4 = self.norm_3(a_4)
        a_4 = self.dropout_3(a_4)
        z_4 = torch.relu(a_4) + z_3

        z_5 = self.l_5(z_4)
        if add_mean:
            F_pred = z_5 @ self.V_hat.T + self.F_mean
        else:
            F_pred = z_5 @ self.V_hat.T

        return F_pred

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("NNEmulator")
        parser.add_argument("--batch_size", type=int, default=128)
        parser.add_argument("--n_hidden_1", type=int, default=128)
        parser.add_argument("--n_hidden_2", type=int, default=128)
        parser.add_argument("--n_hidden_3", type=int, default=128)
        parser.add_argument("--n_hidden_4", type=int, default=128)
        parser.add_argument("--learning_rate", type=float, default=0.01)

        return parent_parser

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), self.hparams.learning_rate, weight_decay=0.0
        )
        # This is an approximation to Doug's version:
        scheduler = {
            "scheduler": ExponentialLR(optimizer, 0.9975, verbose=True),
        }

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, f, o, _ = batch
        f_pred = self.forward(x)
        loss = absolute_error(f_pred, f, o, self.area)

        return loss

    def validation_step(self, batch, batch_idx):
        x, f, o, o_0 = batch
        f_pred = self.forward(x)

        self.log("train_loss", self.train_ae(f_pred, f, o, self.area))
        self.log("test_loss", self.test_ae(f_pred, f, o_0, self.area))

        return {"x": x, "f": f, "f_pred": f_pred, "o": o, "o_0": o_0}

    def validation_epoch_end(self, outputs):

        self.log(
            "train_loss",
            self.train_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "test_loss",
            self.test_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )


# As NNEmulator but number of hidden layers can be specified 
# with n_hidden_layers
class DNNEmulator(pl.LightningModule):
    def __init__(
        self,
        n_parameters: int,
        n_eigenglaciers: int,
        V_hat: Tensor,
        F_mean: Tensor,
        area: Tensor,
        hparams,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters(hparams)
        n_layers = self.hparams.n_layers
        n_hidden = self.hparams.n_hidden

        if isinstance(n_hidden, int):
            n_hidden = [n_hidden] * (n_layers - 1)

        # Inputs to hidden layer linear transformation
        self.l_first = nn.Linear(n_parameters, n_hidden[0])
        self.norm_first = nn.LayerNorm(n_hidden[0])
        self.dropout_first = nn.Dropout(p=0.0)

        models = []
        for n in range(n_layers - 2):
            models.append(
                nn.Sequential(
                    OrderedDict(
                        [
                            ("Linear", nn.Linear(n_hidden[n], n_hidden[n + 1])),
                            ("LayerNorm", nn.LayerNorm(n_hidden[n + 1])),
                            ("Dropout", nn.Dropout(p=0.1)),
                        ]
                    )
                )
            )
        self.dnn = nn.ModuleList(models)
        self.l_last = nn.Linear(n_hidden[-1], n_eigenglaciers)

        self.V_hat = torch.nn.Parameter(V_hat, requires_grad=False)
        self.F_mean = torch.nn.Parameter(F_mean, requires_grad=False)

        self.register_buffer("area", area)

        self.train_ae = AbsoluteError()
        self.test_ae = AbsoluteError()

    def forward(self, x, add_mean=False):
        # Pass the input tensor through each of our operations

        a = self.l_first(x)
        a = self.norm_first(a)
        a = self.dropout_first(a)
        z = torch.relu(a)

        for dnn in self.dnn:
            a = dnn(z)
            z = torch.relu(a) + z

        z_last = self.l_last(z)

        if add_mean:
            F_pred = z_last @ self.V_hat.T + self.F_mean
        else:
            F_pred = z_last @ self.V_hat.T

        return F_pred

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("NNEmulator")
        parser.add_argument("--batch_size", type=int, default=128)
        parser.add_argument("--n_hidden", default=128)
        parser.add_argument("--learning_rate", type=float, default=0.01)

        return parent_parser

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), self.hparams.learning_rate, weight_decay=0.0
        )
        # This is an approximation to Doug's version:
        scheduler = {
            "scheduler": ExponentialLR(optimizer, 0.9975, verbose=True),
        }

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, f, o, _ = batch
        f_pred = self.forward(x)
        loss = absolute_error(f_pred, f, o, self.area)

        return loss

    def validation_step(self, batch, batch_idx):
        x, f, o, o_0 = batch
        f_pred = self.forward(x)

        self.log("train_loss", self.train_ae(f_pred, f, o, self.area))
        self.log("test_loss", self.test_ae(f_pred, f, o_0, self.area))

        return {"x": x, "f": f, "f_pred": f_pred, "o": o, "o_0": o_0}

    def validation_epoch_end(self, outputs):

        self.log(
            "train_loss",
            self.train_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "test_loss",
            self.test_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )


        
max_epochs = 100
num_workers = 4
hparams = {"n_hidden": 128, 
           "n_hidden_1": 128, 
           "n_hidden_2": 128, 
           "n_hidden_3": 128, 
           "n_hidden_4": 128, 
           "n_layers": 5,
           "learning_rate": 0.01}        

n_eigenglaciers = 100
n_samples = 979
n_parameters = 8
n_grid_points = 5097
X_train = torch.randn(n_samples, n_parameters)
Y_train = torch.randn(n_samples, n_grid_points)
V_hat = torch.randn(n_grid_points, n_eigenglaciers)
F_mean = torch.randn(n_grid_points)
area = torch.ones_like(F_mean) / n_grid_points

omegas = torch.Tensor(dirichlet.rvs(np.ones(n_samples))).T
omegas = omegas.type_as(X_train)
omegas_0 = torch.ones_like(omegas) / len(omegas)

training_data = TensorDataset(X_train, Y_train, omegas, omegas_0)

batch_size = 128
train_loader = torch.utils.data.DataLoader(dataset=training_data,
                                           batch_size=batch_size,
                                           shuffle=True)

# train and val data loader are the same because we use BayesBag/Bootstrapping to avoid overfitting
# by generating 50 emulators, each with different weights "omegas"

trainer_e = pl.Trainer(
    deterministic=True,
    num_sanity_val_steps=0,
    max_epochs=max_epochs,
)

e = NNEmulator(
    n_parameters,
    n_eigenglaciers,
    V_hat,
    F_mean,
    area,
    hparams,
)

trainer_e.fit(e, train_loader, train_loader)

trainer_de = pl.Trainer(
    deterministic=True,
    num_sanity_val_steps=0,
    max_epochs=max_epochs,
)

de = DNNEmulator(
    n_parameters,
    n_eigenglaciers,
    V_hat,
    F_mean,
    area,
    hparams,
)

trainer_de.fit(de, train_loader, train_loader)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

   | Name      | Type          | Params
---------------------------------------------
0  | l_1       | Linear        | 1.2 K 
1  | norm_1    | LayerNorm     | 256   
2  | dropout_1 | Dropout       | 0     
3  | l_2       | Linear        | 16.5 K
4  | norm_2    | LayerNorm     | 256   
5  | dropout_2 | Dropout       | 0     
6  | l_3       | Linear        | 16.5 K
7  | norm_3    | LayerNorm     | 256   
8  | dropout_3 | Dropout       | 0     
9  | l_4       | Linear        | 16.5 K
10 | norm_4    | LayerNorm     | 256   
11 | dropout_4 | Dropout       | 0     
12 | l_5       | Linear        | 12.9 K
13 | train_ae  | AbsoluteError | 0     
14 | test_ae   | AbsoluteError | 0     
---------------------------------------------
64.6 K    Trainable params
514 K     Non-trainable params
579 K     Total params
2.318     Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-02.
Epoch 0:  44%|███████████████████████████████████████████████████████▌                                                                       | 7/16 [00:00<00:00, 106.33it/s, loss=20.5, v_num=14]Adjusting learning rate of group 0 to 9.9750e-03.
Epoch 0:  50%|███████████████████████████████████████████████████████████████▌                                                               | 8/16 [00:00<00:00, 110.99it/s, loss=18.5, v_num=14]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 124.20it/s, loss=18.5, v_num=14, train_loss=32.50, test_loss=32.40][A
Epoch 1:  44%|████████████████████████████████████████▎                                                   | 7/16 [00:00<00:00, 112.86it/s, loss=12.3, v_num=14, train_loss=32.50, test_loss=32.40][AAdjusting learning rate of group 0 to 9.9501e-03.
Epoch 1:  50%|██████████████████

Epoch 12:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 117.17it/s, loss=0.246, v_num=14, train_loss=1.310, test_loss=1.320][AAdjusting learning rate of group 0 to 9.6798e-03.
Epoch 12:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 121.35it/s, loss=0.241, v_num=14, train_loss=1.310, test_loss=1.320]
Validating: 0it [00:00, ?it/s][A
Epoch 12: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 129.86it/s, loss=0.241, v_num=14, train_loss=1.280, test_loss=1.290][A
Epoch 13:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 119.09it/s, loss=0.231, v_num=14, train_loss=1.280, test_loss=1.290][AAdjusting learning rate of group 0 to 9.6556e-03.
Epoch 13:  50%|█████████████████████████████████████████████                   

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Adjusting learning rate of group 0 to 1.0000e-02.



  | Name          | Type          | Params
------------------------------------------------
0 | l_first       | Linear        | 1.2 K 
1 | norm_first    | LayerNorm     | 256   
2 | dropout_first | Dropout       | 0     
3 | dnn           | ModuleList    | 50.3 K
4 | l_last        | Linear        | 12.9 K
5 | train_ae      | AbsoluteError | 0     
6 | test_ae       | AbsoluteError | 0     
------------------------------------------------
64.6 K    Trainable params
514 K     Non-trainable params
579 K     Total params
2.318     Total estimated model params size (MB)


Epoch 0:  44%|███████████████████████████████████████████████████████▌                                                                       | 7/16 [00:00<00:00, 112.30it/s, loss=18.7, v_num=15]Adjusting learning rate of group 0 to 9.9750e-03.
Epoch 0:  50%|███████████████████████████████████████████████████████████████▌                                                               | 8/16 [00:00<00:00, 117.79it/s, loss=16.9, v_num=15]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 119.55it/s, loss=16.9, v_num=15, train_loss=34.40, test_loss=34.50][A
Epoch 1:  44%|████████████████████████████████████████▎                                                   | 7/16 [00:00<00:00, 122.04it/s, loss=10.8, v_num=15, train_loss=34.40, test_loss=34.50][AAdjusting learning rate of group 0 to 9.9501e-03.
Epoch 1:  50%|██████████████████████████████████████████████                      

Epoch 12:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 113.10it/s, loss=0.282, v_num=15, train_loss=1.420, test_loss=1.440][AAdjusting learning rate of group 0 to 9.6798e-03.
Epoch 12:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 117.88it/s, loss=0.273, v_num=15, train_loss=1.420, test_loss=1.440]
Validating: 0it [00:00, ?it/s][A
Epoch 12: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 134.33it/s, loss=0.273, v_num=15, train_loss=1.370, test_loss=1.400][A
Epoch 13:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 121.22it/s, loss=0.265, v_num=15, train_loss=1.370, test_loss=1.400][AAdjusting learning rate of group 0 to 9.6556e-03.
Epoch 13:  50%|█████████████████████████████████████████████                   

Epoch 24:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 107.29it/s, loss=0.192, v_num=15, train_loss=1.170, test_loss=1.190][AAdjusting learning rate of group 0 to 9.3934e-03.
Epoch 24:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 111.47it/s, loss=0.188, v_num=15, train_loss=1.170, test_loss=1.190]
Validating: 0it [00:00, ?it/s][A
Epoch 24: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 130.46it/s, loss=0.188, v_num=15, train_loss=1.160, test_loss=1.180][A
Epoch 25:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 115.36it/s, loss=0.183, v_num=15, train_loss=1.160, test_loss=1.180][AAdjusting learning rate of group 0 to 9.3699e-03.
Epoch 25:  50%|█████████████████████████████████████████████                   

In [22]:
Y_pred_e = e(X_train, add_mean=True)
Y_pred_de = de(X_train, add_mean=True)
print(torch.allclose(Y_pred_e, Y_pred_de, rtol=1e-3))

False


In [23]:
Y_pred_e

tensor([[ 0.5687, -0.7332,  0.2741,  ..., -1.4356, -2.9116,  1.3660],
        [ 0.1260, -0.3927,  0.9446,  ..., -0.5491, -2.8917,  1.5394],
        [ 1.7578, -1.1101,  1.3278,  ..., -3.0206, -3.5944,  2.6618],
        ...,
        [ 0.7693, -0.9358,  0.5076,  ..., -2.9734, -3.7563,  2.7836],
        [ 0.0236, -1.8986, -0.9425,  ..., -1.8011, -2.8463,  2.2073],
        [ 0.6712, -0.8488,  0.3387,  ..., -2.4844, -3.3133,  1.8092]],
       grad_fn=<AddBackward0>)

In [24]:
Y_pred_de

tensor([[ 0.5963, -0.7897,  1.4969,  ..., -1.6196, -2.9901,  4.3355],
        [ 1.5535,  0.0247, -0.0095,  ..., -1.5606, -4.7245,  2.0926],
        [ 0.0278, -1.4039,  1.0444,  ..., -1.6318, -3.6089,  2.5104],
        ...,
        [ 1.0046, -0.8259, -2.0149,  ..., -0.6841, -3.9244,  1.2808],
        [ 1.0927,  0.3701, -0.9025,  ..., -1.2624, -2.7328,  3.2205],
        [ 1.8100,  0.0811,  1.1136,  ..., -2.6391, -3.6600,  2.6641]],
       grad_fn=<AddBackward0>)