In [None]:
import xarray as xr
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt


In [None]:

data_dir = "Dataset/"


processed_data = {}

# Calculate global minimum and maximum SST values across all years
global_min_sst = float('inf')
global_max_sst = float('-inf')

for year in range(1981, 2024 + 1):
    year_str = str(year)
    file_path = os.path.join(data_dir, f'sst.day.mean.{year}.nc')  

    try:
        
        dataset = xr.open_dataset(file_path)

        # Subset the dataset for the smaller region of interest
        region_sst = dataset['sst'].where(
            (dataset['lat'] >= -2) & (dataset['lat'] <= 2) &  
            (dataset['lon'] >= 234) & (dataset['lon'] <= 240),  
            drop=True  # Drop points outside this range
        )

        # Update global min and max
        year_min = region_sst.min().values
        year_max = region_sst.max().values
        global_min_sst = min(global_min_sst, year_min)
        global_max_sst = max(global_max_sst, year_max)

        print(f"Checked year {year} for global min and max.")

    except FileNotFoundError:
        print(f"File for {year} not found. Skipping this year.")
    except Exception as e:
        print(f"An error occurred while processing {year}: {e}")

print(f"Global Min SST: {global_min_sst}, Global Max SST: {global_max_sst}")

# Process and scale SST data for each year using global min and max
for year in range(1982, 2024 + 1):
    year_str = str(year)
    file_path = os.path.join(data_dir, f'sst.day.mean.{year}.nc')

    try:
        
        dataset = xr.open_dataset(file_path)

        # Subset the dataset for the smaller region of interest
        region_sst = dataset['sst'].where(
            (dataset['lat'] >= -2) & (dataset['lat'] <= 2) &  
            (dataset['lon'] >= 234) & (dataset['lon'] <= 240),  
            drop=True  # Drop points outside this range
        )
        # Check if there are any NaN values in the region_sst
        has_nan = region_sst.isnull().any()

        
        if has_nan:
            print("The region_sst contains NaN values.")
        else:
            print("The region_sst does not contain any NaN values.")
        # Scale SST values using global min and max
        region_sst_scaled = (2 * (region_sst - global_min_sst) / (global_max_sst - global_min_sst)) - 1

        
        processed_data[year_str] = region_sst_scaled
        print(f"Processed year {year}")

    except FileNotFoundError:
        print(f"File for {year} not found. Skipping this year.")
    except Exception as e:
        print(f"An error occurred while processing {year}: {e}")


In [None]:


class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super(ConvLSTMCell, self).__init__()
        self.hidden_dim = hidden_dim
        self.conv = nn.Conv2d(
            input_dim + hidden_dim, 4 * hidden_dim, kernel_size,
            padding=kernel_size[0] // 2, bias=bias
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

class ConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers):
        super(ConvLSTM, self).__init__()
        self.layers = nn.ModuleList([
            ConvLSTMCell(input_dim if i == 0 else hidden_dim, hidden_dim, kernel_size)
            for i in range(num_layers)
        ])

    def forward(self, input_tensor):
        batch_size, seq_length, _, height, width = input_tensor.size()
        hidden_states = [
            (torch.zeros(batch_size, layer.hidden_dim, height, width, device=input_tensor.device),
             torch.zeros(batch_size, layer.hidden_dim, height, width, device=input_tensor.device))
            for layer in self.layers
        ]
        for t in range(seq_length):
            x = input_tensor[:, t]
            for i, layer in enumerate(self.layers):
                h, c = hidden_states[i]
                h, c = layer(x, (h, c))
                hidden_states[i] = (h, c)
                x = h
        return x

class ConvLSTMWithOutput(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, n_days):
        super(ConvLSTMWithOutput, self).__init__()
        self.convlstm = ConvLSTM(input_dim, hidden_dim, kernel_size, num_layers)
        self.conv_output = nn.Conv2d(hidden_dim, n_days, kernel_size=(1, 1))

    def forward(self, input_tensor):
        lstm_output = self.convlstm(input_tensor)
        output = self.conv_output(lstm_output)
        return output  # Shape: [batch_size, n_days, H, W]


In [None]:


class SSTDataset(Dataset):
    def __init__(self, data, seq_length, n_days):
        self.data = data
        self.seq_length = seq_length
        self.n_days = n_days

    def __len__(self):
        return self.data.shape[0] - (self.seq_length + self.n_days - 1)

    def __getitem__(self, idx):
        input_seq = self.data[idx:idx + self.seq_length]
        target_seq = self.data[idx + self.seq_length:idx + self.seq_length + self.n_days]
        input_seq = input_seq[np.newaxis, :, :, :]  # Add channel dim
        return torch.tensor(input_seq, dtype=torch.float32), torch.tensor(target_seq, dtype=torch.float32)


In [None]:
# Hyperparameters
seq_length = 30
n_days = 30
batch_size = 24
num_epochs = 10
learning_rate = 0.001

# Data Preparation
train_years = list(range(1982, 2006))
val_years = list(range(2006, 2015))
train_data = np.concatenate([processed_data[str(year)] for year in train_years], axis=0)
val_data = np.concatenate([processed_data[str(year)] for year in val_years], axis=0)

train_dataset = SSTDataset(train_data, seq_length, n_days)
val_dataset = SSTDataset(val_data, seq_length, n_days)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvLSTMWithOutput(input_dim=30, hidden_dim=64, kernel_size=(3, 3), num_layers=5, n_days=n_days).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Hyperparameters
seq_length = 30
n_days = 30
batch_size = 24
num_epochs = 10
learning_rate = 0.001

# Data Preparation
train_years = list(range(1982, 2006))
val_years = list(range(2006, 2015))
train_data = np.concatenate([processed_data[str(year)] for year in train_years], axis=0)
val_data = np.concatenate([processed_data[str(year)] for year in val_years], axis=0)

train_dataset = SSTDataset(train_data, seq_length, n_days)
val_dataset = SSTDataset(val_data, seq_length, n_days)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvLSTMWithOutput(input_dim=30, hidden_dim=64, kernel_size=(3, 3), num_layers=5, n_days=n_days).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
best_val_loss = float('inf')
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "convlstm_30D_nino.pth")


In [None]:
def compute_metrics(predictions, targets):
    """
    Compute MAE, RMSE, and R² between predictions and targets.

    Args:
        predictions: Predicted SST values, shape
        targets: Ground truth SST values, shape

    Returns:
        mae: Mean Absolute Error.
        rmse: Root Mean Square Error.
        r2: Coefficient of Determination (R²).
    """
    mae = np.mean(np.abs(predictions - targets))  # Mean Absolute Error
    mse = np.mean((predictions - targets) ** 2)  # Mean Squared Error
    rmse = np.sqrt(mse)  # Root Mean Squared Error

    # Compute R²
    ss_total = np.sum((targets - targets.mean()) ** 2)  # Total sum of squares
    ss_residual = np.sum((targets - predictions) ** 2)  # Residual sum of squares
    r2 = 1 - (ss_residual / ss_total) if ss_total > 0 else 0

    return mae, rmse, r2


def inverse_scale(data, global_min_sst, global_max_sst):
    return ((data + 1) / 2) * (global_max_sst - global_min_sst) + global_min_sst

In [None]:
# Testing Function
def test_model(model, loader, device):
    model.eval()
    all_predictions, all_targets = [], []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            all_predictions.append(outputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    return np.concatenate(all_predictions), np.concatenate(all_targets)

# Test Data
test_years = list(range(2015, 2025))  
test_data = np.concatenate([processed_data[str(year)] for year in test_years], axis=0)
test_dataset = SSTDataset(test_data, seq_length, n_days)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Load Best Model
model.load_state_dict(torch.load("convlstm_30D_nino.pth"))
test_predictions, test_targets = test_model(model, test_loader, device)

# Rescale and Evaluate
test_predictions_rescaled = inverse_scale(test_predictions, global_min_sst, global_max_sst)
test_targets_rescaled = inverse_scale(test_targets, global_min_sst, global_max_sst)
mae, rmse, r2 = compute_metrics(test_predictions_rescaled, test_targets_rescaled)

print(f"Test MAE: {mae:.4f} °C")
print(f"Test RMSE: {rmse:.4f} °C")
print(f"Test R²: {r2:.4f}")


In [None]:


# Function to plot heatmap for a specific day
def plot_heatmap(data, title, day_index, global_min_sst, global_max_sst):
    """
    Plot a heatmap for a specific day.
    
    Args:
        data: The SST data (shape: [steps, days, height, width]).
        title: Title for the plot.
        day_index: Index for the day to be plotted (aligned with test_targets).
        global_min_sst: Minimum SST value for consistent color scale.
        global_max_sst: Maximum SST value for consistent color scale.
    """
    plt.figure(figsize=(8, 6))
    plt.imshow(data, cmap='coolwarm', vmin=global_min_sst, vmax=global_max_sst)
    plt.colorbar(label="Sea Surface Temperature (°C)")
    plt.title(f"{title} (Day {day_index + 30})", fontsize=14)
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.show()


day_to_plot = 31  # Adjust this value as needed

# Map the actual day in the test dataset to the index in test_targets
# The first valid day in test_targets corresponds to day seq_length + n_days - 1
seq_length = 30
n_days = 30
aligned_day_index = day_to_plot - seq_length

# Ensure the day is valid
if aligned_day_index < 0 or aligned_day_index >= test_targets.shape[0]:
    print(f"Invalid day_to_plot: {day_to_plot}. It must be between {seq_length} and {test_targets.shape[0] + seq_length - 1}.")
else:
    
    # Get the heatmap data for the specific day
    ground_truth_heatmap = test_targets_rescaled[aligned_day_index, -1]  # Last predicted day in target
    prediction_heatmap = test_predictions_rescaled[aligned_day_index, -1]  # Last predicted day in prediction

    # Plot the ground truth and prediction heatmaps
    plot_heatmap(ground_truth_heatmap, f"Ground Truth SST", day_to_plot, global_min_sst, global_max_sst)
    plot_heatmap(prediction_heatmap, f"Prediction SST (ConvLSTM)", day_to_plot, global_min_sst, global_max_sst)