In [1]:
from collections import OrderedDict
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import TensorDataset
from torch.optim.lr_scheduler import ExponentialLR
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics import Metric
import pytorch_lightning as pl
from scipy.stats import dirichlet
import numpy as np

def _absolute_error_update(
    preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor
) -> Tensor:
    _check_same_shape(preds, target)
    diff = torch.abs(preds - target)
    sum_abs_error = torch.sum(diff * diff * area, axis=1)
    absolute_error = torch.sum(sum_abs_error * omegas.squeeze())
    return absolute_error


def _absolute_error_compute(absolute_error) -> Tensor:
    return absolute_error


def absolute_error(
    preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor
) -> Tensor:
    """
    Computes squared absolute error
    Args:
        preds: estimated labels
        target: ground truth labels
        omegas: weights
        area: area of each cell
    Return:
        Tensor with absolute error
    Example:
        >>> x = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]).T
        >>> y = torch.tensor([[0, 1, 2, 1], [2, 3, 4, 4]]).T
        >>> o = torch.tensor([0.25, 0.25, 0.3, 0.2])
        >>> a = torch.tensor([0.25, 0.25])
        >>> absolute_error(x, y, o, a)
        tensor(0.4000)
    """
    sum_abs_error = _absolute_error_update(preds, target, omegas, area)
    return _absolute_error_compute(sum_abs_error)


class AbsoluteError(Metric):
    def __init__(self, compute_on_step: bool = True, dist_sync_on_step=False):
        # call `self.add_state`for every internal state that is needed for the metrics computations
        # dist_reduce_fx indicates the function that should be used to reduce
        # state from multiple processes
        super().__init__(
            compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step
        )

        self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor, omegas: Tensor, area: Tensor):
        """
        Update state with predictions and targets, and area.
        Args:
            preds: Predictions from model
            target: Ground truth values
            omegas: Weights
            area: Area of each cell
        """
        sum_abs_error = _absolute_error_update(preds, target, omegas, area)
        self.sum_abs_error += sum_abs_error

    def compute(self):
        """
        Computes absolute error over state.
        """
        return _absolute_error_compute(self.sum_abs_error)

    @property
    def is_differentiable(self):
        return True


class NNEmulator(pl.LightningModule):
    def __init__(
        self,
        n_parameters,
        n_eigenglaciers,
        V_hat,
        F_mean,
        area,
        hparams,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters(hparams)
        n_hidden_1 = self.hparams.n_hidden_1
        n_hidden_2 = self.hparams.n_hidden_2
        n_hidden_3 = self.hparams.n_hidden_3
        n_hidden_4 = self.hparams.n_hidden_4

        # Inputs to hidden layer linear transformation
        self.l_1 = nn.Linear(n_parameters, n_hidden_1)
        self.norm_1 = nn.LayerNorm(n_hidden_1)
        self.dropout_1 = nn.Dropout(p=0.0)
        self.l_2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.norm_2 = nn.LayerNorm(n_hidden_2)
        self.dropout_2 = nn.Dropout(p=0.5)
        self.l_3 = nn.Linear(n_hidden_2, n_hidden_3)
        self.norm_3 = nn.LayerNorm(n_hidden_3)
        self.dropout_3 = nn.Dropout(p=0.5)
        self.l_4 = nn.Linear(n_hidden_3, n_hidden_4)
        self.norm_4 = nn.LayerNorm(n_hidden_3)
        self.dropout_4 = nn.Dropout(p=0.5)
        self.l_5 = nn.Linear(n_hidden_4, n_eigenglaciers)

        self.V_hat = torch.nn.Parameter(V_hat, requires_grad=False)
        self.F_mean = torch.nn.Parameter(F_mean, requires_grad=False)

        self.register_buffer("area", area)

        self.train_ae = AbsoluteError()
        self.test_ae = AbsoluteError()

    def forward(self, x, add_mean=False):
        # Pass the input tensor through each of our operations

        a_1 = self.l_1(x)
        a_1 = self.norm_1(a_1)
        a_1 = self.dropout_1(a_1)
        z_1 = torch.relu(a_1)

        a_2 = self.l_2(z_1)
        a_2 = self.norm_2(a_2)
        a_2 = self.dropout_2(a_2)
        z_2 = torch.relu(a_2) + z_1

        a_3 = self.l_3(z_2)
        a_3 = self.norm_3(a_3)
        a_3 = self.dropout_3(a_3)
        z_3 = torch.relu(a_3) + z_2

        a_4 = self.l_4(z_3)
        a_4 = self.norm_3(a_4)
        a_4 = self.dropout_3(a_4)
        z_4 = torch.relu(a_4) + z_3

        z_5 = self.l_5(z_4)
        if add_mean:
            F_pred = z_5 @ self.V_hat.T + self.F_mean
        else:
            F_pred = z_5 @ self.V_hat.T

        return F_pred

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("NNEmulator")
        parser.add_argument("--batch_size", type=int, default=128)
        parser.add_argument("--n_hidden_1", type=int, default=128)
        parser.add_argument("--n_hidden_2", type=int, default=128)
        parser.add_argument("--n_hidden_3", type=int, default=128)
        parser.add_argument("--n_hidden_4", type=int, default=128)
        parser.add_argument("--learning_rate", type=float, default=0.01)

        return parent_parser

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), self.hparams.learning_rate, weight_decay=0.0
        )
        # This is an approximation to Doug's version:
        scheduler = {
            "scheduler": ExponentialLR(optimizer, 0.9975, verbose=True),
        }

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, f, o, _ = batch
        f_pred = self.forward(x)
        loss = absolute_error(f_pred, f, o, self.area)

        return loss

    def validation_step(self, batch, batch_idx):
        x, f, o, o_0 = batch
        f_pred = self.forward(x)

        self.log("train_loss", self.train_ae(f_pred, f, o, self.area))
        self.log("test_loss", self.test_ae(f_pred, f, o_0, self.area))

        return {"x": x, "f": f, "f_pred": f_pred, "o": o, "o_0": o_0}

    def validation_epoch_end(self, outputs):

        self.log(
            "train_loss",
            self.train_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "test_loss",
            self.test_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )


# As NNEmulator but number of hidden layers can be specified 
# with n_hidden_layers
class DNNEmulator(pl.LightningModule):
    def __init__(
        self,
        n_parameters: int,
        n_eigenglaciers: int,
        V_hat: Tensor,
        F_mean: Tensor,
        area: Tensor,
        hparams,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters(hparams)
        n_layers = self.hparams.n_layers
        n_hidden = self.hparams.n_hidden

        if isinstance(n_hidden, int):
            n_hidden = [n_hidden] * (n_layers - 1)

        # Inputs to hidden layer linear transformation
        self.l_first = nn.Linear(n_parameters, n_hidden[0])
        self.norm_first = nn.LayerNorm(n_hidden[0])
        self.dropout_first = nn.Dropout(p=0.0)

        models = []
        for n in range(n_layers - 2):
            models.append(
                nn.Sequential(
                    OrderedDict(
                        [
                            ("Linear", nn.Linear(n_hidden[n], n_hidden[n + 1])),
                            ("LayerNorm", nn.LayerNorm(n_hidden[n + 1])),
                            ("Dropout", nn.Dropout(p=0.1)),
                        ]
                    )
                )
            )
        self.dnn = nn.ModuleList(models)
        self.l_last = nn.Linear(n_hidden[-1], n_eigenglaciers)

        self.V_hat = torch.nn.Parameter(V_hat, requires_grad=False)
        self.F_mean = torch.nn.Parameter(F_mean, requires_grad=False)

        self.register_buffer("area", area)

        self.train_ae = AbsoluteError()
        self.test_ae = AbsoluteError()

    def forward(self, x, add_mean=False):
        # Pass the input tensor through each of our operations

        a = self.l_first(x)
        a = self.norm_first(a)
        a = self.dropout_first(a)
        z = torch.relu(a)

        for dnn in self.dnn:
            a = dnn(z)
            z = torch.relu(a) + z

        z_last = self.l_last(z)

        if add_mean:
            F_pred = z_last @ self.V_hat.T + self.F_mean
        else:
            F_pred = z_last @ self.V_hat.T

        return F_pred

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("NNEmulator")
        parser.add_argument("--batch_size", type=int, default=128)
        parser.add_argument("--n_hidden", default=128)
        parser.add_argument("--learning_rate", type=float, default=0.01)

        return parent_parser

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), self.hparams.learning_rate, weight_decay=0.0
        )
        # This is an approximation to Doug's version:
        scheduler = {
            "scheduler": ExponentialLR(optimizer, 0.9975, verbose=True),
        }

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, f, o, _ = batch
        f_pred = self.forward(x)
        loss = absolute_error(f_pred, f, o, self.area)

        return loss

    def validation_step(self, batch, batch_idx):
        x, f, o, o_0 = batch
        f_pred = self.forward(x)

        self.log("train_loss", self.train_ae(f_pred, f, o, self.area))
        self.log("test_loss", self.test_ae(f_pred, f, o_0, self.area))

        return {"x": x, "f": f, "f_pred": f_pred, "o": o, "o_0": o_0}

    def validation_epoch_end(self, outputs):

        self.log(
            "train_loss",
            self.train_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        self.log(
            "test_loss",
            self.test_ae,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )


        
max_epochs = 100
num_workers = 4
hparams = {"n_hidden": 128, 
           "n_hidden_1": 128, 
           "n_hidden_2": 128, 
           "n_hidden_3": 128, 
           "n_hidden_4": 128, 
           "n_layers": 5,
           "learning_rate": 0.01}        

n_eigenglaciers = 100
n_samples = 979
n_parameters = 8
n_grid_points = 5097
X_train = torch.randn(n_samples, n_parameters)
Y_train = torch.randn(n_samples, n_grid_points)
V_hat = torch.randn(n_grid_points, n_eigenglaciers)
F_mean = torch.randn(n_grid_points)
area = torch.ones_like(F_mean) / n_grid_points

omegas = torch.Tensor(dirichlet.rvs(np.ones(n_samples))).T
omegas = omegas.type_as(X_train)
omegas_0 = torch.ones_like(omegas) / len(omegas)

training_data = TensorDataset(X_train, Y_train, omegas, omegas_0)

batch_size = 128
train_loader = torch.utils.data.DataLoader(dataset=training_data,
                                           batch_size=batch_size,
                                           shuffle=True)

# train and val data loader are the same because we use BayesBag/Bootstrapping to avoid overfitting
# by generating 50 emulators, each with different weights "omegas"

trainer_e = pl.Trainer(
    deterministic=True,
    num_sanity_val_steps=0,
    max_epochs=max_epochs,
)

e = NNEmulator(
    n_parameters,
    n_eigenglaciers,
    V_hat,
    F_mean,
    area,
    hparams,
)

trainer_e.fit(e, train_loader, train_loader)

trainer_de = pl.Trainer(
    deterministic=True,
    num_sanity_val_steps=0,
    max_epochs=max_epochs,
)

de = DNNEmulator(
    n_parameters,
    n_eigenglaciers,
    V_hat,
    F_mean,
    area,
    hparams,
)

trainer_de.fit(de, train_loader, train_loader)


  Referenced from: /Users/andy/Library/Python/3.9/lib/python/site-packages/torchvision/image.so
  Reason: tried: '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/Users/malfet/miniforge3/envs/py_39_torch-1.10.2/lib/libpng16.16.dylib' (no such file), '/usr/local/lib/libpng16.16.dylib' (no such file), '/usr/lib/libpng16.16.dylib' (no such file)
  warn(f"Failed to load image Python extension: {e}")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

   | Name      | Type          | Params
---------------------------------------------
0  | l_1       | Linear        | 1.2 K 
1  | norm_1    | LayerNorm     | 256   
2  | dropout_1 | Dropout       | 0     
3  | l_2       | Linear        | 16.5 K
4  | norm_2    | LayerNo

Adjusting learning rate of group 0 to 1.0000e-02.
Epoch 0:  44%|████████████████████████████████████████████████████████                                                                        | 7/16 [00:00<00:00, 88.51it/s, loss=21.6, v_num=16]Adjusting learning rate of group 0 to 9.9750e-03.
Epoch 0:  50%|████████████████████████████████████████████████████████████████                                                                | 8/16 [00:00<00:00, 94.04it/s, loss=19.6, v_num=16]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 110.25it/s, loss=19.6, v_num=16, train_loss=30.80, test_loss=30.80][A
Epoch 1:  44%|█████████████████████████████████████████▏                                                    | 7/16 [00:00<00:00, 101.99it/s, loss=13, v_num=16, train_loss=30.80, test_loss=30.80][AAdjusting learning rate of group 0 to 9.9501e-03.
Epoch 1:  50%|██████████████████

Epoch 12:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 110.88it/s, loss=0.263, v_num=16, train_loss=1.330, test_loss=1.350][AAdjusting learning rate of group 0 to 9.6798e-03.
Epoch 12:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 114.17it/s, loss=0.257, v_num=16, train_loss=1.330, test_loss=1.350]
Validating: 0it [00:00, ?it/s][A
Epoch 12: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 127.22it/s, loss=0.257, v_num=16, train_loss=1.300, test_loss=1.310][A
Epoch 13:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 114.68it/s, loss=0.245, v_num=16, train_loss=1.300, test_loss=1.310][AAdjusting learning rate of group 0 to 9.6556e-03.
Epoch 13:  50%|█████████████████████████████████████████████                   

Epoch 24:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 119.60it/s, loss=0.166, v_num=16, train_loss=1.140, test_loss=1.150][AAdjusting learning rate of group 0 to 9.3934e-03.
Epoch 24:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 123.24it/s, loss=0.162, v_num=16, train_loss=1.140, test_loss=1.150]
Validating: 0it [00:00, ?it/s][A
Epoch 24: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 139.79it/s, loss=0.162, v_num=16, train_loss=1.130, test_loss=1.140][A
Epoch 25:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 124.48it/s, loss=0.163, v_num=16, train_loss=1.130, test_loss=1.140][AAdjusting learning rate of group 0 to 9.3699e-03.
Epoch 25:  50%|█████████████████████████████████████████████                   

Epoch 36:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 125.10it/s, loss=0.143, v_num=16, train_loss=1.080, test_loss=1.090][AAdjusting learning rate of group 0 to 9.1154e-03.
Epoch 36:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 130.63it/s, loss=0.141, v_num=16, train_loss=1.080, test_loss=1.090]
Validating: 0it [00:00, ?it/s][A
Epoch 36: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 140.79it/s, loss=0.141, v_num=16, train_loss=1.070, test_loss=1.080][A
Epoch 37:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 119.89it/s, loss=0.144, v_num=16, train_loss=1.070, test_loss=1.080][AAdjusting learning rate of group 0 to 9.0926e-03.
Epoch 37:  50%|█████████████████████████████████████████████                   

Epoch 48:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 122.64it/s, loss=0.135, v_num=16, train_loss=1.050, test_loss=1.060][AAdjusting learning rate of group 0 to 8.8457e-03.
Epoch 48:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 127.43it/s, loss=0.134, v_num=16, train_loss=1.050, test_loss=1.060]
Validating: 0it [00:00, ?it/s][A
Epoch 48: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 143.48it/s, loss=0.134, v_num=16, train_loss=1.050, test_loss=1.060][A
Epoch 49:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 121.94it/s, loss=0.136, v_num=16, train_loss=1.050, test_loss=1.060][AAdjusting learning rate of group 0 to 8.8236e-03.
Epoch 49:  50%|█████████████████████████████████████████████                   

Epoch 60:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 119.50it/s, loss=0.133, v_num=16, train_loss=1.040, test_loss=1.040][AAdjusting learning rate of group 0 to 8.5839e-03.
Epoch 60:  50%|█████████████████████████████████████████████▌                                             | 8/16 [00:00<00:00, 123.25it/s, loss=0.13, v_num=16, train_loss=1.040, test_loss=1.040]
Validating: 0it [00:00, ?it/s][A
Epoch 60: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 135.78it/s, loss=0.13, v_num=16, train_loss=1.040, test_loss=1.040][A
Epoch 61:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 120.89it/s, loss=0.133, v_num=16, train_loss=1.040, test_loss=1.040][AAdjusting learning rate of group 0 to 8.5625e-03.
Epoch 61:  50%|█████████████████████████████████████████████                   

Epoch 72:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 118.98it/s, loss=0.133, v_num=16, train_loss=1.030, test_loss=1.040][AAdjusting learning rate of group 0 to 8.3299e-03.
Epoch 72:  50%|█████████████████████████████████████████████▌                                             | 8/16 [00:00<00:00, 123.21it/s, loss=0.13, v_num=16, train_loss=1.030, test_loss=1.040]
Validating: 0it [00:00, ?it/s][A
Epoch 72: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 134.98it/s, loss=0.13, v_num=16, train_loss=1.030, test_loss=1.030][A
Epoch 73:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 122.87it/s, loss=0.131, v_num=16, train_loss=1.030, test_loss=1.030][AAdjusting learning rate of group 0 to 8.3091e-03.
Epoch 73:  50%|█████████████████████████████████████████████                   

Epoch 84:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 111.44it/s, loss=0.13, v_num=16, train_loss=1.020, test_loss=1.030][AAdjusting learning rate of group 0 to 8.0835e-03.
Epoch 84:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 115.51it/s, loss=0.128, v_num=16, train_loss=1.020, test_loss=1.030]
Validating: 0it [00:00, ?it/s][A
Epoch 84: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 128.01it/s, loss=0.128, v_num=16, train_loss=1.020, test_loss=1.030][A
Epoch 85:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 109.44it/s, loss=0.13, v_num=16, train_loss=1.020, test_loss=1.030][AAdjusting learning rate of group 0 to 8.0632e-03.
Epoch 85:  50%|█████████████████████████████████████████████                   

Epoch 96:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 117.83it/s, loss=0.129, v_num=16, train_loss=1.020, test_loss=1.020][AAdjusting learning rate of group 0 to 7.8443e-03.
Epoch 96:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 120.46it/s, loss=0.126, v_num=16, train_loss=1.020, test_loss=1.020]
Validating: 0it [00:00, ?it/s][A
Epoch 96: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 135.34it/s, loss=0.126, v_num=16, train_loss=1.020, test_loss=1.020][A
Epoch 97:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 117.16it/s, loss=0.127, v_num=16, train_loss=1.020, test_loss=1.020][AAdjusting learning rate of group 0 to 7.8246e-03.
Epoch 97:  50%|█████████████████████████████████████████████                   

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs



Adjusting learning rate of group 0 to 1.0000e-02.



  | Name          | Type          | Params
------------------------------------------------
0 | l_first       | Linear        | 1.2 K 
1 | norm_first    | LayerNorm     | 256   
2 | dropout_first | Dropout       | 0     
3 | dnn           | ModuleList    | 50.3 K
4 | l_last        | Linear        | 12.9 K
5 | train_ae      | AbsoluteError | 0     
6 | test_ae       | AbsoluteError | 0     
------------------------------------------------
64.6 K    Trainable params
514 K     Non-trainable params
579 K     Total params
2.318     Total estimated model params size (MB)


Epoch 0:  44%|███████████████████████████████████████████████████████▌                                                                       | 7/16 [00:00<00:00, 121.95it/s, loss=17.7, v_num=17]Adjusting learning rate of group 0 to 9.9750e-03.
Epoch 0:  50%|████████████████████████████████████████████████████████████████▌                                                                | 8/16 [00:00<00:00, 124.28it/s, loss=16, v_num=17]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 139.18it/s, loss=16, v_num=17, train_loss=40.00, test_loss=40.00][A
Epoch 1:  44%|████████████████████████████████████████▎                                                   | 7/16 [00:00<00:00, 118.19it/s, loss=10.4, v_num=17, train_loss=40.00, test_loss=40.00][AAdjusting learning rate of group 0 to 9.9501e-03.
Epoch 1:  50%|██████████████████████████████████████████████                      

Epoch 12:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 127.88it/s, loss=0.283, v_num=17, train_loss=1.450, test_loss=1.470][AAdjusting learning rate of group 0 to 9.6798e-03.
Epoch 12:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 132.74it/s, loss=0.279, v_num=17, train_loss=1.450, test_loss=1.470]
Validating: 0it [00:00, ?it/s][A
Epoch 12: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 149.85it/s, loss=0.279, v_num=17, train_loss=1.400, test_loss=1.430][A
Epoch 13:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 131.64it/s, loss=0.27, v_num=17, train_loss=1.400, test_loss=1.430][AAdjusting learning rate of group 0 to 9.6556e-03.
Epoch 13:  50%|█████████████████████████████████████████████                   

Epoch 24:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 116.71it/s, loss=0.19, v_num=17, train_loss=1.180, test_loss=1.190][AAdjusting learning rate of group 0 to 9.3934e-03.
Epoch 24:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 120.62it/s, loss=0.185, v_num=17, train_loss=1.180, test_loss=1.190]
Validating: 0it [00:00, ?it/s][A
Epoch 24: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 138.48it/s, loss=0.185, v_num=17, train_loss=1.170, test_loss=1.180][A
Epoch 25:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 122.44it/s, loss=0.186, v_num=17, train_loss=1.170, test_loss=1.180][AAdjusting learning rate of group 0 to 9.3699e-03.
Epoch 25:  50%|█████████████████████████████████████████████                   

Epoch 36:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 119.80it/s, loss=0.16, v_num=17, train_loss=1.100, test_loss=1.110][AAdjusting learning rate of group 0 to 9.1154e-03.
Epoch 36:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 123.10it/s, loss=0.156, v_num=17, train_loss=1.100, test_loss=1.110]
Validating: 0it [00:00, ?it/s][A
Epoch 36: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 137.81it/s, loss=0.156, v_num=17, train_loss=1.100, test_loss=1.110][A
Epoch 37:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 118.52it/s, loss=0.159, v_num=17, train_loss=1.100, test_loss=1.110][AAdjusting learning rate of group 0 to 9.0926e-03.
Epoch 37:  50%|█████████████████████████████████████████████                   

Epoch 48:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 109.18it/s, loss=0.144, v_num=17, train_loss=1.060, test_loss=1.070][AAdjusting learning rate of group 0 to 8.8457e-03.
Epoch 48:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 112.85it/s, loss=0.142, v_num=17, train_loss=1.060, test_loss=1.070]
Validating: 0it [00:00, ?it/s][A
Epoch 48: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 128.08it/s, loss=0.142, v_num=17, train_loss=1.060, test_loss=1.070][A
Epoch 49:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 111.17it/s, loss=0.147, v_num=17, train_loss=1.060, test_loss=1.070][AAdjusting learning rate of group 0 to 8.8236e-03.
Epoch 49:  50%|█████████████████████████████████████████████                   

Epoch 60:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 124.79it/s, loss=0.138, v_num=17, train_loss=1.040, test_loss=1.050][AAdjusting learning rate of group 0 to 8.5839e-03.
Epoch 60:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 127.28it/s, loss=0.134, v_num=17, train_loss=1.040, test_loss=1.050]
Validating: 0it [00:00, ?it/s][A
Epoch 60: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 138.98it/s, loss=0.134, v_num=17, train_loss=1.040, test_loss=1.050][A
Epoch 61:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 116.54it/s, loss=0.137, v_num=17, train_loss=1.040, test_loss=1.050][AAdjusting learning rate of group 0 to 8.5625e-03.
Epoch 61:  50%|█████████████████████████████████████████████                   

Epoch 72:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 114.38it/s, loss=0.134, v_num=17, train_loss=1.030, test_loss=1.040][AAdjusting learning rate of group 0 to 8.3299e-03.
Epoch 72:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 117.11it/s, loss=0.131, v_num=17, train_loss=1.030, test_loss=1.040]
Validating: 0it [00:00, ?it/s][A
Epoch 72: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 128.56it/s, loss=0.131, v_num=17, train_loss=1.030, test_loss=1.040][A
Epoch 73:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 115.21it/s, loss=0.135, v_num=17, train_loss=1.030, test_loss=1.040][AAdjusting learning rate of group 0 to 8.3091e-03.
Epoch 73:  50%|█████████████████████████████████████████████                   

Epoch 84:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 126.84it/s, loss=0.133, v_num=17, train_loss=1.030, test_loss=1.030][AAdjusting learning rate of group 0 to 8.0835e-03.
Epoch 84:  50%|█████████████████████████████████████████████▌                                             | 8/16 [00:00<00:00, 131.58it/s, loss=0.13, v_num=17, train_loss=1.030, test_loss=1.030]
Validating: 0it [00:00, ?it/s][A
Epoch 84: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 142.47it/s, loss=0.13, v_num=17, train_loss=1.030, test_loss=1.030][A
Epoch 85:  44%|███████████████████████████████████████▍                                                  | 7/16 [00:00<00:00, 125.58it/s, loss=0.131, v_num=17, train_loss=1.030, test_loss=1.030][AAdjusting learning rate of group 0 to 8.0632e-03.
Epoch 85:  50%|█████████████████████████████████████████████                   

Epoch 96:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 122.04it/s, loss=0.13, v_num=17, train_loss=1.020, test_loss=1.030][AAdjusting learning rate of group 0 to 7.8443e-03.
Epoch 96:  50%|█████████████████████████████████████████████                                             | 8/16 [00:00<00:00, 125.08it/s, loss=0.128, v_num=17, train_loss=1.020, test_loss=1.030]
Validating: 0it [00:00, ?it/s][A
Epoch 96: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 104.66it/s, loss=0.128, v_num=17, train_loss=1.020, test_loss=1.030][A
Epoch 97:  44%|███████████████████████████████████████▊                                                   | 7/16 [00:00<00:00, 123.18it/s, loss=0.13, v_num=17, train_loss=1.020, test_loss=1.030][AAdjusting learning rate of group 0 to 7.8246e-03.
Epoch 97:  50%|█████████████████████████████████████████████                   

In [2]:
Y_pred_e = e(X_train, add_mean=True)
Y_pred_de = de(X_train, add_mean=True)
print(torch.allclose(Y_pred_e, Y_pred_de, rtol=1e-3))

False


In [3]:
Y_pred_e

tensor([[ 0.6809,  0.1195, -0.3548,  ...,  0.3786,  0.0381,  0.2980],
        [ 0.6017,  0.3940, -0.6569,  ...,  0.0306,  0.2720, -0.0374],
        [ 0.7752, -0.0181, -0.5151,  ...,  0.3033,  0.0036,  0.3225],
        ...,
        [ 0.7611,  0.0346, -0.0890,  ...,  0.1037,  0.4672,  0.0961],
        [ 0.6522,  0.3320, -0.1450,  ...,  0.2635,  0.1432,  0.0391],
        [ 0.7730,  0.0619, -0.3079,  ...,  0.3870,  0.3697, -0.2282]],
       grad_fn=<AddBackward0>)

In [4]:
Y_pred_de

tensor([[ 0.6835, -0.0328, -0.5363,  ...,  0.5086,  0.3566,  0.1720],
        [ 1.0097,  0.1232, -0.3327,  ...,  0.4315, -0.1305, -0.1026],
        [ 0.5467, -0.2653, -0.8158,  ...,  0.0742, -0.1648,  0.2609],
        ...,
        [ 0.6366, -0.1758, -0.6316,  ...,  0.0932,  0.0277, -0.0120],
        [ 0.5793,  0.3048, -0.6393,  ...,  0.1579,  0.0185,  0.1008],
        [ 0.7769, -0.0412, -0.3727,  ...,  0.5590, -0.0187,  0.3127]],
       grad_fn=<AddBackward0>)