# PrithviWxC Rollout Inference
If you haven't already, take a look at the exmaple for the PrithviWxC core
model, as we will pass over the points covered there.

Here we will introduce the PrithviWxC model that was trained furhter for
autoregressive rollout, a common strategy to increase accuracy and stability of
models when applied to forecasting-type tasks.

In [None]:
# setup google drive for saving checkpoints
from google.colab import drive
from pathlib import Path

drive.mount('/content/drive')
checkpoint_path = Path("/content/drive/MyDrive/Colab Notebooks/PrithviWxC_Checkpoints")

Mounted at /content/drive


In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download

# Set backend etc.
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

# Set seeds
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)

# Set device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Set variables
surface_vars = [
    "EFLUX",
    "GWETROOT",
    "HFLUX",
    "LAI",
    "LWGAB",
    "LWGEM",
    "LWTUP",
    "PS",
    "QV2M",
    "SLP",
    "SWGNT",
    "SWTNT",
    "T2M",
    "TQI",
    "TQL",
    "TQV",
    "TS",
    "U10M",
    "V10M",
    "Z0M",
]
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
levels = [
    34.0,
    39.0,
    41.0,
    43.0,
    44.0,
    45.0,
    48.0,
    51.0,
    53.0,
    56.0,
    63.0,
    68.0,
    71.0,
    72.0,
]
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}

### Lead time
When performing auto-regressive rollout, the intermediate steps require the
static data at those times and---if using `residual=climate`---the intermediate
climatology. We provide a dataloader that extends the MERRA2 loader of the
core model, adding in these additional terms. Further, it return target data for
the intermediate steps if those are required for loss terms.

The `lead_time` flag still lets the target time for the model, however now it
only a single value and must be a positive integer multiple of the `-input_time`.

In [None]:
time_range = ("2020-01-01T00:00:00", "2021-01-01T00:00:00")

In [None]:
input_time = -24*30
lead_time = abs(input_time)*11

In [None]:
from datetime import datetime, timedelta

def get_file_names(time_range, input_time, prefix):
    start_date = datetime.fromisoformat(time_range[0])
    end_date = datetime.fromisoformat(time_range[1])

    file_names = []
    current_date = start_date

    while current_date < end_date:
        file_name = f"{prefix}{current_date.strftime('%Y%m%d')}.nc"
        file_names.append(file_name)
        current_date += timedelta(hours=abs(input_time))

    return file_names

In [None]:
get_file_names(time_range, input_time, "MERRA2_sfc_")

['MERRA2_sfc_20200101.nc',
 'MERRA2_sfc_20200131.nc',
 'MERRA2_sfc_20200301.nc',
 'MERRA2_sfc_20200331.nc',
 'MERRA2_sfc_20200430.nc',
 'MERRA2_sfc_20200530.nc',
 'MERRA2_sfc_20200629.nc',
 'MERRA2_sfc_20200729.nc',
 'MERRA2_sfc_20200828.nc',
 'MERRA2_sfc_20200927.nc',
 'MERRA2_sfc_20201027.nc',
 'MERRA2_sfc_20201126.nc',
 'MERRA2_sfc_20201226.nc']

### Data file
MERRA-2 data is available from 1980 to the present day,
at 3-hour temporal resolution. The dataloader we have provided
expects the surface data and vertical data to be saved in
separate files, and when provided with the directories, will
search for the relevant data that falls within the provided time range.


In [None]:
surf_dir = Path("./merra-2")
snapshot_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    allow_patterns=get_file_names(time_range, input_time, "merra-2/MERRA2_sfc_"),
    local_dir=".",
)

vert_dir = Path("./merra-2")
snapshot_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    allow_patterns=get_file_names(time_range, input_time, "merra-2/MERRA_pres_"),
    local_dir=".",
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

MERRA2_sfc_20200331.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20200101.nc:   0%|          | 0.00/102M [00:00<?, ?B/s]

MERRA2_sfc_20200729.nc:   0%|          | 0.00/100M [00:00<?, ?B/s]

MERRA2_sfc_20200430.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20200629.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20200131.nc:   0%|          | 0.00/102M [00:00<?, ?B/s]

MERRA2_sfc_20200530.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20200301.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20200828.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20200927.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20201126.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20201027.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

MERRA2_sfc_20201226.nc:   0%|          | 0.00/101M [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

MERRA_pres_20200301.nc:   0%|          | 0.00/336M [00:00<?, ?B/s]

MERRA_pres_20200131.nc:   0%|          | 0.00/338M [00:00<?, ?B/s]

MERRA_pres_20200101.nc:   0%|          | 0.00/337M [00:00<?, ?B/s]

MERRA_pres_20200530.nc:   0%|          | 0.00/342M [00:00<?, ?B/s]

MERRA_pres_20200729.nc:   0%|          | 0.00/334M [00:00<?, ?B/s]

MERRA_pres_20200430.nc:   0%|          | 0.00/339M [00:00<?, ?B/s]

MERRA_pres_20200331.nc:   0%|          | 0.00/339M [00:00<?, ?B/s]

MERRA_pres_20200629.nc:   0%|          | 0.00/336M [00:00<?, ?B/s]

MERRA_pres_20200828.nc:   0%|          | 0.00/334M [00:00<?, ?B/s]

MERRA_pres_20200927.nc:   0%|          | 0.00/337M [00:00<?, ?B/s]

MERRA_pres_20201027.nc:   0%|          | 0.00/338M [00:00<?, ?B/s]

MERRA_pres_20201126.nc:   0%|          | 0.00/338M [00:00<?, ?B/s]

MERRA_pres_20201226.nc:   0%|          | 0.00/334M [00:00<?, ?B/s]

'/content'

### Climatology
The PrithviWxC model was trained to calculate the output by
producing a perturbation to the climatology at the target time.
 This mode of operation is set via the `residual=climate` option.
 This was chosen as climatology is typically a strong prior for
 long-range prediction. When using the `residual=climate` option,
 we have to provide the dataloader with the path of the
 climatology data.

In [None]:
def get_climate_file_names(time_range, input_time, prefix):
    start_date = datetime.fromisoformat(time_range[0])
    end_date = datetime.fromisoformat(time_range[1])

    file_names = []
    current_date = start_date

    while current_date < end_date:
        file_name = f"{prefix}{current_date.strftime('%j')}_hour{current_date.strftime('%H')}.nc"
        file_names.append(file_name)
        current_date += timedelta(hours=abs(input_time))

    return file_names

In [None]:
get_climate_file_names(time_range, input_time, "climatology/climate_surface_doy")

['climatology/climate_surface_doy001_hour00.nc',
 'climatology/climate_surface_doy031_hour00.nc',
 'climatology/climate_surface_doy061_hour00.nc',
 'climatology/climate_surface_doy091_hour00.nc',
 'climatology/climate_surface_doy121_hour00.nc',
 'climatology/climate_surface_doy151_hour00.nc',
 'climatology/climate_surface_doy181_hour00.nc',
 'climatology/climate_surface_doy211_hour00.nc',
 'climatology/climate_surface_doy241_hour00.nc',
 'climatology/climate_surface_doy271_hour00.nc',
 'climatology/climate_surface_doy301_hour00.nc',
 'climatology/climate_surface_doy331_hour00.nc',
 'climatology/climate_surface_doy361_hour00.nc']

In [None]:
surf_clim_dir = Path("./climatology")
snapshot_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    allow_patterns=get_climate_file_names(time_range, input_time, "climatology/climate_surface_doy"),
    local_dir=".",
)

vert_clim_dir = Path("./climatology")
snapshot_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    allow_patterns=get_climate_file_names(time_range, input_time, "climatology/climate_vertical_doy"),
    local_dir=".",
)

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

climate_surface_doy181_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy031_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy121_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy001_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy151_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy211_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy091_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy061_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy301_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy331_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy241_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy361_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

climate_surface_doy271_hour00.nc:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

climate_vertical_doy001_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy061_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy181_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy121_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy091_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy031_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy151_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy211_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy301_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy271_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy241_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy331_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

climate_vertical_doy361_hour00.nc:   0%|          | 0.00/116M [00:00<?, ?B/s]

'/content'

In [None]:
positional_encoding = "fourier"

### Dataloader init
We are now ready to instantiate the dataloader.

In [None]:
!pip install git+https://github.com/NASA-IMPACT/Prithvi-WxC

Collecting git+https://github.com/NASA-IMPACT/Prithvi-WxC
  Cloning https://github.com/NASA-IMPACT/Prithvi-WxC to /tmp/pip-req-build-2m26ap0f
  Running command git clone --filter=blob:none --quiet https://github.com/NASA-IMPACT/Prithvi-WxC /tmp/pip-req-build-2m26ap0f
  Resolved https://github.com/NASA-IMPACT/Prithvi-WxC to commit ecfd69b2e94b6505d3accbdecf095b92fd18411e
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: PrithviWxC
  Building wheel for PrithviWxC (pyproject.toml) ... [?25l[?25hdone
  Created wheel for PrithviWxC: filename=prithviwxc-1.0.0-py3-none-any.whl size=26385 sha256=ec44e2a7e37374bff859dfd55a8766bd2874fbb0c92c7ab9c70826a2e03a7cd6
  Stored in directory: /tmp/pip-ephem-wheel-cache-vus7ro_i/wheels/27/50/62/bdd643949a6f5dbb8cc41159f3d2ee23a470394d10e56908f9
Successfully built PrithviWxC
Installing collected 

In [None]:
from PrithviWxC.dataloaders.merra2_rollout import Merra2RolloutDataset

dataset = Merra2RolloutDataset(
    time_range=time_range,
    lead_time=lead_time,
    input_time=input_time,
    data_path_surface=surf_dir,
    data_path_vertical=vert_dir,
    climatology_path_surface=surf_clim_dir,
    climatology_path_vertical=vert_clim_dir,
    surface_vars=surface_vars,
    static_surface_vars=static_surface_vars,
    vertical_vars=vertical_vars,
    levels=levels,
    positional_encoding=positional_encoding,
)
assert len(dataset) > 0, "There doesn't seem to be any valid data."

In [None]:
dataset.nsteps

11

## Model
### Scalers and other hyperparameters
Again, this setup is similar as before.

In [None]:
from PrithviWxC.dataloaders.merra2 import (
    input_scalers,
    output_scalers,
    static_input_scalers,
)

surf_in_scal_path = Path("./climatology/musigma_surface.nc")
hf_hub_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    filename=f"climatology/{surf_in_scal_path.name}",
    local_dir=".",
)

vert_in_scal_path = Path("./climatology/musigma_vertical.nc")
hf_hub_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    filename=f"climatology/{vert_in_scal_path.name}",
    local_dir=".",
)

surf_out_scal_path = Path("./climatology/anomaly_variance_surface.nc")
hf_hub_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    filename=f"climatology/{surf_out_scal_path.name}",
    local_dir=".",
)

vert_out_scal_path = Path("./climatology/anomaly_variance_vertical.nc")
hf_hub_download(
    repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
    filename=f"climatology/{vert_out_scal_path.name}",
    local_dir=".",
)

hf_hub_download(
    repo_id="Prithvi-WxC/prithvi.wxc.rollout.2300m.v1",
    filename="config.yaml",
    local_dir=".",
)

in_mu, in_sig = input_scalers(
    surface_vars,
    vertical_vars,
    levels,
    surf_in_scal_path,
    vert_in_scal_path,
)

output_sig = output_scalers(
    surface_vars,
    vertical_vars,
    levels,
    surf_out_scal_path,
    vert_out_scal_path,
)

static_mu, static_sig = static_input_scalers(
    surf_in_scal_path,
    static_surface_vars,
)

residual = "climate"
masking_mode = "local"
decoder_shifting = True
masking_ratio = 0.99

musigma_surface.nc:   0%|          | 0.00/24.7k [00:00<?, ?B/s]

musigma_vertical.nc:   0%|          | 0.00/25.0k [00:00<?, ?B/s]

anomaly_variance_surface.nc:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

anomaly_variance_vertical.nc:   0%|          | 0.00/18.6k [00:00<?, ?B/s]

config.yaml:   0%|          | 0.00/428 [00:00<?, ?B/s]

### Model init
We can now build and load the pretrained weights, note that you should use the
rollout version of the weights.

In [None]:
weights_path = Path("./weights/prithvi.wxc.rollout.2300m.v1.pt")
hf_hub_download(
    repo_id="Prithvi-WxC/prithvi.wxc.rollout.2300m.v1",
    filename=weights_path.name,
    local_dir="./weights",
)

prithvi.wxc.rollout.2300m.v1.pt:   0%|          | 0.00/28.4G [00:00<?, ?B/s]

'weights/prithvi.wxc.rollout.2300m.v1.pt'

In [None]:
import yaml

from PrithviWxC.model import PrithviWxC

with open("./config.yaml", "r") as f:
    config = yaml.safe_load(f)

model = PrithviWxC(
    in_channels=config["params"]["in_channels"],
    input_size_time=config["params"]["input_size_time"],
    in_channels_static=config["params"]["in_channels_static"],
    input_scalers_mu=in_mu,
    input_scalers_sigma=in_sig,
    input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
    static_input_scalers_mu=static_mu,
    static_input_scalers_sigma=static_sig,
    static_input_scalers_epsilon=config["params"][
        "static_input_scalers_epsilon"
    ],
    output_scalers=output_sig**0.5,
    n_lats_px=config["params"]["n_lats_px"],
    n_lons_px=config["params"]["n_lons_px"],
    patch_size_px=config["params"]["patch_size_px"],
    mask_unit_size_px=config["params"]["mask_unit_size_px"],
    mask_ratio_inputs=masking_ratio,
    embed_dim=config["params"]["embed_dim"],
    n_blocks_encoder=config["params"]["n_blocks_encoder"],
    n_blocks_decoder=config["params"]["n_blocks_decoder"],
    mlp_multiplier=config["params"]["mlp_multiplier"],
    n_heads=config["params"]["n_heads"],
    dropout=config["params"]["dropout"],
    drop_path=config["params"]["drop_path"],
    parameter_dropout=config["params"]["parameter_dropout"],
    residual=residual,
    masking_mode=masking_mode,
    decoder_shifting=decoder_shifting,
    positional_encoding=positional_encoding,
    checkpoint_encoder=[],
    checkpoint_decoder=[],
)


state_dict = torch.load(weights_path, weights_only=False)
if "model_state" in state_dict:
    state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)

if (hasattr(model, "device") and model.device != device) or not hasattr(
    model, "device"
):
    model = model.to(device)

## Rollout
We are now ready to perform the rollout. Agin the data has to be run through a
preprocessor. However this time we use a preprocessor that can handle the
additional intermediate data. Also, rather than calling the model directly, we
have a conveient wrapper function that performs the interation. This also
simplifies the model loading when using a sharded cahckpoint. If you attempt to
perform training steps upton this function, we should use an aggressive number
of activation checkpoints as the memory consumption becomes quite high.

In [None]:
# original code
# from PrithviWxC.dataloaders.merra2_rollout import preproc
# from PrithviWxC.rollout import rollout_iter

# data = next(iter(dataset))
# batch = preproc([data], padding)

# for k, v in batch.items():
#     if isinstance(v, torch.Tensor):
#         batch[k] = v.to(device)

# rng_state_1 = torch.get_rng_state()
# with torch.no_grad():
#     model.eval()
#     out = rollout_iter(dataset.nsteps, model, batch)

In [None]:
# setup data
from PrithviWxC.dataloaders.merra2_rollout import preproc
from PrithviWxC.rollout import rollout_iter

data = next(iter(dataset))
batch = preproc([data], padding)

for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        batch[k] = v.to(device)

With `lead_time = 12` and `input_time = -6`, dataset.nsteps will be 2. Without careful debugging, it appears the calculation is essentially `abs(lead_time//input_time)` as above in this notebook it does say that `lead_time` "must be a positive integer multiple of the `-input_time`."

So with these values we've essentially primed our model for 6 hour timesteps. If we want to autoregress further, we should think of each step as 6 hours. If we want a year, we'd calculate `(24*365)//abs(input_time)`


In [None]:
nsteps_extended = (24*365*10)//abs(input_time) # ten years
nsteps = dataset.nsteps

In [None]:
# main autoregression loop
rng_state_1 = torch.get_rng_state()
steps_per_checkpoint = 1
with torch.no_grad():
    model.eval()

    # attempt to load last checkpoint
    checkpoints = [f.name for f in checkpoint_path.iterdir() if f.is_file()]
    print(f"Checkpoints: {checkpoints}")
    if len(checkpoints) > 0:
      get_chkpt_num = lambda x: int(x.split(".")[0].split("_")[-1])
      checkpoints = sorted(checkpoints, key=get_chkpt_num)
      last_checkpoint = checkpoints[-1]
      print(f"Loading checkpoint: {last_checkpoint}")

      batch["x"] = torch.load(checkpoint_path / last_checkpoint).to(device)
      xlast = batch["x"][:, -1] # `out` from the previous run concated on line below
      start_step = get_chkpt_num(last_checkpoint)+1
    else:
      print("No checkpoints found, starting from scratch")
      xlast = batch["x"][:, 1]
      start_step = 0

    batch["lead_time"] = batch["lead_time"][..., 0]

    # Save the masking ratio to be restored later
    mask_ratio_tmp = model.mask_ratio_inputs

    for step in range(start_step, nsteps_extended):
        print(f"Starting step {step}/{nsteps_extended}...")

        # After first step, turn off masking
        if step > 0:
            model.mask_ratio_inputs = 0.0

        # modulo step based on nsteps to cyclically take from
        # available data. normally for loop above would exit
        # at nsteps, but since we're pushing it, we need to
        # wrap around
        batch["static"] = batch["statics"][:, step % nsteps]
        batch["climate"] = batch["climates"][:, step % nsteps]
        batch["y"] = batch["ys"][:, step % nsteps]

        out = model(batch)

        batch["x"] = torch.cat((xlast[:, None], out[:, None]), dim=1)
        xlast = out

        # save checkpoint
        print(f"{step}/{nsteps_extended}")
        if step % steps_per_checkpoint == 0:
          print(f"Saving checkpoint {step}...")
          torch.save(batch["x"], checkpoint_path / f'step_{step}.pt')

    # Restore the masking ratio
    model.mask_ratio_inputs = mask_ratio_tmp

Checkpoints: ['step_0.pt', 'step_1.pt', 'step_2.pt', 'step_3.pt', 'step_4.pt', 'step_5.pt', 'step_6.pt', 'step_7.pt', 'step_8.pt', 'step_9.pt', 'step_10.pt', 'step_11.pt', 'step_12.pt', 'step_13.pt', 'step_14.pt', 'step_15.pt', 'step_16.pt', 'step_17.pt', 'step_18.pt', 'step_19.pt', 'step_20.pt', 'step_21.pt', 'step_22.pt', 'step_23.pt', 'step_24.pt', 'step_25.pt', 'step_26.pt', 'step_27.pt']
Loading checkpoint: step_27.pt


  batch["x"] = torch.load(checkpoint_path / last_checkpoint).to(device)


Starting step 28/121...
28/121
Saving checkpoint 28...
Starting step 29/121...


In [None]:
out.shape

## Plotting

In [None]:
t2m = out[0, 12].cpu().numpy()

lat = np.linspace(-90, 90, out.shape[-2])
lon = np.linspace(-180, 180, out.shape[-1])
X, Y = np.meshgrid(lon, lat)

plt.contourf(X, Y, t2m, 100)
plt.gca().set_aspect("equal")
plt.show()

In [None]:
import concurrent.futures
import torch
from tqdm import tqdm

def process_checkpoint(index, checkpoint_file):
    checkpoint = torch.load(checkpoint_file)
    global_avg_temp = checkpoint[:, -1, 12].mean().item()
    return index, global_avg_temp

checkpoint_files = list(checkpoint_path.glob("*.pt"))
global_avg_temps = [None] * len(checkpoint_files)

with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = [executor.submit(process_checkpoint, i, checkpoint_file) for i, checkpoint_file in enumerate(checkpoint_files)]
    for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
        index, global_avg_temp = future.result()
        global_avg_temps[index] = global_avg_temp

In [None]:
global_avg_temps[:5]

In [None]:
start_datetime = datetime.strptime(time_range[0], "%Y-%m-%dT%H:%M:%S") + timedelta(hours=abs(input_time))
timestamps = [start_datetime]
current_datetime = start_datetime
for i in range(len(global_avg_temps)-1):
  current_datetime += timedelta(hours=abs(input_time) * steps_per_checkpoint)
  timestamps.append(current_datetime)


In [None]:
import pandas as pd

df = pd.DataFrame({'timestamp': timestamps, 'global_avg_temp': global_avg_temps})
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.set_index('timestamp')

# Calculate 30-day moving average
# df['30_day_moving_avg'] = df['global_avg_temp'].rolling(window='30D').mean()
df.head()

In [None]:
plt.plot(timestamps, global_avg_temps)
plt.xlabel("Timestamps")
plt.ylabel("Global Average Temperatures")
plt.title("Global Average Temperature Over Time")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()