In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import os
import datasets
import warnings
from tqdm.cli import tqdm
import os
from torch.utils.data import DataLoader
import torch.nn.functional as Fin
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from torchdiffeq import odeint as odeint
import matplotlib
import argparse
import torch
from pathlib import Path

torch.manual_seed(42)

cwd = os.getcwd()
# data_path = {'z500':str(cwd) + '/era5_data/geopotential_500/*.nc','t850':str(cwd) + '/era5_data/temperature_850/*.nc'}
SOLVERS = [
    "dopri8",
    "dopri5",
    "bdf",
    "rk4",
    "midpoint",
    "adams",
    "explicit_adams",
    "fixed_adams",
    "adaptive_heun",
    "euler",
]
parser = argparse.ArgumentParser("ClimODE")

parser.add_argument("--solver", type=str, default="euler", choices=SOLVERS)
parser.add_argument("--atol", type=float, default=5e-3)
parser.add_argument("--rtol", type=float, default=5e-3)
parser.add_argument(
    "--step_size", type=float, default=None, help="Optional fixed step size."
)
parser.add_argument("--niters", type=int, default=300)
parser.add_argument("--scale", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=6)
parser.add_argument("--spectral", type=int, default=0, choices=[0, 1])
parser.add_argument("--lr", type=float, default=0.0005)
parser.add_argument("--weight_decay", type=float, default=1e-5)


args = parser.parse_args("--scale 0 --batch_size 6 --spectral 0 --solver euler".split())
assert torch.cuda.is_available()
device = torch.device("cuda")
args

  from .autonotebook import tqdm as notebook_tqdm


Namespace(solver='euler', atol=0.005, rtol=0.005, step_size=None, niters=300, scale=0, batch_size=6, spectral=0, lr=0.0005, weight_decay=1e-05)

In [3]:
train_time_scale = slice("2006", "2016")
val_time_scale = slice("2016", "2016")
test_time_scale = slice("2017", "2018")
paths_to_data = [
    "era5_data/geopotential_500/*.nc",
    "era5_data/temperature_850/*.nc",
    "era5_data/2m_temperature/*.nc",
    "era5_data/10m_u_component_of_wind/*.nc",
    "era5_data/10m_v_component_of_wind/*.nc",
]
const_info_path = ["era5_data/constants/constants/constants_5.625deg.nc"]
levels = ["z", "t", "t2m", "u10", "v10"]

assert len(paths_to_data) == len(
    levels
), "Paths to different type of data must be same as number of types of observations"

In [9]:
import xarray as xr


def get_batched(train_times, data_train_final, lev):
    for idx, year in enumerate(train_times):
        data_per_year = data_train_final.sel(time=slice(str(year), str(year))).load()
        data_values = data_per_year[lev].values
        if idx == 0:
            train_data = torch.from_numpy(data_values).reshape(
                -1, 1, 1, data_values.shape[-2], data_values.shape[-1]
            )
            if year % 4 == 0:
                train_data = torch.cat(
                    (train_data[:236], train_data[240:])
                )  # skipping 29 feb in leap year
        else:
            mid_data = torch.from_numpy(data_values).reshape(
                -1, 1, 1, data_values.shape[-2], data_values.shape[-1]
            )
            if year % 4 == 0:
                mid_data = torch.cat(
                    (mid_data[:236], mid_data[240:])
                )  # skipping 29 feb in leap year
            train_data = torch.cat([train_data, mid_data], dim=1)

    return train_data


def get_train_test_data_without_scales_batched(
    data_path, train_time_scale, val_time_scale, test_time_scale, lev, spectral
):
    data = xr.open_mfdataset(data_path, combine="by_coords")
    # data = data.isel(lat=slice(None, None, -1))
    if lev in ["v", "u", "r", "q", "tisr"]:
        data = data.sel(level=500)
    data = data.resample(time="6h").nearest(
        tolerance="1h"
    )  # Setting data to be 6-hour cycles
    data_train = data.sel(time=train_time_scale).load()
    data_val = data.sel(time=val_time_scale).load()
    data_test = data.sel(time=test_time_scale).load()
    data_global = data.sel(time=slice("2006", "2018")).load()

    max_val = data_global.max()[lev].values.tolist()
    min_val = data_global.min()[lev].values.tolist()

    data_train_final = (data_train - min_val) / (max_val - min_val)
    data_val_final = (data_val - min_val) / (max_val - min_val)
    data_test_final = (data_test - min_val) / (max_val - min_val)

    time_vals = data_test_final.time.values
    train_times = [i for i in range(2006, 2016)]
    test_times = [2017, 2018]
    val_times = [2016]

    train_data = get_batched(train_times, data_train_final, lev)
    test_data = get_batched(test_times, data_test_final, lev)
    val_data = get_batched(val_times, data_val_final, lev)

    t = [i for i in range(365 * 4)]
    time_steps = torch.tensor(t).view(-1, 1)
    return (
        train_data,
        val_data,
        test_data,
        time_steps,
        data.lat.values,
        data.lon.values,
        max_val,
        min_val,
        time_vals,
    )


Final_train_data = 0
Final_val_data = 0
Final_test_data = 0
max_lev = []
min_lev = []

for idx, data in enumerate(tqdm(paths_to_data, desc="reading data")):
    Train_data, Val_data, Test_data, time_steps, lat, lon, mean, std, time_stamp = (
        get_train_test_data_without_scales_batched(
            data,
            train_time_scale,
            val_time_scale,
            test_time_scale,
            levels[idx],
            args.spectral,
        )
    )
    max_lev.append(mean)
    min_lev.append(std)
    if idx == 0:
        Final_train_data = Train_data
        Final_val_data = Val_data
        Final_test_data = Test_data
    else:
        Final_train_data = torch.cat([Final_train_data, Train_data], dim=2)
        Final_val_data = torch.cat([Final_val_data, Val_data], dim=2)
        Final_test_data = torch.cat([Final_test_data, Test_data], dim=2)

print("train, val, test data shapes:")
print(Final_train_data.shape, Final_test_data.shape, Final_val_data.shape)

reading data:   0%|          | 0/5 [00:00<?, ?it/s]

reading data: 100%|██████████| 5/5 [00:20<00:00,  4.04s/it]

train, val, test data shapes:
torch.Size([1460, 10, 5, 32, 64]) torch.Size([1460, 2, 5, 32, 64]) torch.Size([1460, 1, 5, 32, 64])





In [10]:
def add_constant_info(path):
    data = xr.open_mfdataset(path, combine="by_coords")
    for idx, var in enumerate(["orography", "lsm"]):
        var_value = torch.from_numpy(data[var].values).view(1, 1, 32, 64)
        if idx == 0:
            final_var = var_value
        else:
            final_var = torch.cat([final_var, var_value], dim=1)

    return (
        final_var,
        torch.from_numpy(data["lat2d"].values),
        torch.from_numpy(data["lon2d"].values),
    )


const_channels_info, lat_map, lon_map = add_constant_info(const_info_path)
H, W = Train_data.shape[3], Train_data.shape[4]
Train_loader = DataLoader(
    Final_train_data[2:], batch_size=args.batch_size, shuffle=False, pin_memory=False
)
Val_loader = DataLoader(
    Final_val_data[2:], batch_size=args.batch_size, shuffle=False, pin_memory=False
)
Test_loader = DataLoader(
    Final_test_data[2:], batch_size=args.batch_size, shuffle=False, pin_memory=False
)
time_loader = DataLoader(
    time_steps[2:], batch_size=args.batch_size, shuffle=False, pin_memory=False
)
time_idx_steps = torch.tensor([i for i in range(365 * 4)]).view(-1, 1)
time_idx = DataLoader(
    time_idx_steps[2:], batch_size=args.batch_size, shuffle=False, pin_memory=False
)

# Model declaration
num_years = len(range(2006, 2016))

In [8]:
model = Climate_encoder_free_uncertain(
    len(paths_to_data),
    2,
    out_types=len(paths_to_data),
    method=args.solver,
    use_att=True,
    use_err=True,
    use_pos=False,
).to(device)

(torch.Size([1460, 10, 5, 32, 64]),
 torch.Size([1460, 2, 5, 32, 64]),
 torch.Size([1460, 1, 5, 32, 64]))