In [None]:
import os
import torch
import mat73
from tqdm import tqdm
import numpy as np

def load_mat_files(path, order = None):
    data_dict = {}
    if(order is None):
        order = os.listdir(path)
    for filename in tqdm(order):
        if filename.endswith('.mat'):
            data = mat73.loadmat(os.path.join(path, filename))
            data_dict[filename] = data
            
    return data_dict

# set up
raw_data_path = './raw'
savepath = "./Processed"

# load train/test/nan split mask
mask_file = mat73.loadmat(os.path.join(raw_data_path, "Mask_US.mat"))
mask = mask_file['TrainMask_US'].astype(np.float32)
torch.save(torch.tensor(mask), os.path.join(savepath, 'mask.pt'))
train_mask = mask == 1.0
test_mask = mask == 2.0
not_nan_mask =  ~np.isnan(mask)
print(f'train data: {np.sum(train_mask)}, test data: {np.sum(test_mask)}, all data: {np.sum(not_nan_mask)}')

### Drought Indices Process

In [None]:
drought_indices_path = os.path.join(raw_data_path, "Drought_Indices")     
drought_indices_list = ['ESI_2003_2013.mat','SIF_2003_2013.mat','SMsurface_2003_2013.mat'] 
drought_indices_data = load_mat_files(drought_indices_path, drought_indices_list)
drought_indices_data = {f"{outer_key}_{inner_key}": value 
                  for outer_key, inner_dict in drought_indices_data.items()
                  for inner_key, value in inner_dict.items()}
# dought index shape: 585x1386 pixels, 52 weeks x 11 years

In [None]:
def process_drought_indices(drought_indices_data, not_nan_mask, fill_with_mean=True):
    drought_indices_tensors = {}
    drought_nan_mask = {}

    def fill_tensor(k, v):
        original_tensor = torch.tensor(v[not_nan_mask].astype(np.float32)).permute(0, 2, 1) # select train/test pixels
        original_nan_mask = torch.isnan(original_tensor)
        all_mean = torch.nanmean(torch.nanmean(original_tensor, dim=1), dim=0) #52-week yearly average for all pixels, used to fill completely missing values
        all_mean_expanded = all_mean.unsqueeze(0)  
        all_nan_mask = torch.isnan(original_tensor).all(dim=2).all(dim=1) # pixels which have no data, filled with the average value for each year.
        original_tensor[all_nan_mask] = all_mean_expanded

        # Self-fill for remaining NaNs
        mean_values = torch.nanmean(original_tensor, dim=1)
        nan_mask = torch.isnan(original_tensor)
        original_tensor[nan_mask] = mean_values.unsqueeze(1).expand(-1, original_tensor.shape[1], -1)[nan_mask]

        if(torch.sum(torch.isnan(original_tensor))>0): # missing some weeks
            nan_mask = torch.isnan(original_tensor)
            original_tensor[nan_mask] = all_mean.unsqueeze(0).unsqueeze(0).expand(original_tensor.shape[0], original_tensor.shape[1], -1)[nan_mask]

        return original_tensor, original_nan_mask

    if fill_with_mean:  # Perform mean value filling
        for k, v in drought_indices_data.items():
            filled_tensor, original_nan_mask = fill_tensor(k, v)
            drought_nan_mask[k] = original_nan_mask
            drought_indices_tensors[k] = filled_tensor
    else:
        for k, v in drought_indices_data.items():
            drought_indices_tensors[k] = torch.tensor(v[not_nan_mask].astype(np.float32)).permute(0, 2, 1)

    return drought_indices_tensors, drought_nan_mask

drought_indices_tensors, drought_nan_mask = process_drought_indices(drought_indices_data, not_nan_mask, fill_with_mean=True)
torch.save(drought_indices_tensors, os.path.join(savepath, "target.pt"))
torch.save(drought_nan_mask, os.path.join(savepath, "drought_nan_mask.pt") )

In [None]:
predictors_path = os.path.join(raw_data_path,"Predictors")
predictors_list = ['ESImm.mat',
 'ESIstd.mat',
 'height.mat',
 'lai_2003_2013.mat',
 'nlcd.mat',
 'pdsi_2003_2013.mat',
 'pet_2003_2013.mat',
 'pr_2003_2013.mat',
 'rad_2003_2013.mat',
 'SIFmm.mat',
 'SIFstd.mat',
 'SMmm.mat',
 'smroot_2003_2013.mat',
 'SMstd.mat',
 'sp_2003_2013.mat',
 'tas_2003_2013.mat',
 'topography.mat',
 'vod_2003_2013.mat',
 'vpd_2003_2013.mat',
 'ws_2003_2013.mat']
predictors_data = load_mat_files(predictors_path, predictors_list)
predictors_data = {f"{outer_key}_{inner_key}": value.astype(np.float32) 
                  for outer_key, inner_dict in predictors_data.items()
                  for inner_key, value in inner_dict.items()}
for k, v in predictors_data.items():
    print(k, v.shape)

In [None]:

def process_predictors(predictors_data, not_nan_mask, fill_with_mean=True):
    
    predictors_tensors = {}

    def fill_tensor(k, v):
        original_tensor = torch.tensor(v[not_nan_mask].astype(np.float32)).permute(0, 2, 1)
        original_nan_mask = torch.isnan(original_tensor)
        all_mean = torch.nanmean(torch.nanmean(original_tensor, dim=1), dim=0)

        all_mean_expanded = all_mean.unsqueeze(0)  # [1,52]
        all_nan_mask_small = torch.isnan(original_tensor).all(dim=2).all(dim=1)
        original_tensor[all_nan_mask_small] = all_mean_expanded

        mean_values = torch.nanmean(original_tensor, dim=1)
        original_nan_mask = torch.isnan(original_tensor)
        original_tensor[original_nan_mask] = mean_values.unsqueeze(1).expand(-1, original_tensor.shape[1], -1)[original_nan_mask]

        if torch.sum(torch.isnan(original_tensor)) > 0:
            original_nan_mask = torch.isnan(original_tensor)
            original_tensor[original_nan_mask] = all_mean_expanded.unsqueeze(1).expand(original_tensor.shape[0], original_tensor.shape[1], original_tensor.shape[2])[original_nan_mask]

        return original_tensor

    for k, v in predictors_data.items():
        if len(v.shape) >= 3:  # Temporal data
            if fill_with_mean:
                predictors_tensors[k] = fill_tensor(k, v)
            else:
                predictors_tensors[k] = torch.tensor(v[not_nan_mask]).permute(0, 2, 1)
        else:  # Static data
            predictors_tensors[k] = torch.tensor(v[not_nan_mask])


    return predictors_tensors

predictors_tensors = process_predictors(predictors_data, not_nan_mask, fill_with_mean=True)
torch.save(predictors_tensors, os.path.join(savepath, "inputs.pt"))


In [None]:
mean_std_max_min_dict = {'mean':{}, 'std':{}, 'max':{}, 'min':{}}
for k,v in drought_indices_tensors.items():
    mean_std_max_min_dict['mean'][k] = v.nanmean().clone() 
    mean_std_max_min_dict['std'][k] = v[~torch.isnan(v)].std().clone()
    mean_std_max_min_dict['max'][k] = (v[~torch.isnan(v)]).max().clone()
    mean_std_max_min_dict['min'][k] = (v[~torch.isnan(v)]).min().clone()
for k,v in predictors_tensors.items():
    mean_std_max_min_dict['mean'][k] = v.nanmean().clone()
    mean_std_max_min_dict['std'][k] = v[~torch.isnan(v)].std().clone()
    mean_std_max_min_dict['max'][k] = (v[~torch.isnan(v)]).max().clone()
    mean_std_max_min_dict['min'][k] = (v[~torch.isnan(v)]).min().clone()
torch.save(mean_std_max_min_dict, os.path.join(savepath, 'stat.pt'))

