# Training and inference



## Model overview


## Running the training

### Using ``pytorch_retrieve``

A simple training recipe for the Prithvi-WxC using the ``pytorch_retrieve`` package is provided in the ``model_small`` directory. The directory contains three files ``model.toml``, ``training.toml``, and ``compute.toml``, which describe the model configuration, training schedule, and compute configuration, respectively. The training can be run by executing the ``pytroch_retrieve train`` command in the ``model_small`` directory.

````
cd model_small
pytorch_retrieve train
````

### Manual training

For more fine-grained control over training, the ``PrithviWxCRegional`` model can be instantiated directly yielding a PyTorch module that can be trained using PyTorch or lightning.
    


The training process can be monitored using ``tensorboard --logdir logs``.

In [9]:
from pathlib import Path
from PrithviWxC.dataloaders.merra2 import (
    input_scalers,
    output_scalers,
    static_input_scalers,
)
from pytorch_retrieve.models.prithvi_wxc import PrithviWxCRegional

VERTICAL_VARS = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
STATIC_SURFACE_VARS = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
SURFACE_VARS = [
    "EFLUX", "GWETROOT", "HFLUX", "LAI", "LWGAB", "LWGEM", "LWTUP", "PS", "QV2M", "SLP",
    "SWGNT", "SWTNT", "T2M", "TQI", "TQL", "TQV", "TS", "U10M", "V10M", "Z0M"
]
LEVELS = [
    34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 51.0, 53.0, 56.0, 63.0, 68.0, 71.0, 72.0
]

# Path containing the scaling factors
scaling_factors = Path("/home/simon/data/e3sm/scaling_factors")
in_mu, in_sig = input_scalers(
    SURFACE_VARS,
    VERTICAL_VARS,
    LEVELS,
    str(scaling_factors / "musigma_surface.nc"),
    str(scaling_factors / "musigma_vertical.nc"),
)
output_sig = output_scalers(
    SURFACE_VARS,
    VERTICAL_VARS,
    LEVELS,
    str(scaling_factors / "anomaly_variance_surface.nc"),
    str(scaling_factors / "anomaly_variance_vertical.nc"),
)

static_mu, static_sig = static_input_scalers(
    str(scaling_factors / "musigma_surface.nc"),
    STATIC_SURFACE_VARS,
)

# Parameters are chosen to match the small Prithvi-WxC model.
model = PrithviWxCRegional(
    in_channels=160,
    input_size_time=2,
    in_channels_static=8,
    input_scalers_epsilon=0.0,
    static_input_scalers_epsilon=0.0,
    n_lats_px=180,
    n_lons_px=360,
    patch_size_px=[2, 2],
    mask_unit_size_px=[20, 20],
    embed_dim=1024,
    n_blocks_encoder=8,
    n_blocks_decoder=4,
    mlp_multiplier=4,
    n_heads=16,
    dropout=0.0,
    drop_path=0.0,
    parameter_dropout=0.0,
    positional_encoding="fourier",
    encoder_shifting=True,
    decoder_shifting=False,
    mask_ratio_inputs=0.0,
    residual="climate",
    masking_mode="both",
    # Activate activate checpointing to reduce memory footprint.
    checkpoint_encoder=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    checkpoint_decoder=[0, 1, 2, 3, 4, 5, 6, 7, 8],
    input_scalers_mu=in_mu,
    input_scalers_sigma=in_sig,
    static_input_scalers_mu= static_mu,
    static_input_scalers_sigma= static_sig,
    output_scalers= output_sig ** 0.5,
    mask_ratio_targets=0.0
)
