In [1]:
import torch
import deepsensor
import deepsensor.torch
from deepsensor.train import set_gpu_default_device
set_gpu_default_device()

In [2]:
from deepsensor.train import Trainer
from deepsensor.model import ConvNP
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from deepsensor.data import DataProcessor, TaskLoader
from tqdm import tqdm
from scipy.special import sph_harm_y

# Load first three files from dataset
data_path = '/nfs/turbo/seas-dannes/urop-2024-bias/cfs-forecasts/cfs_2012-2023.zarr'

# Open the Zarr store using Xarray
ds0 = xr.open_zarr(data_path, consolidated=True)

# Check the contents of the dataset
ds0
#ds = ds.sel(time=ds.time.dt.hour == 0)

#ds = ds.sel(time=slice("2012-01-01", "2015-12-31"))  # Select only the first three files




In [3]:
ds = ds0.isel(lead=0, drop=True)
ds = ds.sortby("time")
ds

['2012-01-01T00:00:00.000000000' '2012-01-01T06:00:00.000000000'
 '2012-01-01T12:00:00.000000000' ... '2024-01-01T06:00:00.000000000'
 '2024-01-01T12:00:00.000000000' '2024-01-01T18:00:00.000000000']


['2012-01-01T00:00:00.000000000' '2012-01-01T06:00:00.000000000'
 '2012-01-01T12:00:00.000000000' ... '2015-01-01T06:00:00.000000000'
 '2015-01-01T12:00:00.000000000' '2015-01-01T18:00:00.000000000']


<xarray.Dataset> Size: 3GB
Dimensions:            (time: 4206, latitude: 181, longitude: 360)
Coordinates:
  * latitude           (latitude) float64 1kB -90.0 -89.0 -88.0 ... 89.0 90.0
  * longitude          (longitude) float64 3kB 0.0 1.0 2.0 ... 357.0 358.0 359.0
  * time               (time) datetime64[ns] 34kB 2012-01-01 ... 2015-01-01T1...
Data variables:
    LHTFL_surface      (time, latitude, longitude) float32 1GB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>
    SHTFL_surface      (time, latitude, longitude) float32 1GB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>
    TMP_2maboveground  (time, latitude, longitude) float32 1GB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>


In [7]:
# Apply bias correction function

def spherical_harmonic_bias(data, l=2, m=1, amplitude=5):
    lon = np.linspace(0, 2 * np.pi, data.shape[2])
    lat = np.linspace(-np.pi / 2, np.pi / 2, data.shape[1])
    lon_grid, lat_grid = np.meshgrid(lon, lat)
    harmonics = np.real(sph_harm(m, l, lon_grid, lat_grid))
    harmonics = amplitude * harmonics / np.max(np.abs(harmonics))
    return harmonics[None, :, :]



In [8]:
print(ds.sel(time=slice("2012-01-01T00:00:00.000000000", "2012-01-03T00:00:00.000000000")))

<xarray.Dataset> Size: 7MB
Dimensions:            (time: 9, latitude: 181, longitude: 360)
Coordinates:
  * latitude           (latitude) float64 1kB -90.0 -89.0 -88.0 ... 89.0 90.0
  * longitude          (longitude) float64 3kB 0.0 1.0 2.0 ... 357.0 358.0 359.0
  * time               (time) datetime64[ns] 72B 2012-01-01 ... 2012-01-03
Data variables:
    LHTFL_surface      (time, latitude, longitude) float32 2MB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>
    SHTFL_surface      (time, latitude, longitude) float32 2MB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>
    TMP_2maboveground  (time, latitude, longitude) float32 2MB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>


In [9]:
data_processor = DataProcessor(x1_name="latitude", x2_name="longitude")
_ = data_processor(ds.sel(time=slice("2012-01-01T00:00:00.000000000", "2012-01-31T00:00:00.000000000")))
ds_processed = data_processor(ds)
print(data_processor)

DataProcessor with normalisation params:
{'LHTFL_surface': {'method': 'mean_std',
                   'params': {'mean': 66.60301208496094,
                              'std': 65.60218048095703}},
 'SHTFL_surface': {'method': 'mean_std',
                   'params': {'mean': 6.118175506591797,
                              'std': 31.839258193969727}},
 'TMP_2maboveground': {'method': 'mean_std',
                       'params': {'mean': 276.7368469238281,
                                  'std': 20.317398071289062}},
 'coords': {'time': {'name': 'time'},
            'x1': {'map': (-90.0, 269.0), 'name': 'latitude'},
            'x2': {'map': (0.0, 359.0), 'name': 'longitude'}}}


In [10]:
print(ds_processed)

<xarray.Dataset> Size: 3GB
Dimensions:            (time: 4206, x1: 181, x2: 360)
Coordinates:
  * time               (time) datetime64[ns] 34kB 2012-01-01 ... 2015-01-01T1...
  * x1                 (x1) float64 1kB 0.0 0.002786 0.005571 ... 0.4986 0.5014
  * x2                 (x2) float64 3kB 0.0 0.002786 0.005571 ... 0.9972 1.0
Data variables:
    LHTFL_surface      (time, x1, x2) float32 1GB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>
    SHTFL_surface      (time, x1, x2) float32 1GB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>
    TMP_2maboveground  (time, x1, x2) float32 1GB dask.array<chunksize=(1, 91, 180), meta=np.ndarray>


In [11]:
train_range = ("2012-01-01", "2014-12-31")
val_range = ("2015-01-01", "2015-12-31")

In [12]:
def add_bias_function(data, bias_function, **kwargs):
    return data + bias_function(data, **kwargs)

# Define TaskLoader
biased_contexts = [
    add_bias_function(ds_processed["TMP_2maboveground"], spherical_harmonic_bias),
    #add_bias_function(ds_processed["TMP_2maboveground"], random_noise_bias),
    #add_bias_function(ds_processed["TMP_2maboveground"], linear_trend_bias),
    #add_bias_function(ds_processed["TMP_2maboveground"], periodic_step_bias),
]

task_loader = TaskLoader(
    context=biased_contexts,  
    target=ds_processed["TMP_2maboveground"],
)

print(task_loader)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



TaskLoader(1 context sets, 1 target sets)
Context variable IDs: (('TMP_2maboveground',),)
Target variable IDs: (('TMP_2maboveground',),)


IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [13]:
import time
import sys
def gen_tasks(dates, progress=True):
    tasks = []
    for date in tqdm(dates):
        date = np.datetime64(date)  # Ensure consistent datetime format
        try:
            task = task_loader(date, context_sampling="all", target_sampling="all")  
            tasks.append(task)
        except KeyError:
            print(f"Skipping date {date} as it is not found in dataset.")
        sys.stdout.flush()
    print(f"Finished generating {len(tasks)} tasks.")
    return tasks


# Define the ConvNP model
model = ConvNP(data_processor, task_loader, dim_yc=(1,))

# Train the model
trainer = Trainer(model)

dim_yt inferred from TaskLoader: 1
dim_aux_t inferred from TaskLoader: 0
internal_density inferred from TaskLoader: 359
encoder_scales inferred from TaskLoader: [np.float32(0.0013927576)]
decoder_scale inferred from TaskLoader: 0.002785515320334262


RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
ds["time"] = pd.to_datetime(ds["time"].values).strftime("%Y-%m-%d %H:%M:%S")

In [None]:
print(ds.time)

In [None]:
losses = []
val_rmses = []
train_rmses = []

In [None]:
def compute_val_rmse(model, val_tasks):
    errors = []
    target_var_ID = task_loader.target_var_IDs[0][0]  # assume 1st target set and 1D
    for task in val_tasks:
        mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True)
        true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True)
        errors.extend(np.abs(mean - true))
    return np.sqrt(np.mean(np.concatenate(errors) ** 2))


In [None]:
def compute_train_rmse(model, train_tasks):
    errors = []
    context_var_ID = task_loader.context_var_IDs[0][0]  # assume 1st target set and 1D
    for task in train_tasks:
        mean = data_processor.map_array(model.mean(task), context_var_ID, unnorm=True)
        true = data_processor.map_array(task["Y_t"][0], context_var_ID, unnorm=True)
        errors.extend(np.abs(mean - true))
    return np.sqrt(np.mean(np.concatenate(errors) ** 2))


In [None]:
val_dates = pd.date_range(val_range[0], val_range[1])
val_tasks = gen_tasks(val_dates)
_ = model(val_tasks[0])
print(f"Model has {deepsensor.backend.nps.num_params(model.model):,} parameters")

In [None]:
# Train model
#deepsensor_folder = "/home/whruiray/deepsensor_config/"


val_rmse_best = np.inf
train_rmse_best = np.inf


trainer = Trainer(model, lr=5e-5)
train_tasks = gen_tasks(pd.date_range(train_range[0], train_range[1]), progress=False)

In [None]:
for epoch in tqdm(range(10)):
    #train_tasks = gen_tasks(pd.date_range(train_range[0], train_range[1]), progress=False)
    
    print(f"Training batch for Epoch {epoch+1} started...")
    sys.stdout.flush()
    
    batch_losses = trainer(train_tasks)
    
    
    losses.append(np.mean(batch_losses))
    print(f"Training batch for Epoch {epoch+1} complete...")
    sys.stdout.flush()
   
    train_rmses.append(compute_train_rmse(model, train_tasks))
    print(f"Epoch {epoch+1} - Loss: {losses[-1]:.4f}, Validation RMSE: {train_rmses[-1]:.4f}")
    print(f"Epoch {epoch+1} completed in {time.time()} seconds.")
    sys.stdout.flush()
    if train_rmses[-1] < train_rmse_best:
        train_rmse_best = train_rmses[-1]

    val_rmses.append(compute_val_rmse(model, val_tasks))
    print(f"Epoch {epoch+1} - Loss: {losses[-1]:.4f}, Validation RMSE: {val_rmses[-1]:.4f}")
    print(f"Epoch {epoch+1} completed in {time.time()} seconds.")
    sys.stdout.flush()
    if val_rmses[-1] < val_rmse_best:
        val_rmse_best = val_rmses[-1]

        #model.save(deepsensor_folder)

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(losses)
axes[1].plot(val_rmses)
_ = axes[0].set_xlabel("Epoch")
_ = axes[1].set_xlabel("Epoch")
_ = axes[0].set_title("Training loss")
_ = axes[1].set_title("Validation RMSE")

plt.tight_layout()
plt.savefig("fig/spherical_bias_loss_and_rmse.png")
plt.close()

In [None]:
import os
import matplotlib.pyplot as plt

# Plot training metrics with adjusted scales
fig, ax1 = plt.subplots(figsize=(10, 6))

# Smooth the loss values (using simple moving average)
window_size = 5
smoothed_losses = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')

# Plot training loss
ax1.plot(range(len(smoothed_losses)), smoothed_losses, 'b-', label='Training Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss', color='b')
ax1.tick_params('y', colors='b')

# Create a second y-axis for validation RMSE
ax2 = ax1.twinx()
ax2.plot(range(len(val_rmses)), val_rmses, 'r-', label='Validation RMSE')
ax2.set_ylabel('RMSE', color='r')
ax2.tick_params('y', colors='r')

# Adjust scales to show downward trend
if len(losses) > 0:
    loss_min, loss_max = min(smoothed_losses), max(smoothed_losses)
    ax1.set_ylim(loss_min - 0.1*(loss_max-loss_min), loss_max + 0.1*(loss_max-loss_min))

if len(val_rmses) > 0:
    rmse_min, rmse_max = min(val_rmses), max(val_rmses)
    ax2.set_ylim(rmse_min - 0.1*(rmse_max-rmse_min), rmse_max + 0.1*(rmse_max-rmse_min))

plt.title('Training Loss and Validation RMSE')
fig.tight_layout()
plt.savefig("fig/spherical_bias_metrics.png")
plt.close()

In [None]:
import deepsensor
from deepsensor import plot
import matplotlib.pyplot as plt

# Select a sample date for visualization
sample_date = "2012-01-01"
sample_task = task_loader(sample_date, context_sampling="all", target_sampling="all")

task_loader_spherical_barmonic_bias = TaskLoader(
    context=biased_contexts,
    target=ds_processed["TMP_2maboveground"],
)
# Plot context (biased) and target (original) data
fig = deepsensor.plot.task(sample_task, task_loader=task_loader_spherical_barmonic_bias)
fig.savefig("fig/constant_sample_context_target_plot.png")
plt.close(fig)

In [None]:
# Step 1: Choose the target date for prediction
target_date = "2016-01-15"  # Example date in the 2016-2020 range

# Step 2: Load the task for the target date (context and target data)
task = task_loader(target_date, context_sampling="all", target_sampling="all")

# Step 3: Run the model to get predictions
# Note that X_t is just passed as ds. It gets the xarray structure from ds. 
pred_val = model.predict(task, X_t=ds)  # Just pass the task object

# Step 4: Extract the relevant variable for predictions (e.g., 'APCP_surface' or 'TMP_2maboveground')
# Assuming 'TMP_2maboveground' is the variable you're predicting
predxr = pred_val['TMP_2maboveground']

# Step 5: Plot the mean prediction
predxr['mean'].plot(cmap='viridis')  # Plot mean prediction
plt.title(f"Prediction for {target_date}")
plt.savefig(f"fig/prediction_for_spherical_bias_{target_date}_mean.png")
plt.close()


# Step 6: Plot the standard deviation as well
predxr['std'].plot(cmap='viridis')  # Plot standard deviation
plt.title(f"Standard Deviation for {target_date}")
plt.savefig(f"fig/prediction_for_spherical_bias_{target_date}_std.png")
plt.close()