# Prithvi WxC Forecast

This notebook uses the large Prithvi-WxC model and a diagnostic neural network to perform precipitation forecasts.

In [26]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
from precipfm.utils import get_date
geos_forecasts = sorted(list(Path("/data/precipfm/verification/forecast/").glob("**/*.nc")))
initialization_times = [get_date(path) for path in geos_forecasts][4:]

## Load the Prithvi-WxC model

In [2]:
from precipfm.prithviwxc import load_model
mdl = load_model(
    checkpoint_path="/gdata1/simon/precipfm/models/prithvi.wxc.rollout.2300m.v1.pt",
    auxiliary_path="/gdata1/simon/precipfm/training_data_2019/climatology/",
    configuration="large",
)

## Input data loader

In [3]:
from precipfm.datasets import GEOSInputData

input_loader = GEOSInputData(
    "/gdata1/simon/precipfm/verification/analysis/",
    input_times=[-6, 0],
    lead_times=[6]
)


## Diagnostic model

In [4]:
from pytorch_retrieve import load_model
from pytorch_retrieve.config import InferenceConfig, RetrievalOutputConfig
precip_mdl = load_model("/gdata1/simon/precipfm/model_diagnose/gprof_nn_3d.pt")

expected_value = RetrievalOutputConfig(precip_mdl.output_config["surface_precip"], "ExpectedValue", {})
retrieval_output = {
    "surface_precip": {
        "surface_precip": expected_value,
    }
}
inference_config = InferenceConfig(
    tile_size=128,
    spatial_overlap=32,
    retrieval_output=retrieval_output,
    batch_size=2,
)
precip_mdl.inference_config = inference_config
precip_mdl.save("precip_diagnostic.pt")


## Run the forecasts

In [6]:
from precipfm.forecast import Forecaster
fc = Forecaster(mdl, input_loader)

In [None]:
output_path = Path("/gdata1/simon/precipfm/results") / model_name
output_path.mkdir(exist_ok=True)

In [30]:
results = fc.run(
    np.datetime64("2025-04-07T00:00:00"),
    4,
    diagnostics={"surface_precip": precip_mdl}
)

  static = pad(torch.tensor(self.input_loader.load_static_data(step_time)))[None]


Output()

 25%|█████████████████▎                                                   | 1/4 [01:05<03:17, 65.78s/it]

Output()

 50%|██████████████████████████████████▌                                  | 2/4 [02:16<02:17, 68.94s/it]

Output()

 75%|███████████████████████████████████████████████████▊                 | 3/4 [03:15<01:04, 64.04s/it]

Output()

100%|█████████████████████████████████████████████████████████████████████| 4/4 [04:24<00:00, 66.16s/it]
