In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
from tqdm.cli import tqdm
import os
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchdiffeq import odeint as odeint
import argparse
import torch
from pathlib import Path

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import xarray as xr

In [None]:
torch.manual_seed(42)

# data_path = {'z500':str(cwd) + '/era5_data/geopotential_500/*.nc','t850':str(cwd) + '/era5_data/temperature_850/*.nc'}
SOLVERS = [
    "dopri8",
    "dopri5",
    "bdf",
    "rk4",
    "midpoint",
    "adams",
    "explicit_adams",
    "fixed_adams",
    "adaptive_heun",
    "euler",
]
parser = argparse.ArgumentParser("ClimODE")

parser.add_argument("--solver", type=str, default="euler", choices=SOLVERS)
parser.add_argument("--atol", type=float, default=5e-3)
parser.add_argument("--rtol", type=float, default=5e-3)
parser.add_argument(
    "--step_size", type=float, default=None, help="Optional fixed step size."
)
parser.add_argument("--niters", type=int, default=100)
parser.add_argument("--scale", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=6)
parser.add_argument("--spectral", type=int, default=0, choices=[0, 1])
parser.add_argument("--lr", type=float, default=0.0005)
parser.add_argument("--weight_decay", type=float, default=1e-5)

args = parser.parse_args(
    "--scale 0 --batch_size 6 --spectral 0 --solver euler".split()
)
assert torch.cuda.is_available()
device = torch.device("cuda")
args

train_time_scale = slice("2006", "2016")
val_time_scale = slice("2016", "2016")
test_time_scale = slice("2017", "2018")
paths_to_data = [
    "era5_data/geopotential_500/*.nc",
    "era5_data/temperature_850/*.nc",
    "era5_data/2m_temperature/*.nc",
    "era5_data/10m_u_component_of_wind/*.nc",
    "era5_data/10m_v_component_of_wind/*.nc",
]
const_info_path = ["era5_data/constants/constants/constants_5.625deg.nc"]
levels = ["z", "t", "t2m", "u10", "v10"]

assert len(paths_to_data) == len(
    levels
), "Paths to different type of data must be same as number of types of observations"

Final_train_data = 0
Final_val_data = 0
Final_test_data = 0
max_lev = []
min_lev = []

for idx, data in enumerate(tqdm(paths_to_data, desc="reading data")):
    Train_data, Val_data, Test_data, time_steps, lat, lon, mean, std, time_stamp = (
        get_train_test_data_without_scales_batched(
            data,
            train_time_scale,
            val_time_scale,
            test_time_scale,
            levels[idx],
            args.spectral,
        )
    )
    max_lev.append(mean)
    min_lev.append(std)
    if idx == 0:
        Final_train_data = Train_data
        Final_val_data = Val_data
        Final_test_data = Test_data
    else:
        Final_train_data = torch.cat([Final_train_data, Train_data], dim=2)
        Final_val_data = torch.cat([Final_val_data, Val_data], dim=2)
        Final_test_data = torch.cat([Final_test_data, Test_data], dim=2)

print("train, val, test data shapes:")
print(Final_train_data.shape, Final_test_data.shape, Final_val_data.shape)

const_channels_info, lat_map, lon_map = add_constant_info(const_info_path)
H, W = Train_data.shape[3], Train_data.shape[4]
Train_loader = DataLoader(
    Final_train_data[2:],
    batch_size=args.batch_size,
    shuffle=False,
    pin_memory=False,
)
Val_loader = DataLoader(
    Final_val_data[2:], batch_size=args.batch_size, shuffle=False
)
Test_loader = DataLoader(
    Final_test_data[2:], batch_size=args.batch_size, shuffle=False
)
time_loader = DataLoader(time_steps[2:], batch_size=args.batch_size, shuffle=False)
time_idx_steps = torch.tensor([i for i in range(365 * 4)]).view(-1, 1)
time_idx = DataLoader(time_idx_steps[2:], batch_size=args.batch_size, shuffle=False)
num_years = len(range(2006, 2016))

if not Path("kernel.npy").exists():
    get_gauss_kernel((32, 64), lat, lon)

kernel = torch.from_numpy(np.load("kernel.npy"))
if not Path("test_10year_2day_mm_vel.npy").exists():
    print("Fitting velocity...")
    fit_velocity(
        time_idx,
        time_loader,
        Final_train_data,
        Train_loader,
        device,
        num_years,
        paths_to_data,
        args.scale,
        H,
        W,
        types="train_10year_2day_mm",
        vel_model=Optim_velocity,
        kernel=kernel,
    )
    fit_velocity(
        time_idx,
        time_loader,
        Final_val_data,
        Val_loader,
        device,
        1,
        paths_to_data,
        args.scale,
        H,
        W,
        types="val_10year_2day_mm",
        vel_model=Optim_velocity,
        kernel=kernel,
    )
    fit_velocity(
        time_idx,
        time_loader,
        Final_test_data,
        Test_loader,
        torch.device("cuda"),
        2,
        paths_to_data,
        args.scale,
        H,
        W,
        types="test_10year_2day_mm",
        vel_model=Optim_velocity,
        kernel=kernel,
    )

vel_train, vel_val = load_velocity(["train_10year_2day_mm", "val_10year_2day_mm"])
