### climatology
```
variables = [
 '2m_temperature',
 '10m_u_component_of_wind',
 '10m_v_component_of_wind',
 'mean_sea_level_pressure',
 'geopotential',
 'temperature',
 'specific_humidity',
 'u_component_of_wind',
 'v_component_of_wind',
 'dayofyear',
 'hour',
 'latitude',
 'level',
 'longitude'
]

levels = [500, 700, 850]
```
### test
```
levels =[   
    1,    2,    3,    5,    7,   10,   20,   30,   50,   70,  100,
    125,  150,  175,  200,  225,  250,  300,  350,  400,  450,  500,
    550,  600,  650,  700,  750,  775,  800,  825,  850,  875,  900,
    925,  950,  975, 1000   
]
```

In [1]:
import os
os.sys.path.append('/localhome/prateiksinha/atmos-arena/atmos_arena/atmos_utils')
import numpy as np
from pprint import pprint
import netCDF4 as nc
from metrics import *
import torch
import csv
import pandas as pd
import xarray as xr

In [2]:
test_lvls = [   
      1,    2,    3,    5,    7,   10,   20,   30,   50,   70,  100,
    125,  150,  175,  200,  225,  250,  300,  350,  400,  450,  500,
    550,  600,  650,  700,  750,  775,  800,  825,  850,  875,  900,
    925,  950,  975, 1000   
]
test_lvl2idx = {v:i for i,v in enumerate(test_lvls)}

def load_clim(var, lvl=None):
    file_path = '/localhome/data/datasets/climate/climatology.nc'
    dataset = nc.Dataset(file_path, mode='r')
    variables = list(dataset.variables.keys())
    if lvl:
        return dataset[var][:,:,clim_lvl2idx[lvl],:,:]
    else:
        return dataset[var]

def load_test(var, year, lvl=None):
    test_path = f'/localhome/data/datasets/climate/era5_corpenicus/{var}/{year}.nc'
    dataset = nc.Dataset(test_path, mode='r')
    var_abbrv = list(dataset.variables.keys())[-1]
    if lvl:
        return (
            dataset.variables[var_abbrv][:,test_lvl2idx[lvl],:,:], 
            dataset['latitude']
        )
    else:
        return dataset.variables[var_abbrv], dataset['latitude']


In [3]:
device = 'cuda'
log_dir = '/localhome/prateiksinha/atmos-arena/atmos_arena/s2s_stormer/clim_data'

variables = [
    # '2m_temperature',
    'temperature',
    'specific_humidity',
    'geopotential'
]
levels = [
    # None,
    850,
    700,
    500
]
years = ['2020', '2021', '2022', '2023']

for var, lvl in zip(variables, levels):
    
    clim = load_clim(var, lvl)
    clim = np.array(clim).swapaxes(0,1).reshape((1464, 721, 1440))
    clim = torch.tensor(clim).to(device)

    for year in years:
        test, lat = load_test(var, year, lvl)
        test, lat = np.array(test), np.array(lat)
        test = torch.tensor(test).to(device)

        # print(clim.shape)
        # print(test.shape)

        window_size = 14 * 4

        csv_path = f'{log_dir}/{year}_{var}.csv'
        
        file = open(csv_path, mode='w', newline='')
        writer = csv.writer(file)
        writer.writerow(['year', 'variable', 'start', 'mse', 'weight_mse', 'weighted_rmse']) 

        with torch.no_grad():
            for i in range(1464 - window_size):
                clim_window = torch.mean(clim[i:i+window_size,:,:], dim=0)[None,None,:,:]
                test_window = torch.mean(test[i:i+window_size,:,:], dim=0)[None,None,:,:]
                
                mse_ = mse(clim_window, test_window, [var])['loss'].item()
                l_mse = lat_weighted_mse(clim_window, test_window, [var], lat)['loss'].item()
                l_rmse = lat_weighted_rmse(clim_window, test_window, lambda x:x, [var], lat, None)['w_rmse']
                writer.writerow([year, var, i, mse_, l_mse, l_rmse])    
                print(i, end = '\r')
        file.close()

In [5]:
csvs = []
variables = [
    'temperature',
]
years = ['2020', '2021', '2022', '2023']
log_dir = '/localhome/prateiksinha/atmos-arena/atmos_arena/s2s_stormer/clim_data'

for var in variables:
    print(var,':')
    for year in years:
        csvs.append(f'{log_dir}/{year}_{var}.csv')
    all_years = pd.concat([pd.read_csv(x) for x in csvs])
    for metric in ['mse', 'weight_mse', 'weighted_rmse']:
        print(f'{metric}: {all_years[metric].mean()}')
    print()
    csvs = []    

temperature :
mse: 6.6127676906383845
weight_mse: 4.730860434080895
weighted_rmse: 2.001581556030883

