# 2. Understanding and training PhaseNet on a local dataset

Re-made from the [PhaseNet tutorial](https://colab.research.google.com/github/seisbench/seisbench/blob/main/examples/03a_training_phasenet.ipynb) provided by the [SeisBench](https://seisbench.readthedocs.io/en/latest/) project  by [Léonard Seydoux](https://sites.google.com/view/leonard-seydoux/accueil) for the [Short Course #3](https://spin-itn.eu/sc3/) of the [Innovative Training Network SPIN](https://spin-itn.eu/) (Seismological Parameters and INstrumentation).

> This notebook is largely inspired by the [PhaseNet tutorial](https://colab.research.google.com/github/seisbench/seisbench/blob/main/examples/03a_training_phasenet.ipynb) provided by the [SeisBench](https://seisbench.readthedocs.io/en/latest/) project. You can find other interesting tutorials therein. This tutorial shows how to use and re-train the [PhaseNet](https://github.com/AI4EPS/PhaseNet) model with SeisBench. Note that you can also use the distribution of PhaseNet provided in the [dedicated repository](https://github.com/AI4EPS/PhaseNet) and use another version of PyTorch adapted to your hardware. You can find a list of compatible versions of PyTorch at https://pytorch.org/get-started/locally.

<img src="https://raw.githubusercontent.com/seisbench/seisbench/main/docs/_static/seisbench_logo_subtitle_outlined.svg" width=300px alt="SeisBench logo"/>

## 1. Introduction

The tutorial here works with ~5BG of continuous seismic data from Iquique, in the sequence of aftershocks that followed the $M_w$ 8.1 Iquique earthquake that occurred in northern Chile in 2014.

The data is stored in a SeisBench dataset, and should be downloaded beforehand. If you have little space left on your computer or have access to other datasets, it is also possible to use it. Note that re-training PhaseNet requires labels (e.g., _P_ and _S_ wave picks) for the training data. If you have access to such data, you can use it to re-train PhaseNet. If not, you can still use the pre-trained model to make predictions on your data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from seisbench.util import worker_seeding
from seisbench.models import PhaseNet
from seisbench.data import WaveformDataset
from seisbench.generate import GenericGenerator
import seisbench.generate as sbg
import torch
from torch.utils.data import DataLoader

import utils

In [None]:
# This is just to render plots in vector format
%config InlineBackend.figure_formats = "svg"

## 2. Data preparation

First we will get the dataset ready with the SeisBench API. We will use the `seisbench.data` module to load the data. We will do all examples with the Iquique dataset that contains 13,400 examples of picked arrivals from the aftershock sequence following the $M_w$ 8.1 Iquique earthquake that occurred in northern Chile in 2014. All stations are 100 Hz, 3-component stations. The waveforms contain examples of earthquakes only, and was published in Woollam et al. ([2019](https://doi.org/10.1785/0220180312)). 

### 2.1. Downloading or copying the dataset

The dataset is available from the SeisBench website, and can be downloaded with instanciating one of the dataset's name (see https://seisbench.readthedocs.io/en/stable/pages/benchmark_datasets.html for an overview of available datasets in SeisBench. Yet, downloading a dataste might take a while and we recommend that you pick up the data from the available USB sticks. You can choose a location to copy the dataset to, and load it with the following command.

In [None]:
# Load dataset from disk
data = WaveformDataset("../iquique", name="Iquique")

# Show loaded dataset
print(data)

### 2.2 Dataset's content

The datasets consists essentially in two things that you can directly see in the `dataset` repository:

- The __metadata__, stored as a csv file, and loaded into a pandas dataframe. This dataframe contains the information about the waveforms and events, such as the station, the event time, the picked phases, sampling rates, etc. If can be accessed with the `metadata` attribute of the `dataset`.

- The __waveforms__, stored as entries of an hdf5 file. They are loaded dynamically in the memory, so you don't have to load all waveforms at once. Several methods are available to access the waveforms, depeding on if you want to train a model, make predictions or just have a look at the data.

#### Having a look at metadata

You can access them with the `metadata` attribute of the `dataset`. For example, you can have a look at the first 5 rows of the metadata with the following command:

In [None]:
data.metadata.head()

#### Having a look at waveforms

We can also have a look at a random sample in the dataset, and the corresponding waveform. The `get_waveforms` methods only  returns the waveforms as a `np.numpy` array, while the `get_sample` method returns both the waveforms and the metadata, which we can use to annotate the waveforms.

In [None]:
# Get a sample
waveforms, info = data.get_sample(0)

# Plot and annotate
utils.plot_waveforms(waveforms, info)
utils.add_picks(info)

### 2.3 Data generator

Pytorch implements a data generator that allows to load data in batches easily. This is useful for training and evaluating the model. The `seisbench.generator` module provides a `GenericGenerator` class that can be used to generate batches of waveforms. The generator can be instanciated with a `dataset` object. 

With the generator we can add augmentations, which are here used to segment the input data into adequate segments. Note that the model only works with 3001-points long waveforms, so we need to segment the waveforms. We can do this with the `add_augmentation` module, which provides a `Segment` augmentation that can be used to segment the waveforms. 

In [None]:
# Create a data generator
data_generator = sbg.GenericGenerator(data)

# Normalize and label
augmentations = [
    sbg.WindowAroundSample(list(utils.PHASES.keys()), samples_before=3000, windowlen=6000, selection="random", strategy="variable"),
    sbg.RandomWindow(windowlen=3001, strategy="pad"),
    sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
    sbg.ChangeDtype(np.float32),
    sbg.ProbabilisticLabeller(label_columns=utils.PHASES, sigma=30, dim=0)
]

data_generator.add_augmentations(augmentations)

## 3. Loading a model

First, we can investigate pre-trained PhaseNet models. We can load it using `seisbench.models`. Firstly, we can access the list of available pre-trained models on different datasets, also available from SeisBench. The list returns the name of the datasets that PhaseNet was trained on.

In [None]:
# List available PhaseNet models
print(f"Pretrained models for PhaseNet: {PhaseNet.list_pretrained()}")

### 3.1 Pre-trained model loading

Then, we can download one of the model to use for our experiment. A model is really light (several MB), and can be downloaded in a few seconds. We can load it with the `from_pretrained` method. The function returns a `seisbench.models.Model` object, which is a wrapper around the model. The model is a PyTorch model, and can be used as such. 

In [None]:
# Get a pretrained model
model = PhaseNet.from_pretrained("iquique")

print(dir(model))

### 3.2 Observe a forward pass

In [None]:
# Get a random sample
waveform_index = np.random.randint(0, len(data_generator))
sample = data_generator[waveform_index]
x = sample["X"]
y_true = sample["y"]
metadata = data.metadata.iloc[waveform_index]

# Forward pass (no need to calculate gradient since we just want to predict)
with torch.no_grad():
    sample_tensor = torch.tensor(x, device=model.device).unsqueeze(0)
    y_hat = model.forward(sample_tensor)
    y_hat = y_hat.cpu().numpy().squeeze()

# Plot y_hat
ax = utils.plot_waveforms_and_labels(x, y_true, y_hat, metadata=metadata)
ax[-1].set_ylabel("Prediction")

### 3.3 Errors

We here redefined the loss that was used in the paper. This loss is the cross-entropy loss, as defined by the following equation:

$$
\mathcal{L} = - \sum_c \sum_x p_{c}(x) \log \hat{p}_{c}(x)
$$

where $p_{c}(x)$ is the probability of the $c$-th class at the $x$-th point, and $\hat{p}_{c}(x)$ is the predicted probability of the $c$-th class at the $x$-th point.

In [None]:
def loss_fn(y_pred, y_true, eps=1e-5):
        # vector cross entropy loss
        h = y_true * torch.log(y_pred + eps)
        h = h.mean(-1).sum(-1)  # Mean along sample dimension and sum along pick dimension
        h = h.mean()  # Mean over batch axis
        return -h.cpu().numpy()

We can evaluate the performances of our model over a set of samples:

In [None]:
# Collect 
losses = []
for waveform_index in range(500):

    sample = data_generator[waveform_index]
    x = sample["X"]
    y_true = sample["y"]
    metadata = data.metadata.iloc[waveform_index]

    # Forward
    with torch.no_grad():
        sample_tensor = torch.tensor(x, device=model.device).unsqueeze(0)
        y_hat = model.forward(sample_tensor)
        y_hat = y_hat.cpu().numpy().squeeze()

    # Plot y_hat
    losses.append(loss_fn(torch.tensor(y_hat), torch.tensor(y_true)))

# Plot losses of a fiew samples
fig, ax = plt.subplots()
ax.hist(losses, bins=30)
ax.set_xlabel("Loss (cross entropy)")
ax.set_ylabel("Count")

## 4. Retrain the model

First, we need to define a training and a test set.

In [None]:
train, dev, test = data.train_dev_test()

Now we define two generators with identical augmentations, one for training, one for validation. The augmentations are:
1. Selection of a (long) window around a pick. This way, we ensure that out data always contains a pick.
1. Selection of a random window with 3001 samples, the input length of PhaseNet.
1. A normalization, consisting of demeaning and amplitude normalization.
1. A change of datatype to float32, as this is expected by the pytorch model.
1. A probabilistic label

In [None]:
train_generator = sbg.GenericGenerator(train)
dev_generator = sbg.GenericGenerator(dev)

augmentations = [
    sbg.WindowAroundSample(list(utils.PHASES.keys()), samples_before=3000, windowlen=6000, selection="random", strategy="variable"),
    sbg.RandomWindow(windowlen=3001, strategy="pad"),
    sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
    sbg.ChangeDtype(np.float32),
    sbg.ProbabilisticLabeller(label_columns=utils.PHASES, sigma=30, dim=0)
]

train_generator.add_augmentations(augmentations)
dev_generator.add_augmentations(augmentations)

Let's visualize a few training examples. Everytime you run the cell below, you'll see a different training example.

In [None]:
sample = train_generator[np.random.randint(len(train_generator))]
metadata = train.metadata.iloc[waveform_index]
x = sample["X"]
y_true = sample["y"]

# Plot y_hat
ax = utils.plot_waveforms_and_labels(x, y_true, y_true * np.nan, metadata=metadata)
ax[-1].set_ylabel("Prediction")

SeisBench generators are pytorch datasets. Therefore, we can pass them to pytorch data loaders. These will automatically take care of parallel loading and batching. Here we create one loader for training and one for validation. We choose a batch size of 256 samples. This batch size should fit on most hardware.

In [None]:
batch_size = 256
num_workers = 4  # The number of threads used for loading data

train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_seeding)
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding)

### 2.1 Training and testing ingredients

Now we got all components for training the model. What we still need to do is define the optimizer and the loss, and write the training and validation loops.

In [None]:
learning_rate = 1e-2
epochs = 5

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def loss_fn(y_pred, y_true, eps=1e-5):
        # vector cross entropy loss
        h = y_true * torch.log(y_pred + eps)
        h = h.mean(-1).sum(-1)  # Mean along sample dimension and sum along pick dimension
        h = h.mean()  # Mean over batch axis
        return -h

def train_loop(dataloader):
    size = len(dataloader.dataset)
    for batch_id, batch in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(batch["X"].to(model.device))
        loss = loss_fn(pred, batch["y"].to(model.device))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id % 5 == 0:
            loss, current = loss.item(), batch_id * batch["X"].shape[0]
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader):
    num_batches = len(dataloader)
    test_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            pred = model(batch["X"].to(model.device))
            test_loss += loss_fn(pred, batch["y"].to(model.device)).item()

    test_loss /= num_batches
    print(f"Test avg loss: {test_loss:>8f} \n")

### 2.2 Training loop

This loop performs the optimization by training (forward pass, backpropagation, weight update) on the training data. We use the Adam optimizer, which is a good default choice. We use the cross-entropy loss, which is the loss used in the paper.

### 2.3 Test loop   

This loop evaluates the model on the validation data. We use the cross-entropy loss, which is the loss used in the paper.

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader)
    test_loop(dev_loader)

### 2.3 Evaluating the model

Not that we trained the model, we can evaluate it. First, we'll check how the model does on an example from the development set. Note that the model will most likely not be fully trained after only five epochs.

In [None]:
sample = dev_generator[np.random.randint(len(dev_generator))]

with torch.no_grad():
    y_pred = model(torch.tensor(sample["X"], device=model.device).unsqueeze(0))  # Add a fake batch dimension
    y_pred = y_pred[0].cpu().numpy()

sample = train_generator[np.random.randint(len(train_generator))]
metadata = train.metadata.iloc[waveform_index]
x = sample["X"]
y_true = sample["y"]

# Plot y_hat
ax = utils.plot_waveforms_and_labels(x, y_true, y_pred, metadata=metadata)
ax[-1].set_ylabel("Prediction")

As a second option, we'll directly apply our model to an obspy waveform stream using the `annotate` function. For this, we are downloading waveforms through FDSN and annotating them. Note that you could use the `classify` function in a similar fashion.

As we trained the model on Swiss data, we use an example event from Switzerland. Note that we deliberately chose a rather easy example, as the model is not fully trained after the low number of epochs. The exact performance of the model will vary depending, because the model training and initialization involves random aspects.

In [None]:
from obspy.clients.fdsn import Client
from obspy import UTCDateTime

client = Client("ETH")

t = UTCDateTime("2019-11-04T00:59:46.419800Z")
stream = client.get_waveforms(network="CH", station="EMING", location="*", channel="HH?", starttime=t-30, endtime=t+50)

annotations = model.annotate(stream)

print(annotations)

fig = plt.figure()
ax = fig.subplots(2, 1, sharex=True)

offset = annotations[0].stats.starttime - stream[0].stats.starttime
for i, (trace, annotation) in enumerate(zip(stream, annotations)):
    data = trace.data / np.max(np.abs(trace.data))
    ax[0].plot(trace.times(), data + i, label=trace.stats.channel)
    if annotation.stats.channel[-1] != "N":
        ax[1].plot(annotation.times() + offset, annotation.data, label=annotation.stats.channel)

ax[0].set_xlabel("Time (s)")
ax[1].set_xlabel("Time (s)")
ax[1].set_ylabel("Probability")
ax[0].legend()
ax[1].legend()
ax[0].grid()
ax[1].grid()