In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pytorch_lightning as pl
from scipy.stats import dirichlet

from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from torch.utils.tensorboard import SummaryWriter, FileWriter

In [None]:
torch.manual_seed(0)
np.random.seed(0)

tb_dir = "tb_logs"

max_epochs = 100
X_train = torch.randn(100, 8, requires_grad=True)
F_train = torch.empty(100, dtype=torch.long).random_(1000)
omegas = torch.tensor(dirichlet.rvs(np.ones(100)), dtype=torch.float).T
omegas_0 = torch.ones_like(omegas) / len(omegas)
dataset = torch.utils.data.TensorDataset(X_train, F_train, omegas, omegas_0)
train_dataloader = torch.utils.data.DataLoader(dataset)

In [None]:
class TestModel(pl.LightningModule):
    def __init__(
        self,
        V_hat,
        n_p: int = 8,
        n_e: int = 8,
        n_hidden_1: int = 128,
        n_hidden_2: int = 128,
        n_hidden_3: int = 128,
        n_hidden_4: int = 128,
    ):
        super().__init__()
        # Inputs to hidden layer linear transformation
        self.l_1 = nn.Linear(n_p, n_hidden_1)
        self.norm_1 = nn.LayerNorm(n_hidden_1)
        self.dropout_1 = nn.Dropout(p=0.0)
        self.l_2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.norm_2 = nn.LayerNorm(n_hidden_2)
        self.dropout_2 = nn.Dropout(p=0.5)
        self.l_3 = nn.Linear(n_hidden_2, n_hidden_3)
        self.norm_3 = nn.LayerNorm(n_hidden_3)
        self.dropout_3 = nn.Dropout(p=0.5)
        self.l_4 = nn.Linear(n_hidden_3, n_hidden_4)
        self.norm_4 = nn.LayerNorm(n_hidden_3)
        self.dropout_4 = nn.Dropout(p=0.5)
        self.l_5 = nn.Linear(n_hidden_4, n_e)

        self.V_hat = torch.nn.Parameter(V_hat, requires_grad=False)


    def forward(self, x, add_mean=False):
        # Pass the input tensor through each of our operations

        a_1 = self.l_1(x)
        a_1 = self.norm_1(a_1)
        a_1 = self.dropout_1(a_1)
        z_1 = torch.relu(a_1)

        a_2 = self.l_2(z_1)
        a_2 = self.norm_2(a_2)
        a_2 = self.dropout_2(a_2)
        z_2 = torch.relu(a_2) + z_1

        a_3 = self.l_3(z_2)
        a_3 = self.norm_3(a_3)
        a_3 = self.dropout_3(a_3)
        z_3 = torch.relu(a_3) + z_2

        a_4 = self.l_4(z_3)
        a_4 = self.norm_3(a_4)
        a_4 = self.dropout_3(a_4)
        z_4 = torch.relu(a_4) + z_3

        z_5 = self.l_5(z_4)

        F_pred = z_5 @ self.V_hat.T

        return F_pred
    
    def criterion_ae(self, F_pred, F_obs, omegas):
        instance_misfit = torch.sum(torch.abs(F_pred - F_obs) ** 2, axis=1)
        return torch.sum(instance_misfit * omegas.squeeze())

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = {
        "scheduler": ReduceLROnPlateau(optimizer, verbose=True),
        "reduce_on_plateau": True,
        "monitor": "loss",
        }

        return optimizer

    def training_step(self, batch, batch_idx):
        x, f, o, o_0 = batch
        f_pred = self.forward(x)
        loss = self.criterion_ae(f_pred, f, o)

        return {"loss": loss, "x": x, "f": f, "omegas": o, "omegas_0": o_0}

    def on_train_epoch_end(self, outputs):
        x = []
        f = []
        omegas_0 = []
        omegas = []
        for k, out in enumerate(outputs[0]):
            o = out[0]["extra"]
            x.append(o["x"])
            f.append(o["f"])
            omegas.append(o["omegas"])
            omegas_0.append(o["omegas_0"])
        x = torch.vstack(x)
        f = torch.vstack(f)
        omegas = torch.vstack(omegas)
        omegas_0 = torch.vstack(omegas_0)
        self.trainer.model.eval()
        f_pred_eval = self.forward(x)
        self.trainer.model.train()
        f_pred_train = self.forward(x)
        train_loss_eval = self.criterion_ae(f_pred_eval, f, omegas)
        train_loss_train = self.criterion_ae(f_pred_train, f, omegas)
        test_loss_eval = self.criterion_ae(f_pred_eval, f, omegas_0)
        test_loss_train = self.criterion_ae(f_pred_train, f, omegas_0)

        self.log("train_loss_eval", train_loss_eval, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_loss_train", train_loss_train, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_loss_eval", test_loss_eval, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_loss_train", test_loss_train, on_step=False, on_epoch=True, prog_bar=True)

## Torch implementation

In [None]:
model = TestModel(X_train)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
writer = SummaryWriter("tb_logs")

for epoch in range(max_epochs):
    for x, f, o, o_0 in train_dataloader:
        model.train()
        f_hat = model(x)
        loss = model.criterion_ae(f_hat, f, o)
        writer.add_scalar("loss", loss)

        # clear gradients
        optimizer.zero_grad()

        # backward
        loss.backward()

        # update parameters
        optimizer.step()
        

    model.eval()
    F_pred = model(X_train)
    loss_train = model.criterion_ae(F_pred, F_train, omegas)
    loss_test = model.criterion_ae(F_pred, F_train, omegas_0)
    writer.add_scalar("loss_train", loss_train, epoch)
    writer.add_scalar("loss_test", loss_test, epoch)
    print(f"epoch: {epoch}, train loss: {loss_train.item()}, test loss {loss_test.item()}")
    
writer.close()

## Set up Tensorboard Logger

In [None]:
logger = TensorBoardLogger("tb_logs", name="Lightning")

# Lightning Implementation

In [None]:
model = TestModel(X_train)

## Set up LearningRate monitor

In [None]:
lr_monitor = LearningRateMonitor(logging_interval="step")

## Run for max_epochs

In [None]:
trainer = pl.Trainer(callbacks=[lr_monitor], max_epochs=max_epochs, logger=logger)
trainer.fit(model, train_dataloader)

## Run with Early Stopping

In [None]:
early_stop_callback = EarlyStopping(
    monitor="loss", min_delta=0.0, patience=5, verbose=False, mode="min", strict=True
    )
trainer = pl.Trainer(callbacks=[lr_monitor, early_stop_callback], logger=logger)

trainer = pl.Trainer(max_epochs=max_epochs, logger=logger)
trainer.fit(model, train_dataloader)

## Load tensorboard extention

In [None]:
%load_ext tensorboard

In [None]:
tensorboard --logdir="tb_logs/"