In [1]:
import os
import sys
import logging
import torch
import xarray as xr
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path
from sklearn.metrics import r2_score
from torch.utils.data import DataLoader
project_root = Path.cwd().parent  
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
from src.data_preprocessing.normalization import XarrayNormalizer
from src.data_preprocessing.create_dataloaders import SweDataset
from src.data_preprocessing.split_data import split_by_time
from src.utils.utils import (write_to_netcdf, data_split, unscale_pred, load_scalers)
from src.training.trainer import evaluate_model
from src.models.swe_net import SWE_NET

# DATA SETUP

In [2]:
# setup test set for prediction

sequence_length = 5
start = f'2015-10-{sequence_length + 1}'
end = '2016-09-30'

true_labels = data_split('2016', '2016')['SNOW'].sel(XTIME=slice(start, end))
true_labels.shape

(361, 390, 348)

In [3]:
base = '/bsuscratch/stanleyakor/swe_emulator/modis'

ds1 = xr.open_dataset(base + '/wrf_features_june_30_2025.nc')[['SNOWNC_CUMSUM', 'PRCP_CUMSUM', 'TMIN', 'TMAX', 'ELEVATION','DAY_SIN', 'DAY_COS']]
ds2 = xr.open_dataset(base + '/snowcover_june_30_2025.nc')
ds3 = xr.open_dataset(base +  '/lai_june_30_2025.nc').sel(XTIME=slice("2005-10-01", "2016-09-30"))
ds2['XTIME'] = pd.to_datetime(ds2['XTIME'].values)


data = xr.merge([ds1, ds2, ds3])
data = data.rename({"snow_presence": "BINARY_SNOW_CLASS"})

# ------------------ Data Splitting ------------------
split = split_by_time(data)

# ------------------ Normalization ------------------
variables = ['SNOWNC_CUMSUM', 'PRCP_CUMSUM', 'TMIN', 'TMAX','ELEVATION', 'BINARY_SNOW_CLASS', 'LAI','DAY_SIN', 'DAY_COS']
normalizer = XarrayNormalizer(split['train'])
train_features_norm = normalizer.fit_transform(
    variables=variables, 
    method="minmax", 
    save_scaler_path= base + "/scalers/scalers.pkl"
)

normalizer_test = XarrayNormalizer(split['test'])
val_features_norm = normalizer_test.transform(
    variables=variables,
    load_scaler_path= base + "/scalers/scalers.pkl"
)

# ------------------ Target Normalization ------------------
target_data = xr.open_dataset(base + '/wrf_target_june_30_2025.nc')
split_target = split_by_time(target_data)

target_normalizer_train = XarrayNormalizer(split_target['train'])
train_target_norm = target_normalizer_train.fit_transform(
    variables=['SNOW'],
    method="minmax",
    save_scaler_path=base + "/scalers/target_scalers.pkl"
)

target_normalizer_test = XarrayNormalizer(split_target['test'])
val_target_norm = target_normalizer_test.transform(
    variables=['SNOW'],
    load_scaler_path=base +"/scalers/target_scalers.pkl"
)



In [4]:
channel_order = ['SNOWNC_CUMSUM', 'PRCP_CUMSUM', 'TMIN', 'TMAX','ELEVATION', 'BINARY_SNOW_CLASS', 'LAI','DAY_SIN', 'DAY_COS']

test_dataset = SweDataset(val_features_norm, val_target_norm, sequence_length, channel_order)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load Saved Model

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
checkpoint_path = '../saved_models/WRF_MODIS_STATIC_v2.pth'
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

In [7]:
model = SWE_NET(input_dim=9, hidden_dim=64, kernel_size=(3, 3),height=390, width=348, dropout_rate=0.3).to(device)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

# Make Prediction

In [8]:
conf_path = base + "/scalers/target_scalers.pkl"
unscaled_preds = unscale_pred(model,test_loader, conf_path)

100%|██████████████████████████████████████████████████████████████████████████| 361/361 [00:27<00:00, 13.32it/s]


In [9]:
static = xr.open_dataset('/bsuscratch/stanleyakor/uppercolorado/static_inputs/wrfout_d02_2000-04-08_00:00:00').isel(Time=0)
lat = static.XLAT.values[:, 0]
lon = static.XLONG.values[0, :]
start_date = f'2015-10-{sequence_length + 1}'
end_date = '2016-09-30'
output_file = '../data/WRF_MODIS_STATIC_v2_PREDITICTION.nc'

write_to_netcdf(output_file, start_date, end_date, lat, lon, unscaled_preds)