 # Notebook 05: Train ANNs with PyTorch



 ### Primary Goal: Train and evaluate the artificial neural networks



 #### Background



 In the paper we start by training artificial neural networks, so we will do the same here in the notebooks using PyTorch.



 Note that the model used in the paper is included in the github repository, but we will also take you through the steps to build and train a similar network here.



 #### Step 1: Imports

In [None]:
!pip install gewitter-functions
!pip install torchmetrics



In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import pandas as pd
import numpy as np
import tqdm
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects

#outlines for text
pe1 = [path_effects.withStroke(linewidth=1.5,
foreground="k")]
pe2 = [path_effects.withStroke(linewidth=1.5,
foreground="w")]

#plot parameters that I personally like, feel free to make these your own.
matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
matplotlib.rcParams['axes.labelsize'] = 14
matplotlib.rcParams['axes.titlesize'] = 14
matplotlib.rcParams['xtick.labelsize'] = 12
matplotlib.rcParams['ytick.labelsize'] = 12
matplotlib.rcParams['legend.fontsize'] = 12
matplotlib.rcParams['legend.facecolor'] = 'w'
matplotlib.rcParams['savefig.transparent'] = False
%config InlineBackend.figure_format = 'retina'

#one quick thing here, we need to set the random seed so we all get the same results no matter the computer or python session
_ = torch.manual_seed(43)


def param_summary(model: torch.nn.Module) -> None:
    """Iterate through model's trainable named parameters, print name and count"""
    print(f"{'Layer Name':<20} {'Shape':<20} {'Param #':<10}")
    print("-" * 50)

    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name:<20} {str(list(param.shape)):<20} {param.numel():<10}")
            total_params += param.numel()

    print("-" * 50)
    print(f"{total_params=}")


 #### Step 2: Load data

In [None]:

#load training
try:
    df_t = pd.read_csv('../datasets/sub-sevir-engineered/lowres_features_train.csv')
except FileNotFoundError:
    df_t = pd.read_csv('https://raw.githubusercontent.com/ai2es/WAF_ML_Tutorial_Part2/refs/heads/main/datasets/sub-sevir-engineered/lowres_features_train.csv')
#load validation set
try:
    df_v = pd.read_csv('../datasets/sub-sevir-engineered/lowres_features_val.csv')
except FileNotFoundError:
    df_v = pd.read_csv('https://raw.githubusercontent.com/ai2es/WAF_ML_Tutorial_Part2/refs/heads/main/datasets/sub-sevir-engineered/lowres_features_val.csv')


#make matrices for training/validation
X_t = df_t.to_numpy()[:,:36]
y_t = df_t.to_numpy()[:,36]
X_v = df_v.to_numpy()[:,:36]
y_v = df_v.to_numpy()[:,36]

# Convert to PyTorch Tensors (Float32 is standard for NNs)
X_t_tensor = torch.tensor(X_t, dtype=torch.float32)
y_t_tensor = torch.tensor(y_t, dtype=torch.float32).view(-1, 1) # Reshape to (N, 1)

X_v_tensor = torch.tensor(X_v, dtype=torch.float32)
y_v_tensor = torch.tensor(y_v, dtype=torch.float32).view(-1, 1)


 #### Step 3: Make PyTorch Dataset



 As we discussed in the previous notebook, we need to shuffle and batch the data. We will leverage `torch.utils.data.DataLoader`.

In [None]:
# make datasets
train_ds = TensorDataset(X_t_tensor, y_t_tensor)
val_ds = TensorDataset(X_v_tensor, y_v_tensor)

# batch size
batch_size = 32

# make dataloaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)  # Shuffle only the training
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)


 Just a few quick notes on batching: First, we technically do not need to batch the validation data if it fits in memory. The reason we do so here is to handle memory efficiently when computing predictions. Second, it is important to carefully choose how much data to include within a single batch. If your batch is too small, it could take many iterations to train the model. Conversely, a batch that is too large can overwhelm your RAM. As such, the optimal batch size will be machine dependent.

#### Step 4: Build a model

In the paper we note that neural networks do not have a "one size fits all" paradigm, where a set of default parameters can consistently achieve good results. With neural networks you *need* to do some sort of hyperparameter search. So here, while we only show one model configuration, we encourage you to play around with different model configurations to figure out what hyperparameters work best for a given prediction task. In fact, we have a modular script that leverages the tensorflow api to help optimize your model configurations. The notebook example explaining this script is [here](#) <-- dead link right now!

Let's start simple and create a model with 2 layers with 2 neurons each to classify whether or not an example contains lightning (0 no lightning, 1 lightning).

In [None]:
class SimpleANN(torch.nn.Module):
    "My first PyTorch Model"
    def __init__(self, X_t: torch.Tensor) -> None:
        """
        Instantiate model. Define all the layers the model will use

        Args:
            X_t: input tensor for model
        Returns:
            None
        """
        super(SimpleANN, self).__init__()

        self.linear1 = nn.Linear(in_features=X_t.shape[1], out_features=2)
        self.activation1 = nn.ReLU()
        self.linear2 = nn.Linear(in_features=2, out_features=2)
        self.activation2 = nn.ReLU()
        self.output = nn.Linear(in_features=2, out_features=1)
        self.activation_output = nn.Sigmoid()

    def forward(self, x):
        """define forward pass for the model"""
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.output(x)
        x = self.activation_output(x)
        return x


model = SimpleANN(X_t=X_t)
print(model)
param_summary(model)

 Notice the following about the above definition of the model:



 1. The input layer *always* has the `in_features` defined

    - This is needed so that the model is initialized with the correct number of weights.



 2. The activation functions for the input and hidden layers are `ReLU`

    - NNs require an activation function to effectively learn non-linear relationships in the data.  The 'relu' activation function is commonly used for this purpose.



 3. The output layer (last layer) has a *sigmoid* activation function

    - This is specifically for classification tasks. If we were training a regression model (as in the previous notebook), we wouldn't need this (i.e., linear activation).
    - If you have more than one output neuron (say if you have more than two possible classifications such as 'no lightning', 'some lightning', and 'lots of lightning'), then you would use *softmax* instead of sigmoid.




 #### Step 5: Run some data through it



 In order to visualize the initial performance with random weights and biases, we can plug the data into the untrained model. Unlike Keras, we need to manually pass the data through the model, preferably using `no_grad` to save memory.

In [None]:
# define plotting function
def model_prediction_hist(y_preds, title=""):
    """plot histogram of y_preds"""
    fig, ax = plt.subplots(1,1)
    ax.hist(y_preds)
    ax.set_xlabel('prob of lightning')
    ax.set_ylabel('count')
    ax.set_title(title)
    ax.set_xlim([0,1])
    return ax


# Set model to evaluation mode (i.e. turn off dropout layers, etc.)
model.eval()

# Get predictions for the validation set - no need to track gradients.
with torch.no_grad():
    # We can pass the whole tensor since it fits in memory,
    # or iterate the loader if it was huge. Let's pass the tensor.
    y_preds_tensor = model(X_v_tensor)
    y_preds = y_preds_tensor.numpy() # Convert back to numpy for plotting

untrained_predictions_hist = model_prediction_hist(
    y_preds=y_preds, title="Untrained ANN predictions"
)
plt.show()


 The sigmoid function in the output layer ensures that all model output falls between 0 - 1. In order to plot the performance diagram, we need to calculate TP, FP, and FN. We will define a helper function for this using the torchmetrics library.

In [None]:
import torch
import numpy as np
from torchmetrics.classification import StatScores


class MultiThresholdMetrics:
    """Calculate classification metrics for range of thresholds"""
    def __init__(self, thresholds: torch.Tensor, device='cpu'):
        self.thresholds = thresholds.to(device).view(1, -1) # Shape: (1, n_thresholds)
        self.device = device

    def __call__(self, y_true, y_pred) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        y_true: (N,) or (N, 1) - Ground truth labels (0 or 1)
        y_pred: (N,) or (N, 1) - Predicted probabilities
        """
        # Ensure inputs are flat and on the correct device
        y_true = y_true.to(self.device).view(-1, 1) # Shape: (N, 1)
        y_pred = y_pred.to(self.device).view(-1, 1) # Shape: (N, 1)

        # Broadcast comparison: (N, 1) >= (1, T) -> (N, T)
        # This creates a boolean matrix where each column corresponds to a threshold
        pred_labels = (y_pred >= self.thresholds).float()

        # Calculate Stats per threshold (summing over the batch dimension N)
        # True Positives: Predicted 1 AND Actual 1
        tps = (pred_labels * y_true).sum(dim=0)

        # False Positives: Predicted 1 AND Actual 0
        fps = (pred_labels * (1 - y_true)).sum(dim=0)

        # False Negatives: Predicted 0 AND Actual 1
        fns = ((1 - pred_labels) * y_true).sum(dim=0)

        return tps, fps, fns


# Instantiate
thresh = torch.tensor(np.arange(0.05, 1.05, 0.05), dtype=torch.float32)
metric_calculator = MultiThresholdMetrics(thresh)

# Calculate
tps, fps, fns = metric_calculator(torch.Tensor(y_v), y_preds_tensor)
tps, fps, fns = [x.numpy() for x in (tps, fps, fns)]

# Note: adding epsilon to avoid division by zero
eps = 1e-7
pods = tps / (tps + fns + eps)       # Probability of Detection (Recall)
srs = tps / (tps + fps + eps)        # Success Ratio (Precision)
csis = tps / (tps + fns + fps + eps) # Critical Success Index

print("PODs:", pods)
print("SRs:", srs)
print("CSIs:", csis)

In [None]:
# import some helper functions for our other directory.
# import sys
# sys.path.insert(1, '../scripts/')

#load contingency_table func
from gewitter_functions import get_contingency_table, make_performance_diagram_axis, get_acc, get_pod, get_sr, csi_from_sr_and_pod

#plot it up
def plot_performance_diagram(truth, preds, metric_calculator, title: str = "", nth_point_to_label: int = 3):
  """create performance diagram from truth and predictions"""
  # Recalculate metrics with trained predictions
  tps, fps, fns = metric_calculator(truth, preds)
  tps, fps, fns = [x.numpy() for x in [tps, fps, fns]]

  #calc x,y of performance diagram
  eps = 1e-7
  pods = tps/(tps + fns + eps)
  srs = tps/(tps + fps + eps)
  csis = tps/(tps + fns + eps)

  #plot it up
  ax = make_performance_diagram_axis()
  ax.plot(np.asarray(srs)[:-1], np.asarray(pods)[:-1], '-', color='dodgerblue', markerfacecolor='w', label='UNET')

  for i, t in enumerate(thresh.numpy()):
      #plot text and marker every 3rd point, because every point was too many
      if np.mod(i, nth_point_to_label) == 0:
          text = np.char.ljust(str(np.round(t, 2)), width=4, fillchar='0')
          ax.plot(np.asarray(srs)[i], np.asarray(pods)[i], 's', color='dodgerblue', markerfacecolor='w')
          ax.text(np.asarray(srs)[i] + 0.02, np.asarray(pods)[i], text, path_effects=pe1, fontsize=9, color='white')

  plt.title(title)
  plt.tight_layout()
  return ax


untrained_ann_perf_diag = plot_performance_diagram(
    truth=torch.Tensor(y_v),
    preds=y_preds_tensor,
    metric_calculator=metric_calculator,
    title="Untrained ANN Performance Diagram",
    nth_point_to_label=1
    )
plt.show()


As expected, the model performance looks wonky\! It is an untrained model.

#### Step 6: Train the model

Okay, let's
1. define our loss and optimizer,
2. and then write the training loop.

We will use **Binary Cross Entropy** (`BCELoss`) and Root **Mean Square Propagation** (`RMSprop`).

In [None]:
# Define Loss and Optimizer
loss_fn = nn.BCELoss()
optimizer = optim.RMSprop(model.parameters(), lr=1e-3)

# Lists to keep track of losses
train_losses = []
val_losses = []

epochs = 25
for epoch in range(epochs):
    # --- Training Phase ---
    model.train()
    running_loss = 0.0
    for x_batch, y_batch in train_loader:
        # 1. Zero gradients
        optimizer.zero_grad()

        # 2. Forward pass
        preds = model(x_batch)

        # 3. Calculate loss
        loss = loss_fn(preds, y_batch)

        # 4. Backward pass
        loss.backward()

        # 5. Step
        optimizer.step()

        running_loss += loss.item() * x_batch.size(0)

    epoch_train_loss = running_loss / len(train_ds)
    train_losses.append(epoch_train_loss)

    # --- Validation Phase ---
    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            val_preds = model(x_val)
            v_loss = loss_fn(val_preds, y_val)
            running_val_loss += v_loss.item() * x_val.size(0)

    epoch_val_loss = running_val_loss / len(val_ds)
    val_losses.append(epoch_val_loss)

    # Print progress every 5 epochs
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f}")




 #### Step 7: Check for overfitting



 We manually tracked the losses in lists (`train_losses` and `val_losses`) during the loop. This data is very useful for determining if a model is overfitting. Let's plot the training loss vs the validation loss:

In [None]:
plt.plot(train_losses, color='dodgerblue', label='training')
plt.plot(val_losses, color='orangered', label='validation')
plt.legend()
plt.xlabel('Epoch number')
plt.ylabel('loss')
plt.title("Simple ANN Model Training Curves")
plt.show()

 Above, you want to compare the red line to the blue line. The absolute value is less important. Notice that the red line is relatively flat after about 10 epochs. This isn't much of an overfitting signal, which would be increasing validation loss with more epochs, but it does seem to have converged to some local minimum.



 #### Step 8: Check validation performance



 Now that the model is trained, let's check the new predictions and generate the performance diagram:

In [None]:
# Set model to eval mode
model.eval()
with torch.no_grad():
    y_preds_tensor = model(X_v_tensor)
    y_preds = y_preds_tensor.numpy()


model_prediction_hist(y_preds, title="Simple ANN predictions histogram")
plt.show()



 Well this looks better than before, and we can see that there is a good number of predictions near 1 (lightning) and 0 (no lightning). Hopefully these predictions align with the correct 'truth' labels. Let's take a look at the performance diagram.

In [None]:
simple_ann_perf_diag = plot_performance_diagram(
    truth=torch.Tensor(y_v),
    preds=y_preds_tensor,
    metric_calculator=metric_calculator,
    title="Simple ANN Performance Diagram"
  )
plt.show()


 MUCH BETTER. Even just a simple neural network does well here. This is probably because the task is relatively easy - remember that we could use a simple brightness temperature threshold for IR to get an 80% accuracy.



 Because this network is so small, we can actually view the learned weights and biases of each layer.

In [None]:
# Input layer weights
# note weights addressed by structure in model class definition
model.linear1.weight

In [None]:
# First hidden layer weights
model.linear2.weight


 #### Step 9: Save trained model



 Now that you have a model trained, you probably don't want to re-train it every time you need to make a prediction. In PyTorch, the standard is to save the `state_dict`, which is a dictionary containing all the learnable parameters.

In [None]:
model_path = '../datasets/models/neural_nets_from_notebooks/MyFirstNN.pt'
model_path = 'MyFirstNN.pt'
torch.save(model.state_dict(), model_path)


 This saves the trained parameters to a `.pt` file. This is the most robust way to save PyTorch models.



 #### Step 11: Load trained model



 To load the model, we first need to instantiate the model architecture (the code must be available), and then load the state dictionary into it.

In [None]:
# 1. Re-create the model architecture
loaded_model = SimpleANN(X_t=X_t)

# 2. Load the weights
loaded_model.load_state_dict(torch.load(model_path))
loaded_model.eval() # Don't forget to set to eval mode!

 #### Step 12: Run loaded model



 Now it's all set to run\!

In [None]:
# infer on validation set with loaded model
with torch.no_grad():
    y_preds_tensor = loaded_model(X_v_tensor)
    y_preds = y_preds_tensor.numpy()


model_prediction_hist(y_preds, title="Simple ANN (from disk) predictions histogram")
plt.show()

plot_performance_diagram(
    truth=torch.Tensor(y_v),
    preds=y_preds_tensor,
    metric_calculator=metric_calculator,
    title="Simple ANN (from disk) performance diagram"
)
plt.show()

#### Step 13: Load and run a pre-trained network

Here we load a pre-trained model from disk. Note we're loading the entire model here. This ANN Classifier model was 'traced' into TorchScript - a ScriptFunction. This is a serialized deployable object representing the original model as a static DAG.


In [None]:
# load pretrained model
pretrained_model_path = "sub-sevir-ann-class-1d-eng-TorchScript.pt"
pretrained_model = torch.jit.load(pretrained_model_path)
pretrained_model.eval()

In [None]:
# Infer validation set with pretrained model
with torch.no_grad():
    # Make sure X_v_tensor is on the correct device if the model was traced on GPU
    y_preds_traced_tensor = pretrained_model(X_v_tensor)
    y_preds_traced = y_preds_traced_tensor.numpy()

# Plot histogram of predictions
model_prediction_hist(y_preds_traced, title="Traced ANN predictions histogram")
plt.show()

plot_performance_diagram(
    truth=torch.Tensor(y_v),
    preds=y_preds_traced_tensor,
    metric_calculator=metric_calculator,
    title="Traced ANN Performance Diagram"
)
plt.show()

Voila! Now you can hopefully do end-to-end neural networks. The next notebook will jump into convolutions.
