In [1]:
import xarray as xr
import pandas as pd
import os
import pickle
import imageio
from PIL import Image
import seaborn as sns
import numpy as np
import deepsensor
import cartopy.crs as ccrs
from tqdm.notebook import tqdm
import cartopy.feature as cfeature
import deepsensor.torch
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import Point
from tqdm import tqdm_notebook
from deepsensor.model import ConvNP
from tqdm import tqdm_notebook
from matplotlib.animation import FuncAnimation, PillowWriter
from deepsensor.data import DataProcessor, TaskLoader
from deepsensor.train import set_gpu_default_device, Trainer

In [2]:
set_gpu_default_device()

#### Reading data files and merging them to form an xarray ds

In [3]:
def extract_timestamp(file):
    return np.datetime64(f'{file[-9:-5]}-{file[-5:-3]}-01 00:00:00.000000000','ns')

combined_ds = []
dir_path = 'Monthly_01'
for folder in os.listdir(dir_path):
    for file in os.listdir(f'{dir_path}/{folder}'):
        monthly_ds = xr.open_dataset(os.path.join(dir_path,folder,file))
        timestamp = extract_timestamp(file)
        combined_ds.append(monthly_ds.expand_dims({'Timestamp':[timestamp]}))
wustl_monthly_ds = xr.concat(combined_ds,dim='Timestamp').sortby('Timestamp')

wustl_monthly_ds = wustl_monthly_ds.sel(Timestamp = slice("1998","2019"))   # Removing covid affected years 

In [4]:
wustl = wustl_monthly_ds.copy()
wustl

In [5]:
india_lat_min, india_lat_max = 5, 39
india_lon_min, india_lon_max = 67, 99

# Use `.sel()` to filter the dataset (assuming lat is sorted from min to max)
wustl_india = wustl.sel(lat=slice(india_lat_min, india_lat_max), lon=slice(india_lon_min, india_lon_max))

# If lat isn't sorted, use `.where()` instead
wustl_india = wustl.where(
    (wustl.lat >= india_lat_min) & (wustl.lat <= india_lat_max) & 
    (wustl.lon >= india_lon_min) & (wustl.lon <= india_lon_max),
    drop=True
)

In [6]:
wustl_india = wustl_india.rename({'Timestamp':'time'})
wustl_india

In [15]:
wustl_india.to_netcdf('wustl_dataset.nc')

In [7]:
train_val_data = wustl_india.sel(time=slice('1998', '2010'))
pd.to_datetime(train_val_data.time.min().item()), pd.to_datetime(train_val_data.time.max().item())

(Timestamp('1998-01-01 00:00:00'), Timestamp('2010-12-01 00:00:00'))

In [8]:
train_data = train_val_data.sel(time=slice('1998', '2008'))
val_data = train_val_data.sel(time=slice('2009', '2010'))
len(train_data.time), len(val_data.time)

(132, 24)

In [9]:
save_dir = "./data"

In [10]:
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
scaled_train_data = data_processor(train_data) # similar to fit_transform
scaled_val_data = data_processor(val_data) # similar to transform

  f"x1_map={x1_map} and x2_map={x2_map} have different ranges ({float(np.diff(x1_map))} "
  f"and {float(np.diff(x2_map))}, respectively). "


In [None]:
scaled_train_data.to_netcdf(f"{save_dir}/scaled_train_data.nc")
scaled_val_data.to_netcdf(f"{save_dir}/scaled_val_data.nc")
data_processor.save(f"{save_dir}/data_processor_config/")

In [12]:
val_task_loader = TaskLoader(context=scaled_val_data,target=scaled_val_data)
val_task_loader

TaskLoader(1 context sets, 1 target sets)
Context variable IDs: (('PM25',),)
Target variable IDs: (('PM25',),)

Context data dimensions: (1,)
Target data dimensions: (1,)

In [13]:
n_context_list = [5, 20, 50, 100, 200, 500]
seeds = [0, 1, 2, 3, 4]
progress_bar = tqdm(total=len(n_context_list) * len(seeds))
np.random.seed(0)
checksum = 0
for n_context in n_context_list:
    for seed in seeds:
        tasks = []
        for timestamp in scaled_val_data.time.values:
            task = val_task_loader(
            timestamp,
            context_sampling=n_context,
            target_sampling="all",
            seed_override=seed + int(timestamp) + n_context,
            )
            tasks.append(task)
        save_path = f"{save_dir}/val_tasks/{n_context=}/{seed=}.pkl"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, "wb") as f:
            pickle.dump(tasks, f)
        progress_bar.update(1)
        checksum += tasks[0]['X_t'][0][0].sum()

print(checksum)

  0%|          | 0/30 [00:00<?, ?it/s]

5099.9995
