In [None]:
import xarray as xr
import numpy as np
import os
import matplotlib.pyplot as plt
import numpy as np

data_dir = "Dataset/"

processed_data = {}

#Calculate global minimum and maximum SST values across all years
global_min_sst = float('inf')
global_max_sst = float('-inf')

for year in range(1981, 2024 + 1):
    year_str = str(year)
    file_path = os.path.join(data_dir, f'sst.day.mean.{year}.nc') 

    try:
        # Loading the dataset
        dataset = xr.open_dataset(file_path)

        # Subset the dataset for the smaller region of interest
        region_sst = dataset['sst'].where(
            (dataset['lat'] >= -2) & (dataset['lat'] <= 2) &  
            (dataset['lon'] >= 234) & (dataset['lon'] <= 240), 
            drop=True  # Drop points outside this range
        )

        
        year_min = region_sst.min().values
        year_max = region_sst.max().values
        global_min_sst = min(global_min_sst, year_min)
        global_max_sst = max(global_max_sst, year_max)

        print(f"Checked year {year} for global min and max.")

    except FileNotFoundError:
        print(f"File for {year} not found. Skipping this year.")
    except Exception as e:
        print(f"An error occurred while processing {year}: {e}")

print(f"Global Min SST: {global_min_sst}, Global Max SST: {global_max_sst}")

# Process and scale SST data for each year using global min and max
for year in range(1982, 2024 + 1):
    year_str = str(year)
    file_path = os.path.join(data_dir, f'sst.day.mean.{year}.nc')

    try:
        
        dataset = xr.open_dataset(file_path)

        
        region_sst = dataset['sst'].where(
            (dataset['lat'] >= -2) & (dataset['lat'] <= 2) &  
            (dataset['lon'] >= 234) & (dataset['lon'] <= 240), 
            drop=True  
        )
        
        region_sst_scaled = (2 * (region_sst - global_min_sst) / (global_max_sst - global_min_sst)) - 1

        # Store the processed data in the dictionary with year as key
        processed_data[year_str] = region_sst_scaled
        print(f"Processed year {year}")

    except FileNotFoundError:
        print(f"File for {year} not found. Skipping this year.")
    except Exception as e:
        print(f"An error occurred while processing {year}: {e}")


In [None]:
test_years = list(range(2015, 2025))
test_data = np.concatenate([processed_data[str(year)] for year in test_years], axis=0)

In [None]:
def average_model_n_days_iterative(test_data, seq_length, n_days):
    total_days = test_data.shape[0]
    predictions = []
    targets = []

    for t in range(seq_length, total_days - n_days + 1):
        # Initialize the input window using ground truth
        input_window = test_data[t - seq_length:t].copy()
        future_preds = []
        targets.append(test_data[t:t + n_days])  # Collect ground truth for the same range

        for day in range(n_days):
            # Compute the average over the input window
            avg_prediction = np.mean(input_window, axis=0)  

            future_preds.append(avg_prediction)

            # Update the input window: shift by removing the oldest value and appending the predicted day
            input_window = np.concatenate([input_window[1:], avg_prediction[np.newaxis, :, :]], axis=0)

        predictions.append(np.stack(future_preds, axis=0))  

    predictions = np.array(predictions)  
    targets = np.array(targets)  
    return predictions, targets


In [None]:
def compute_metrics(predictions, targets):
    """
    Compute MAE, RMSE, and R² between predictions and targets.

    Args:
        predictions: Predicted SST values, shape
        targets: Ground truth SST values, shape 

    Returns:
        mae: Mean Absolute Error.
        rmse: Root Mean Square Error.
        r2: Coefficient of Determination (R²).
    """
    mae = np.mean(np.abs(predictions - targets))  # Mean Absolute Error
    mse = np.mean((predictions - targets) ** 2)  # Mean Squared Error
    rmse = np.sqrt(mse)  # Root Mean Squared Error

    # Compute R²
    ss_total = np.sum((targets - targets.mean()) ** 2)  # Total sum of squares
    ss_residual = np.sum((targets - predictions) ** 2)  # Residual sum of squares
    r2 = 1 - (ss_residual / ss_total) if ss_total > 0 else 0

    return mae, rmse, r2


def inverse_scale(data, global_min_sst, global_max_sst):
    return ((data + 1) / 2) * (global_max_sst - global_min_sst) + global_min_sst

In [None]:
# Compute Metrics
seq_length = 30  # Input window size
n_days = 30

average_predictions, average_targets = average_model_n_days_iterative(test_data, seq_length, n_days)

# Rescale Predictions and Ground Truth
average_predictions_rescaled = inverse_scale(average_predictions, global_min_sst, global_max_sst)
ground_truth_rescaled = inverse_scale(average_targets, global_min_sst, global_max_sst)
mae, rmse, r2 = compute_metrics(average_predictions_rescaled, ground_truth_rescaled)

print(f"Average Model - MAE: {mae:.4f} °C")
print(f"Average Model - RMSE: {rmse:.4f} °C")
print(f"Average Model - R²: {r2:.4f}")

In [None]:
# Heatmap Plotting Function
def plot_heatmap_baseline(data, title, day_index, global_min_sst, global_max_sst):
    plt.figure(figsize=(8, 6))
    plt.imshow(data, cmap='coolwarm', vmin=global_min_sst, vmax=global_max_sst)
    plt.colorbar(label="Sea Surface Temperature (°C)")
    plt.title(f"{title} (Day {day_index + 30})", fontsize=14)
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.show()

In [None]:

seq_length = 30
n_days = 30  # Predict 30 days into the future
day_to_plot = 31 # The specific day for heatmap plotting


aligned_day_index = day_to_plot - seq_length
if aligned_day_index < 0 or aligned_day_index >= test_data.shape[0] - n_days + 1:
    print(f"Invalid day_to_plot: {day_to_plot}")
else:
    # Average Model Predictions
    avg_predictions, avg_targets = average_model_n_days_iterative(test_data, seq_length, n_days)
    avg_predictions_rescaled = inverse_scale(avg_predictions, global_min_sst, global_max_sst)
    avg_targets_rescaled = inverse_scale(avg_targets, global_min_sst, global_max_sst)


    # Extract corresponding ground truth and predictions for the specified day
    avg_prediction_heatmap = avg_predictions_rescaled[aligned_day_index, -1]  # Last day prediction
    ground_truth_heatmap = avg_targets_rescaled[aligned_day_index, -1]  

    # Plot Heatmaps
    plot_heatmap_baseline(ground_truth_heatmap, f"Ground Truth SST", day_to_plot, global_min_sst, global_max_sst)
    plot_heatmap_baseline(avg_prediction_heatmap, f"Average Model Prediction", day_to_plot, global_min_sst, global_max_sst)
