# Libraries

In [1]:
import xarray as xr
from datetime import datetime

import torch

from aurora import AuroraSmall, Batch, Metadata, rollout
import matplotlib.pyplot as plt

from pathlib import Path

import cdsapi
import numpy as np
from sklearn.metrics import root_mean_squared_error
import gcsfs

from torch.utils.data import Dataset
from aurora import Batch, Metadata
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

  return torch._C._cuda_getDeviceCount() > 0


False

In [1]:
import torch
print(torch.cuda.is_available())


False


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())


2.6.0+cu124
False


# Clear memory

In [3]:
# torch.cuda.empty_cache()

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"

# torch.cuda.memory_allocated()

In [4]:

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [5]:
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cuda.enable_math_sdp(True)

# Load the model

In [6]:
model = AuroraSmall()

model.load_state_dict(torch.load('../model/aurora.pth'))

<All keys matched successfully>

In [7]:
# # model.half()
# tensor = torch.tensor.half()

# Data

## World

In [8]:
fs = gcsfs.GCSFileSystem(token="anon")

store = fs.get_mapper('gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr')
full_era5 = xr.open_zarr(store=store, consolidated=True, chunks=None)
# df_2022 = ds.sel({'time':np.arange('2022-01-01','2023-01-11',dtype='datetime64[ns]')})

### Subset data from 2022

In [9]:
start_time = '2022-01-01'
end_time = '2023-01-31'
data_inner_steps = 6  # process every 6th hour

sliced_era5_world = (
    full_era5
    # [['geopotential', 'specific_humidity', 'temperature', 'u_component_of_wind', 'v_component_of_wind']]
    # .pipe(
    #     xarray_utils.selective_temporal_shift,
    #     variables=model.forcing_variables,
    #     time_shift='24 hours',
    # )
    .sel(time=slice(start_time, end_time, data_inner_steps))
    # .compute()
)

### Surface variables

In [10]:
# List of surface variable names
surface_vars = ['2m_temperature', '10m_u_component_of_wind', '10m_v_component_of_wind', 'mean_sea_level_pressure']

# Select surface variables
surf_vars_ds = sliced_era5_world[surface_vars]


### Atmospherique variables

In [11]:
atmostpheric_variables = ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity", "geopotential"]
atmos_vars_ds = sliced_era5_world[atmostpheric_variables]

## Static variables

In [12]:
static_variables = ["land_sea_mask", "soil_type", "geopotential_at_surface"]
static_vars_ds = sliced_era5_world[static_variables]

In [13]:
# surf_vars_ds

In [14]:
# static_vars_ds["geopotential_at_surface"].values.shape

## Create batches

In [15]:

class ERA5ZarrDataset(Dataset):
    def __init__(self, surf_vars_ds, atmos_vars_ds, static_vars_ds, sequence_length):
        self.surf_vars_ds = surf_vars_ds
        self.atmos_vars_ds = atmos_vars_ds
        self.static_vars_ds = static_vars_ds
        self.sequence_length = sequence_length
        self.time_indices = range(sequence_length, len(surf_vars_ds.time))

    def __len__(self):
        return len(self.time_indices)

    def __getitem__(self, idx):
        i = self.time_indices[idx]
        # time_slice = slice(i - self.sequence_length, i)

        surf_vars = {
            "2t": torch.from_numpy(self.surf_vars_ds["2m_temperature"].values[[i - 1, i]][None]),
            "10u": torch.from_numpy(self.surf_vars_ds["10m_u_component_of_wind"].values[[i - 1, i]][None]),
            "10v": torch.from_numpy(self.surf_vars_ds["10m_v_component_of_wind"].values[[i - 1, i]][None]),
            "msl": torch.from_numpy(self.surf_vars_ds["mean_sea_level_pressure"].values[[i - 1, i]][None]),
        }

        static_vars = {
            "z": torch.from_numpy(self.static_vars_ds["geopotential_at_surface"].values),
            "slt": torch.from_numpy(self.static_vars_ds["soil_type"].values),
            "lsm": torch.from_numpy(self.static_vars_ds["land_sea_mask"].values),
        }

        atmos_vars = {
            "t": torch.from_numpy(self.atmos_vars_ds["temperature"].values[[i - 1, i]][None]),
            "u": torch.from_numpy(self.atmos_vars_ds["u_component_of_wind"].values[[i - 1, i]][None]),
            "v": torch.from_numpy(self.atmos_vars_ds["v_component_of_wind"].values[[i - 1, i]][None]),
            "q": torch.from_numpy(self.atmos_vars_ds["specific_humidity"].values[[i - 1, i]][None]),
            "z": torch.from_numpy(self.atmos_vars_ds["geopotential"].values[[i - 1, i]][None]),
        }

        metadata=Metadata(
        lat=torch.from_numpy(self.surf_vars_ds.latitude.values),
        lon=torch.from_numpy(self.surf_vars_ds.longitude.values),
        time=(self.surf_vars_ds.time.values.astype("datetime64[s]").tolist()[i],),
        atmos_levels=tuple(int(level) for level in self.atmos_vars_ds.level.values)
    )


        return Batch(surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, metadata=metadata)


In [16]:
#   metadata=Metadata(
#         lat=torch.from_numpy(surf_vars_ds.latitude.values),
#         lon=torch.from_numpy(surf_vars_ds.longitude.values),
#         # Converting to `datetime64[s]` ensures that the output of `tolist()` gives
#         # `datetime.datetime`s. Note that this needs to be a tuple of length one:
#         # one value for every batch element.
#         time=(surf_vars_ds.valid_time.values.astype("datetime64[s]").tolist()[i],),
#         atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
#     ),
# )

In [17]:

# class ERA5ZarrDataset(Dataset):
#     def __init__(self, surf_vars_ds, atmos_vars_ds, static_vars_ds, sequence_length):
#         self.surf_vars_ds = surf_vars_ds
#         self.atmos_vars_ds = atmos_vars_ds
#         self.static_vars_ds = static_vars_ds
#         self.sequence_length = sequence_length
#         self.time_indices = range(sequence_length, len(surf_vars_ds.time))

#     def __len__(self):
#         return len(self.time_indices)

#     def __getitem__(self, idx):
#         i = self.time_indices[idx]
#         time_slice = slice(i - self.sequence_length, i)

#         surf_vars = {
#             "2t": torch.from_numpy(self.surf_vars_ds["2m_temperature"].isel(time=time_slice).values[None]),
#             "10u": torch.from_numpy(self.surf_vars_ds["10m_u_component_of_wind"].isel(time=time_slice).values[None]),
#             "10v": torch.from_numpy(self.surf_vars_ds["10m_v_component_of_wind"].isel(time=time_slice).values[None]),
#             "msl": torch.from_numpy(self.surf_vars_ds["mean_sea_level_pressure"].isel(time=time_slice).values[None]),
#         }

#         static_vars = {
#             "z": torch.from_numpy(self.static_vars_ds["geopotential_at_surface"].values),
#             "slt": torch.from_numpy(self.static_vars_ds["soil_type"].values),
#             "lsm": torch.from_numpy(self.static_vars_ds["land_sea_mask"].values),
#         }

#         atmos_vars = {
#             "t": torch.from_numpy(self.atmos_vars_ds["temperature"].isel(time=time_slice).values[None]),
#             "u": torch.from_numpy(self.atmos_vars_ds["u_component_of_wind"].isel(time=time_slice).values[None]),
#             "v": torch.from_numpy(self.atmos_vars_ds["v_component_of_wind"].isel(time=time_slice).values[None]),
#             "q": torch.from_numpy(self.atmos_vars_ds["specific_humidity"].isel(time=time_slice).values[None]),
#             "z": torch.from_numpy(self.atmos_vars_ds["geopotential"].isel(time=time_slice).values[None]),
#         }

#         metadata = Metadata(
#             lat=torch.from_numpy(self.surf_vars_ds.latitude.values),
#             lon=torch.from_numpy(self.surf_vars_ds.longitude.values),
#             time=tuple(self.surf_vars_ds.time.isel(time=time_slice).values.astype("datetime64[s]").tolist(),),
#             atmos_levels=tuple(int(level) for level in self.atmos_vars_ds.level.values),
#         )

#         return Batch(surf_vars=surf_vars, static_vars=static_vars, atmos_vars=atmos_vars, metadata=metadata)


In [18]:
world_batches = ERA5ZarrDataset(surf_vars_ds, atmos_vars_ds, static_vars_ds,0)

In [19]:
# world_batches

### South Africa Data

In [20]:
# sa_area = [
#     -22.12,  # North latitude
#     15.125,   # West longitude
#     -37.875,  # South latitude
#     34.875    # East longitude
# ]

# lat_min, lon_min, lat_max, lon_max = sa_area

In [21]:
start_time = '2022-01-01'
end_time = '2023-01-31'
data_inner_steps = 6  # process every 6th hour

# Adjusted latitude and longitude to get exactly 64 x 80 grid points
lat_max = -22.00  # Keep latitude unchanged (64 rows are correct)
lat_min = -37.75  # Keep latitude unchanged

lon_min = 15.25   # Keep longitude start the same
lon_max = 35.00   # Increase longitude slightly to add 1 extra column

sliced_era5_SA = (
    full_era5
    .sel(
        time=slice(start_time, end_time, data_inner_steps),
        latitude=slice(lat_max, lat_min),
        longitude=slice(lon_min, lon_max)  # Expanding max longitude
    )
)


In [22]:
surf_vars_ds_SA = sliced_era5_SA[surface_vars]
atmos_vars_ds_SA = sliced_era5_SA[atmostpheric_variables]
static_vars_ds_SA = sliced_era5_SA[static_variables]

In [23]:
SA_batches = ERA5ZarrDataset(surf_vars_ds_SA, atmos_vars_ds_SA, static_vars_ds_SA,0)

## Predictions Function

In [24]:
def predict_fn(model, batch):
    model.eval()
    model = model.to("cuda")
    batch = batch.to("cuda")
    with torch.inference_mode():
        preds = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]

    model = model.to("cpu")
    batch = batch.to("cpu")
    return preds

# RMSEs World dataset

In [25]:
def rmse_fn(model, actual_batch, var_name, var_type="surface", atmos_level_idx=0):
    predictions = predict_fn(model, batch=actual_batch)
    two_steps_rmse = []
    pred_dates = []
    print(actual.shape, prediction.shape)
    for i in range(len(predictions)):
        pred = predictions[i]
        if var_type=="surface":
            prediction = pred.surf_vars[var_name][0, 0].numpy()
            actual = actual_batch.surf_vars[var_name][0, 0].numpy()[:-1, :]
            
            rmse = root_mean_squared_error(actual.flatten(), prediction.flatten())
            two_steps_rmse.append(rmse)
            pred_dates.append(pred.metadata.time[0])
        # Atmospherique variable
        elif var_type=="atmosphere":
            prediction = pred.atmos_vars[var_name].squeeze()[atmos_level_idx,:,:].numpy().squeeze()
            actual = actual_batch.atmos_vars[var_name].squeeze()[i,atmos_level_idx,:,:].numpy()[:-1,:]
            rmse = root_mean_squared_error(actual.flatten(), prediction.flatten())
            two_steps_rmse.append(rmse)
            pred_dates.append(pred.metadata.time[0])
    return two_steps_rmse, pred_dates

# RMSEs South Africa dataset

In [26]:
def rmse_fn_sa(model, actual_batch, var_name, var_type="surface",  atmos_level_idx=0):
    predictions = predict_fn(model, batch=actual_batch)
    two_steps_rmse = []
    pred_dates = []
    for i in range(len(predictions)):
        pred = predictions[i]
        if var_type=="surface":
            prediction = pred.surf_vars[var_name][0, 0].numpy()
            actual = actual_batch.surf_vars[var_name][0, 0].numpy()
            rmse = root_mean_squared_error(actual.flatten(), prediction.flatten())
            two_steps_rmse.append(rmse)
            pred_dates.append(pred.metadata.time[0])
        elif var_type=="atmosphere":
            prediction = pred.atmos_vars[var_name].squeeze()[atmos_level_idx,:,:].numpy().squeeze()
            actual = actual_batch.atmos_vars[var_name].squeeze()[i,atmos_level_idx,:,:].numpy()
            rmse = root_mean_squared_error(actual.flatten(), prediction.flatten())
            two_steps_rmse.append(rmse)
            pred_dates.append(pred.metadata.time[0])
    return two_steps_rmse, pred_dates

# PLot RMSES

In [27]:
# def plot_rmses(rmses_world, rmses_sa, 
#             figsize=(12, 8), fontsize=18,
#             alpha=0.2, date_ranges=None, 
#             title="Two steps forward prediction: RMSES"):

#       fig, ax = plt.subplots(figsize=figsize, dpi=1000)

#       ax.plot(np.array(rmses_world)[:,0], label="World first time step prediction")
#       ax.plot(np.array(rmses_sa)[:,0], label="South Africa first time step prediction")
#       ax.plot(np.array(rmses_world)[:,1], label="World second time step prediction")
#       ax.plot(np.array(rmses_sa)[:,1], label="Soth Africa second time step prediction")
      
#       date_times = [dt for pair in date_ranges for dt in pair]

#       formatted_dates = [
#       f"{dt.strftime('%Y-%m-%d')} ({dt.strftime('%H:%M')})"
#       for dt in date_times
#       ]
#       ax.set_xticks(formatted_dates, rotation=30, ha='right')

#       ax.legend(
#             title="RMSES\n", title_fontsize=fontsize, markerscale=20,
#             bbox_to_anchor=(1.06, 0.8), loc="center left" , frameon=False
#                               )


#       ax.set_xlabel("Dates")
#       ax.set_ylabel("RMSEs")
#       ax.set_title(title, fontsize=fontsize, pad=20)
#       print('\n\n') ; plt.show()



In [27]:

def plot_rmses(variable, rmses_world, rmses_sa, 
               figsize=(12, 8), fontsize=18,
               alpha=0.2, date_ranges=None, 
               title="Two steps forward prediction: RMSEs",
               save_path="../report/rmses_world_SA"):

    fig, ax = plt.subplots(figsize=figsize, dpi=1000)

    # Extract dates
    date_times_6_hours = [date1 for date1, date2 in date_ranges]
    date_times_12_hours = [date2 for date1, date2 in date_ranges]
    formatted_dates_6_hours = [dt.strftime('%Y-%m-%d (%H:%M)') for dt in date_times_6_hours]
    formatted_dates_12_hours = [dt.strftime('%Y-%m-%d (%H:%M)') for dt in date_times_12_hours]

    # Convert x-axis to indices
    x_indices = np.arange(len(formatted_dates_6_hours))

    # Select a subset of dates for x-axis labels
    num_ticks = min(6, len(formatted_dates_6_hours))  # Show max 6 ticks
    tick_positions = np.linspace(0, len(formatted_dates_6_hours) - 1, num_ticks, dtype=int)

    ax.plot(x_indices, np.array(rmses_world)[:,0], label="World 6 hours forward prediction")
    ax.plot(x_indices, np.array(rmses_sa)[:,0], label="South Africa 6 hours forward prediction")
    ax.plot(x_indices, np.array(rmses_world)[:,1], label="World 12 hours forward prediction")
    ax.plot(x_indices, np.array(rmses_sa)[:,1], label="South Africa 12 hours forward prediction")

    # Set selected x-ticks
    ax.set_xticks(tick_positions)
    ax.set_xticklabels([formatted_dates_6_hours[i] for i in tick_positions], rotation=30, ha='right')

    ax.legend(
    title="RMSEs\n", title_fontsize=fontsize, markerscale=1.5,
    bbox_to_anchor=(0.98, 0.98), loc="upper left", frameon=False
)

    ax.set_xlabel("Dates")
    ax.set_ylabel("RMSEs")
    ax.set_title(title, fontsize=fontsize, pad=20)
    
    plt.savefig(f"{save_path}/rmse-{variable}.pdf")
    plt.savefig(f"{save_path}/rmse-{variable}.png")
    plt.savefig(f"{save_path}/rmse-{variable}.svg")

    plt.show()


In [28]:

def plot_rmses(variable, rmses_world, rmses_sa, 
               figsize=(12, 8), fontsize=18,
               date_ranges=None, 
               title="Two Steps Forward Prediction: RMSEs",
               save_path="../report/rmses_world_SA"):

    fig, ax = plt.subplots(figsize=figsize, dpi=300)

    # Extract dates
    date_times_6_hours = [date1 for date1, date2 in date_ranges]
    date_times_12_hours = [date2 for date1, date2 in date_ranges]
    formatted_dates_6_hours = [dt.strftime('%Y-%m-%d (%H:%M)') for dt in date_times_6_hours]
    formatted_dates_12_hours = [dt.strftime('%Y-%m-%d (%H:%M)') for dt in date_times_12_hours]

    # Convert x-axis to indices
    x_indices = np.arange(len(formatted_dates_6_hours))

    # Select a subset of dates for x-axis labels
    num_ticks = min(6, len(formatted_dates_6_hours))
    tick_positions = np.linspace(0, len(formatted_dates_6_hours) - 1, num_ticks, dtype=int)

    # Plot RMSEs with improved colors and styles
    ax.plot(x_indices, np.array(rmses_world)[:, 0], label="Global RMSE (6h Forecast)", color="blue", linestyle="-", linewidth=2)
    ax.plot(x_indices, np.array(rmses_sa)[:, 0], label="South Africa RMSE (6h Forecast)", color="orange", linestyle="-", linewidth=2)
    ax.plot(x_indices, np.array(rmses_world)[:, 1], label="Global RMSE (12h Forecast)", color="blue", linestyle="--", linewidth=2)
    ax.plot(x_indices, np.array(rmses_sa)[:, 1], label="South Africa RMSE (12h Forecast)", color="orange", linestyle="--", linewidth=2)

    # Set selected x-ticks
    ax.set_xticks(tick_positions)
    ax.set_xticklabels([formatted_dates_12_hours[i] for i in tick_positions], rotation=30, ha='right')

    # Improve legend appearance
    ax.legend(title="Forecast Horizon", title_fontsize=fontsize-2, fontsize=fontsize-4,
              bbox_to_anchor=(1.05, 1), loc="upper left", frameon=False)

    # Improve axis labels and title
    ax.set_xlabel("Forecast Date", fontsize=fontsize-2)
    ax.set_ylabel("Root Mean Squared Error (RMSE)", fontsize=fontsize-2)
    ax.set_title(title, fontsize=fontsize, pad=20)

    # Save the plots
    plt.savefig(f"{save_path}/rmse-{variable}.pdf", bbox_inches="tight")
    plt.savefig(f"{save_path}/rmse-{variable}.png", bbox_inches="tight", dpi=300)
    plt.savefig(f"{save_path}/rmse-{variable}.svg", bbox_inches="tight")

    plt.show()


In [30]:
# data = iter(world_batches)

In [31]:
# data0 = next(data)

In [32]:
# data0.atmos_vars["t"].squeeze()[0,0,:,:].numpy()[:-1,:].shape

In [33]:
# data0.static_vars["lsm"].shape

In [34]:
# data0.surf_vars["10v"].shape

In [35]:
# data0.surf_vars["2t"].shape

In [36]:
# preds = predict_fn(model, data0)

In [37]:
# preds[0].surf_vars["2t"].shape

In [38]:
# preds.surf_vars["10v"][1,:,:,:].shape

In [39]:
# data0.metadata.time

# Surface Variables

## Two-meter temperature in K: 2t

In [40]:
# data0.surf_vars["2t"][1].shape

In [29]:
rmses_world_2t = []
dates_world_2t = []

for batch in world_batches:
    rmse, date = rmse_fn(model, batch, "2t", var_type="surface")
    rmses_world_2t.append(rmse)
    dates_world_2t.append(date)

: 

In [33]:
# plt.plot(np.array(rmses)[:,0], label="1")
# plt.plot(np.array(rmses)[:,1], label="2")
# plt.legend()

In [34]:
# rmse_fn(model, data0, "2t")

In [35]:
# i=0
# for batch in world_batches:
#    i+=1
# i 

In [23]:
# data = iter(sa_batches)
# data0 = next(data)

In [24]:
# data0.surf_vars["2t"].shape

In [26]:
rmses_SA_2t = []
dates_SA_2t = []
for batch in SA_batches:
    rmse, date = rmse_fn_sa(model, batch, "2t", var_type="surface")
    rmses_SA_2t.append(rmse)
    dates_SA_2t.append(date)

In [None]:
# plt.plot(np.array(rmses)[:,0], label="w1")
# plt.plot(np.array(rmses)[:,1], label="w2")
# plt.plot(np.array(rmses_sa)[:,0], label="sa1")
# plt.plot(np.array(rmses_sa)[:,1], label="sa2")
# plt.legend()

In [None]:
# print(sliced_era5.latitude.shape, sliced_era5.longitude.shape)


In [None]:
plot_rmses("2t",rmses_world_2t, rmses_SA_2t, 
            figsize=(12, 8), fontsize=18,
            alpha=0.2, date_ranges=dates_world_2t, title="Temperature two steps forward prediction: RMSES")

## Ten-meter eastward wind speed in m/s :U10

In [66]:
rmses_world_u10 = []
dates_world_u10 = []
for batch in world_batches:
    rmse, date = rmse_fn(model, batch, "10u", var_type="surface")
    rmses_world_u10.append(rmse)
    dates_world_u10.append(date)

In [67]:
rmses_SA_u10 = []
dates_SA_u10 = []
for batch in SA_batches:
    rmse, date = rmse_fn_sa(model, batch, "10u", var_type="surface")
    rmses_SA_u10.append(rmse)
    dates_SA_u10.append(date)

In [None]:
plot_rmses("u10",rmses_world_u10, rmses_SA_u10, 
            figsize=(12, 8), fontsize=18,
            alpha=0.2, date_ranges=dates_world_u10, title="Ten-meter eastward wind speed two steps forward prediction: RMSES")

## Ten-meter southward wind speed in m/s: V10

In [None]:
rmses_world_v10 = []
dates_world_v10 = []
for batch in world_batches:
    rmse, date = rmse_fn(model, batch, "10v", var_type="surface")
    rmses_world_v10.append(rmse)
    dates_world_v10.append(date)

In [70]:
rmses_SA_v10 = []
dates_SA_v10 = []
for batch in SA_batches:
    rmse, date = rmse_fn_sa(model, batch, "10v", var_type="surface")
    rmses_SA_v10.append(rmse)
    dates_SA_v10.append(date)

In [None]:
plot_rmses("v10",rmses_world_v10, rmses_SA_v10, 
            figsize=(12, 8), fontsize=18,
            alpha=0.2, date_ranges=dates_world_v10, title="Ten-meter southward wind speed two steps forward prediction: RMSES")

## Mean sea-level pressure in Pa :msl

In [None]:
rmses_world_msl = []
dates_world_msl = []
for batch in world_batches:
    rmse, date = rmse_fn(model, batch, "msl", var_type="surface")
    rmses_world_msl.append(rmse)
    dates_world_msl.append(date)

In [None]:
rmses_SA_msl = []
dates_SA_msl = []
for batch in SA_batches:
    rmse, date = rmse_fn_sa(model, batch, "msl", var_type="surface")
    rmses_SA_msl.append(rmse)
    dates_SA_msl.append(date)

In [None]:
plot_rmses("msl",rmses_world_msl, rmses_SA_msl, 
            figsize=(12, 8), fontsize=18,
            alpha=0.2, date_ranges=dates_world_msl, title="Mean sea-level pressure two steps forward prediction: RMSES")

# Atmosphere

## Temperature in K : t

### 50 hPa

In [None]:
rmses_world_t = []
dates_world_t = []
for batch in world_batches:
    rmse, date = rmse_fn(model, batch, "t", var_type="atmosphere", atmos_level_idx=0)
    rmses_world_t.append(rmse)
    dates_world_t.append(date)

In [None]:
rmses_SA_t = []
dates_SA_t = []
for batch in SA_batches:
    rmse, date = rmse_fn_sa(model, batch, "t", var_type="atmosphere", atmos_level_idx=0)
    rmses_SA_t.append(rmse)
    dates_SA_t.append(date)

In [None]:
plot_rmses("t",rmses_world_t, rmses_SA_t, 
            figsize=(12, 8), fontsize=18,
            alpha=0.2, date_ranges=dates_world_t, title="Temperature in K  two steps forward prediction: RMSES")

### 100 hPa

### 150 hPa

### 200 hPa

### 250 hPa

### 300 hPa

### 400 hPa

### 500 hPa

### 600hPa

### 700 hPa

### 850 hPa

### 925 hPa

### 1000 hPa

## Eastward wind speed in m/s u

### 50 hPa

### 100 hPa

### 150 hPa

### 200 hPa

### 250 hPa

### 300 hPa

### 400 hPa

### 500 hPa

### 600hPa

### 700 hPa

### 850 hPa

### 925 hPa

### 1000 hPa

### Southward wind speed in m/s :v

### 50 hPa

### 100 hPa

### 150 hPa

### 200 hPa

### 250 hPa

### 300 hPa

### 400 hPa

### 500 hPa

### 600hPa

### 700 hPa

### 850 hPa

### 925 hPa

### 1000 hPa

### 50 hPa

### 100 hPa

### 150 hPa

### 200 hPa

### 250 hPa

### 300 hPa

### 400 hPa

### 500 hPa

### 600hPa

### 700 hPa

### 850 hPa

### 925 hPa

### 1000 hPa

### Specific humidity in kg / kg: q

### 50 hPa

### 100 hPa

### 150 hPa

### 200 hPa

### 250 hPa

### 300 hPa

### 400 hPa

### 500 hPa

### 600hPa

### 700 hPa

### 850 hPa

### 925 hPa

### 1000 hPa

### Geopotential in m^2 / s^2 : z

### 50 hPa

### 100 hPa

### 150 hPa

### 200 hPa

### 250 hPa

### 300 hPa

### 400 hPa

### 500 hPa

### 600hPa

### 700 hPa

### 850 hPa

### 925 hPa

### 1000 hPa