<a href="https://colab.research.google.com/github/hawaii-ai/tutorial_climate_ai/blob/main/02_classify_ENSO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ENSO Phase Prediction with Deep Learning

## Import Libraries

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import xarray as xr
from torch.utils.data import DataLoader

In [None]:
torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

## Download and Preprocess Data

In [None]:
netcdf_file = xr.open_dataset("data/ersst_pacific_anom.nc")
netcdf_file

In [None]:
ssta = (
    netcdf_file.ssta.expand_dims("channel")
    .transpose("time", "channel", "lat", "lon")
    .fillna(0)
)
ssta

In [None]:
labels_df = pd.read_csv("data/labels.csv")
labels_df

In [None]:
one_hot_labels = nn.functional.one_hot(
    torch.from_numpy(labels_df[["lead_1", "lead_2", "lead_3"]].values).long(),
    num_classes=3,
).float()

## Define the Training Set

In [None]:
# now we have to create a dataset that can be used by pytorch

class ENSODataset(torch.utils.data.Dataset):
    def __init__(self, ssta, labels):
        self.ssta = ssta
        self.labels = labels

    def __len__(self):
        return len(self.ssta)

    def __getitem__(self, idx):
        return self.ssta[idx], self.labels[idx]


# we will use the first 80% of the data for training and the last 20% for testing
train_size = int(0.8 * len(ssta))

train_dataset = ENSODataset(ssta[:train_size].data.copy(), one_hot_labels[:train_size])
test_dataset = ENSODataset(ssta[train_size:].data.copy(), one_hot_labels[train_size:])

In [None]:
# build the dataloaders

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32)

## Defining the Model

In [None]:
class ENSOModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 8 * 20, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))

        x_lead_1 = self.fc3(x)
        x_lead_2 = self.fc3(x)
        x_lead_3 = self.fc3(x)

        return x_lead_1, x_lead_2, x_lead_3

## Training the Model

In [None]:
model = ENSOModel()

# define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# train the model
epochs = 60
train_loss = []

for epoch in range(epochs):
    epoch_loss = []

    model.train()

    for batch_idx, (data, target) in enumerate(train_dataloader):
        optimizer.zero_grad()

        # forward + backward + optimize
        output_lead_1, output_lead_2, output_lead_3 = model(data)
        loss = (
            loss_fn(output_lead_1, target[:, 0])
            + loss_fn(output_lead_2, target[:, 1])
            + loss_fn(output_lead_3, target[:, 2])
        )

        epoch_loss.append(loss.item() * data.size(0))

        loss.backward()
        optimizer.step()

        # print statistics
        if batch_idx % 400 == 0:
            print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}")
    train_loss.append(sum(epoch_loss) / len(train_dataset))

In [None]:
# plot the training loss
plt.plot(train_loss)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

## Evaluating the Model

In [None]:
model.eval()

# test the accuracy per lead

correct_lead_1 = 0
correct_lead_2 = 0
correct_lead_3 = 0

with torch.no_grad():
    for data, target in test_dataloader:
        output_lead_1, output_lead_2, output_lead_3 = model(data)

        _, predicted_lead_1 = torch.max(output_lead_1.data, 1)
        _, predicted_lead_2 = torch.max(output_lead_2.data, 1)
        _, predicted_lead_3 = torch.max(output_lead_3.data, 1)

        _, target_lead_1 = torch.max(target[:, 0], 1)
        _, target_lead_2 = torch.max(target[:, 1], 1)
        _, target_lead_3 = torch.max(target[:, 2], 1)

        correct_lead_1 += (predicted_lead_1 == target_lead_1).sum().item()
        correct_lead_2 += (predicted_lead_2 == target_lead_2).sum().item()
        correct_lead_3 += (predicted_lead_3 == target_lead_3).sum().item()

# print the accuracy per lead in %
print(f"Accuracy Lead 1: {100 * correct_lead_1 / len(test_dataset):.2f}%")
print(f"Accuracy Lead 2: {100 * correct_lead_2 / len(test_dataset):.2f}%")
print(f"Accuracy Lead 3: {100 * correct_lead_3 / len(test_dataset):.2f}%")

## Forecast with December 2023

In [None]:
dec2023 = xr.open_dataset("data/ersst_anom_dec2023.nc").ssta
dec2023

In [None]:
# evaluate the model on the dec2023 data
model.eval()

pred_input = torch.from_numpy(
    dec2023.expand_dims("channel")
    .transpose("channel", "lat", "lon")
    .fillna(0)
    .data.copy()
)

with torch.no_grad():
    model_output = model(pred_input.unsqueeze(0))

In [None]:
print("0: Neutral, 1: El Nino, 2: La Nina")
print(f"Prediction for January 2024: {model_output[0].argmax()}")
print(f"Prediction for February 2024: {model_output[1].argmax()}")
print(f"Prediction for March 2024: {model_output[2].argmax()}")