In [1]:
## import math
import os
import sys
import pandas as pd
import numpy as np
import time
import tqdm


import matplotlib.pyplot as plt
from numpy.random import seed

import torch
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR

from sklearn import metrics
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from pathlib import Path

  from pandas.core import (


In [2]:
from data_organize import load_county_data, load_scym_data, data_normalization, torch_dataloader
from solver import Solver
from pathlib import Path

DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
def run_all(feature_size:int, num_seed: int):
    seed(num_seed)
    torch.manual_seed(num_seed)

    num_feature = feature_size
    num_shared_feature = 32

    learning_rate = 0.1
    nepoch = 400 + 1
    dp = 0.5

    optim = 'Adagrad'
    weight_decay = 0.01

    # Define the ADANN function
    solver = Solver(num_feature, num_shared_feature, learning_rate, dp, optim, weight_decay)

    #for t in range(nepoch): train_loader_loc
    t = 1
    loss_pred, loss_d_src, loss_d_tar = solver.train(train_loader_src, train_loader_src_all, train_loader_tar, test_loader_tar, train_loader_loc, t, nepoch)

    RMSE_src, R2_src, MAPE_src, r2_src, MAE_src, y_src, y_src_pred = solver.test(train_loader_src_origin, True)
    RMSE_tar, R2_tar, MAPE_tar, r2_tar, MAE_tar, y_tar, y_tar_pred = solver.test(test_loader_tar, True)

    return solver, RMSE_tar, R2_tar, MAPE_tar, r2_tar, MAE_tar, y_tar, y_tar_pred

In [4]:
def load_QDANN(feature_size:int, num_seed: int, year: int):
    seed(num_seed)
    torch.manual_seed(num_seed)

    num_feature = feature_size
    num_shared_feature = 32

    learning_rate = 0.1
    nepoch = 400 + 1
    dp = 0.5

    optim = 'Adagrad'
    weight_decay = 0.01

    # Define the model
    solver = Solver(num_feature, num_shared_feature, learning_rate, dp, optim, weight_decay)

    path = 'model/seed_' + str(num_seed) + '/'
    solver.load_models(year, path)

    return solver

In [5]:
drop_int = ['GCVI_int', 'GCVI_peak', 'GCVI_peak_30'] # 'GCVI_int', 'GCVI_peak_30'
drop_GCVI = ['GCVI_constant', 'GCVI_cos1', 'GCVI_cos2', 'GCVI_cos3', 'GCVI_sin1', 'GCVI_sin2', 'GCVI_sin3']
drop_NDVI = ['NDVI_constant', 'NDVI_cos1', 'NDVI_cos2', 'NDVI_cos3', 'NDVI_sin1', 'NDVI_sin2', 'NDVI_sin3']
drop_NIRv = ['NIRv_constant', 'NIRv_cos1', 'NIRv_cos2', 'NIRv_cos3', 'NIRv_sin1', 'NIRv_sin2', 'NIRv_sin3']
drop_weather = ['early_ppt', 'growing_ppt', 'growing_tmean', 'growing_sr']
drop_gridmet = ['pr_4', 'tmmn_4', 'tmmx_4','pr_5', 'tmmn_5', 'tmmx_5','pr_6', 'tmmn_6', 'tmmx_6','vpd_6','vpd_7', 'tmmn_7', 'tmmx_7','pr_7','vpd_8', 'tmmn_8', 'tmmx_8']

drop_apr = ['pr_4', 'tmmn_4', 'tmmx_4', 'vpd_4', 'srad_4']
drop_may = ['pr_5', 'tmmn_5', 'tmmx_5', 'vpd_5', 'srad_5']

drop_features = drop_NDVI + drop_NIRv + drop_int + drop_weather + drop_apr #+ drop_may

In [6]:
target_state =  None
target_state_ID = 19
num_sample = 0
add_gridmet = 1
add_SIF = 0

In [7]:
training = 0
seed_used = [1, 2, 3, 4, 5]

for id in range(5):

  seed_num = seed_used[id]
  print("seed = ", seed_num)

  year_all = []
  RMSE_all = []
  R2_all = []
  r2_all = []
  MAPE_all = []
  MAE_all = []

  for year in [2018]:

    print(year)
    years = np.arange(2008, 2019, 1)
    
    # Load the field dataset
    selected_states = [17, 18, 19, 27, 29, 39, 46, 55] # 17, 18, 19, 27, 29, 39, 46, 55
    X_src, y_src, dist_all = load_county_data(selected_states, years, drop_features, year)
    X_loc, y_loc, _ = load_county_data(selected_states, [year], drop_features, year)

    ## Normalize the data
    normalization = True
    norm_type = 'Standard'
    X_train_all, y_train_all, X_train_src, y_train_src = data_normalization(X_src, y_src, X_src, y_src, normalization, norm_type)
    X_train_all, y_train_all, X_train_loc, y_train_loc = data_normalization(X_src, y_src, X_loc, y_loc, normalization, norm_type)

    X_train_src_o = X_train_src.copy()
    X_train_src = np.concatenate((X_train_src, dist_all), axis=1)

    batch_size = 512
    train_loader_src_origin  = torch_dataloader(X_train_src_o, y_train_src, X_train_src_o.shape[0])
    train_loader_src     = torch_dataloader(X_train_src, y_train_src, batch_size)
    train_loader_src_all = torch_dataloader(X_train_src, y_train_src, X_train_src.shape[0], False)
    train_loader_loc = torch_dataloader(X_train_loc, y_train_loc, X_train_loc.shape[0], False)
    
    
    if training == 1: # Load the field dataset (if there is)
      X_tar, y_tar, y_county_yield, dist_scym = load_scym_data(target_state, year, drop_features, num_sample)
      X_train_all, y_train_all, X_train_tar, y_train_tar = data_normalization(X_src, y_src, X_tar, y_tar, normalization, norm_type)
      train_loader_tar = torch_dataloader(X_train_tar, y_train_tar, batch_size)
      test_loader_tar  = torch_dataloader(X_train_tar, y_train_tar, X_train_tar.shape[0], False)

    # Initialize and train the model
    feature_size = X_src.shape[1]

    if training == 1:
      solver, RMSE_tar, R2_tar, MAPE_tar, r2_tar, MAE_tar, y_tar, y_tar_pred = run_all(feature_size, seed_num)
    else:
      solver = load_QDANN(feature_size, seed_num, year)
      print("results under random seed ", seed_num)
      RMSE_src, R2_src, MAPE_src, r2_src, MAE_src, y_src, y_src_pred = solver.test(train_loader_src_origin, True)
      #RMSE_tar, R2_tar, MAPE_tar, r2_tar, MAE_tar, y_tar, y_tar_pred = solver.test(test_loader_tar, True)
      


seed =  1
2018
results under random seed  1
Test RMSE = 0.8790, R2 = 0.7302, MARE = 0.0629, r2 = 0.7820 

seed =  2
2018
results under random seed  2
Test RMSE = 0.8780, R2 = 0.7309, MARE = 0.0630, r2 = 0.7879 

seed =  3
2018
results under random seed  3
Test RMSE = 0.8459, R2 = 0.7502, MARE = 0.0605, r2 = 0.7890 

seed =  4
2018
results under random seed  4
Test RMSE = 0.8539, R2 = 0.7454, MARE = 0.0613, r2 = 0.7884 

seed =  5
2018
results under random seed  5
Test RMSE = 0.8674, R2 = 0.7373, MARE = 0.0621, r2 = 0.7915 

