In [1]:
# basic libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# sklearn
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, KFold, cross_val_predict, GridSearchCV
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, StandardScaler
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.svm import SVR
from sklearn.cluster import KMeans

#others
from xgboost import XGBRegressor
import cartopy.crs as ccrs
import cartopy.mpl.ticker as cticker
import time
import xarray as xr
import sherpa
import time
from scipy.spatial import Delaunay
from scipy import interpolate
from copy import deepcopy

# enable autoreload
%load_ext autoreload
%autoreload 2
# Variables from config file
from config import BASE_DIR, FILE_NAMES, LABELS, ATTRIBUTES, BEST_MODEL_COLUMNS, ISLAND_RANGES
from math import pi as PI

In [2]:
columns = deepcopy(LABELS)
columns.extend(["season_wet", "elevation", "lat", "lon"])
for item in columns:
    print(item, end=' ')

# load datasets
df_train = pd.read_csv(f"{BASE_DIR}/train.csv", usecols=columns + ['year', 'month', 'skn', 'data_in'])
df_valid = pd.read_csv(f"{BASE_DIR}/valid.csv", usecols=columns + ['year', 'month', 'skn', 'data_in'])
df_test = pd.read_csv(f"{BASE_DIR}/test.csv", usecols=columns + ['year', 'month', 'skn', 'data_in'])
df_combined = pd.concat([df_train, df_valid, df_test])

air2m air1000_500 hgt500 hgt1000 omega500 pottemp1000-500 pottemp1000-850 pr_wtr shum-uwnd-700 shum-uwnd-925 shum-vwnd-700 shum-vwnd-950 shum700 shum925 skt slp season_wet elevation lat lon 

In [3]:
# Split the stations by the number of samples available
threshold = 400
df_split = df_combined.groupby('skn').size().reset_index().rename(columns={0: "n_samples"})
df_split['class'] = df_split.apply(lambda row: 0 if row['n_samples'] < threshold else 1, axis=1)
df_combined = df_combined.merge(right=df_split, left_on="skn", right_on='skn')

In [37]:
train, valid, test = (0, 0, 0)
df_train = []
df_valid = []
df_test = []
np.random.seed(40)
for name, group in df_combined[df_combined['class']==1].groupby(by=["year", "month"]):
    # print(name, len(group))
    label = np.random.choice(a=["train", "valid", "test"], size=1, replace=True, p=[0.6, 0.2, 0.2])
    if label == "train":
        train += len(group)
        df_train.append(group)
    elif label == "valid":
        valid += len(group)
        df_valid.append(group)
    else:
        test += len(group)
        df_test.append(group)
print(len(df_train), len(df_valid), len(df_test))

462 162 156


In [38]:
df_train = pd.concat(df_train).reset_index()
df_valid = pd.concat(df_valid).reset_index()
df_test = pd.concat(df_test).reset_index()
df_train.shape, df_valid.shape, df_test.shape

((461374, 27), (158908, 27), (159178, 27))

In [29]:
df_train.reset_index().to_csv(f"{BASE_DIR}/split_on_n_samples/high/train.csv", index=False)
df_valid.reset_index().to_csv(f"{BASE_DIR}/split_on_n_samples/high/valid.csv", index=False)
df_test.reset_index().to_csv(f"{BASE_DIR}/split_on_n_samples/high/test.csv", index=False)

In [35]:
df_train

Unnamed: 0,index,skn,year,month,data_in,lat,lon,elevation,air2m,air1000_500,...,shum-uwnd-925,shum-vwnd-700,shum-vwnd-950,shum700,shum925,skt,slp,season_wet,n_samples,class
0,301,1.00,1948,1,3.200000,18.916176,-155.674994,35.000000,295.39603,31.299995,...,-25.859348,0.589191,7.106412,2.945999,9.869999,23.385218,1014.08490,1,484,1
1,969,2.00,1948,1,5.950000,19.108660,-155.825545,1750.000000,295.39603,31.299995,...,-25.859348,0.589191,7.106412,2.945999,9.869999,23.385218,1014.08490,1,775,1
2,1709,2.20,1948,1,11.500000,19.164740,-155.682280,4890.000000,295.39603,31.299995,...,-25.859348,0.589191,7.106412,2.945999,9.869999,23.385218,1014.08490,1,720,1
3,2429,2.25,1948,1,5.515941,19.160603,-155.822488,2940.000000,295.39603,31.299995,...,-25.859348,0.589191,7.106412,2.945999,9.869999,23.385218,1014.08490,1,720,1
4,3149,2.26,1948,1,4.310617,19.225323,-155.778876,5680.000000,295.39603,31.299995,...,-25.859348,0.589191,7.106412,2.945999,9.869999,23.385218,1014.08490,1,720,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
512900,865536,221.60,2012,12,3.590551,20.027500,-155.406300,1897.489809,295.92688,30.300003,...,-70.068130,0.526238,4.627120,2.505999,11.017000,23.982779,1016.88490,1,9,0
512901,865544,55.20,2012,12,16.488189,19.433300,-155.228000,3657.134336,295.92688,30.300003,...,-70.068130,0.526238,4.627120,2.505999,11.017000,23.982779,1016.88490,1,8,0
512902,865550,87.30,2012,12,0.940945,19.713600,-155.079100,17.038321,295.92688,30.300003,...,-70.068130,0.526238,4.627120,2.505999,11.017000,23.982779,1016.88490,1,5,0
512903,865554,27.70,2012,12,0.330709,19.459400,-155.893800,778.045519,295.92688,30.300003,...,-70.068130,0.526238,4.627120,2.505999,11.017000,23.982779,1016.88490,1,3,0
