In [1]:
# %load test_radnet_2.py
%matplotlib notebook
from radnet import RadNet
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from netCDF4 import Dataset
import fnmatch
import os
import sys
import operator
from sympl import (
    DataArray, AdamsBashforth, PlotFunctionMonitor)
from climt import RRTMGLongwave, get_default_state
from datetime import timedelta
from metpy import calc
from metpy.units import units as unit_reg
from copy import deepcopy
from scipy.interpolate import splev, splrep
from datetime import datetime


radiation = RRTMGLongwave()
time_stepper = AdamsBashforth([radiation])
timestep = timedelta(hours=4)

mid_values = {'label': 'mid_levels', 'values': np.arange(60), 'units': ''}
int_values = {'label': 'interface_levels',
              'values': np.arange(61), 'units': ''}


state = get_default_state([radiation])

sample_size = 20
path = "/Users/yingliu/Projects/radnet/"

radnet = RadNet(path + 'graph-frozen-radnet_v5.0_equal_pressure_with_shifted_v3.pb')
#radnet = RadNet(path + 'graph-frozen-radnet_v7.0_500_levels_update_climt.pb')
#radnet = RadNet(path + 'graph-frozen-radnet_v7.1_with_shifted_samples.pb')
#filename = 'shifted_new/1.0_20/'
filename = 'historical_dataset'
#filename = 'test_dir'

global input_dic
input_dic = []
global label
label = []
global data
data = []


def find_files(directory, pattern='*.csv'):
    """ Recursively finds all files matching the pattern.

    :param directory:  directory path
    :param pattern: reggex
    :return: list of files
    """

    files = []
    for root, dirnames, filenames in os.walk(directory):
        for filename in fnmatch.filter(filenames, pattern):
            files.append(os.path.join(root, filename))
    return files


def load_data_samples(filename, size):
    f = Dataset(filename, mode='r')
    v = f.variables['radiation_data'][:]
    f.close()

    if size > v.shape[0]:
        rand_index = np.random.choice(v.shape[0], size=v.shape[0], replace=False)
        loaded_size = v.shape[0]
    else:
        rand_index = np.random.choice(v.shape[0], size=size, replace=False)
        loaded_size = size
    rand_data = []
    rand_label = []

    for id in rand_index:
        data = []
        label = []
        data.append(v[id, 0])
        data.append(v[id, 1])
        data.append(v[id, 2:62])
        data.append(v[id, 62:122])
        data.append(v[id, 182:263])

        label.append(v[id, 122:182])

        '''
        if np.isnan(np.sum(data)) or np.isnan(np.sum(label)):
            print("NaN found!!!!!")
            continue

        '''
        rand_data.append(data)
        rand_label.append(label)

    return rand_data, rand_label, loaded_size


############################

def init_for_plots(filename):
    global data
    global label
    global input_dic
    #files = find_files(path + "test_dataset_v2/", '*')
    files = find_files(path + filename, '*')
    indexs = np.random.choice(len(files), size=len(files), replace=False)

    data_to_load = []
    for i in range(len(indexs)):
        data_to_load.append(int(sample_size/len(indexs)))
    
    for i in range(sample_size%len(indexs)):
        data_to_load[i] = data_to_load[i] + 1

    print(np.sum(data_to_load))
    print(sample_size)
    assert np.sum(data_to_load) == sample_size
    
    for i in range(len(indexs)):
        if data_to_load[i] != 0:
            tmp_data, tmp_label, _ = load_data_samples(files[indexs[i]], data_to_load[i])
            data = data + tmp_data
            label = label + tmp_label


    # prepare data samples, tune inputs here
    for one_sample in data:
        one_input_dic = {"surface_temperature": one_sample[1], "CO2": one_sample[0],
                     "air_temperature": one_sample[2], "humidity": one_sample[3],
                     "pressure": one_sample[4]}
        input_dic.append(one_input_dic)


def calculate_prediction(input_dic):

    number_of_layers = len(input_dic["air_temperature"])

    pred = radnet.predict(input_dic, number_of_layers).reshape(1, 1, -1)

    predicted_values = pred.squeeze().tolist()

    return predicted_values


def calculate_radiation_onsite(input_dic):
    # calculate radiation onsite
    state['surface_temperature'].values[0, 0] = input_dic["surface_temperature"]
    state['air_temperature'].values[0, 0, :] = input_dic["air_temperature"]
    state['specific_humidity'].values[0, 0, :] = input_dic["humidity"]
    state['mole_fraction_of_carbon_dioxide_in_air'].values[0, 0, :] = input_dic["CO2"]
    state['air_pressure'].values[0, 0, :] = cal_air_pressure(input_dic["pressure"])
    state['air_pressure_on_interface_levels'].values[0, 0, :] = input_dic["pressure"] + 1.e-9
    a = datetime.now()
    tendencies, diagnostics = radiation(state)
    b = datetime.now()
    print('time for calculation')
    print((b-a).microseconds)
    label_on_site = tendencies["air_temperature"].values[0, 0, :]
    return label_on_site

def calculate_saturation_humidity(input_dic):
    air_pressure = cal_air_pressure(input_dic["pressure"])    
    saturation_humidity = calc.saturation_mixing_ratio(air_pressure*unit_reg(state['air_pressure'].units),
                                                input_dic["air_temperature"]*unit_reg(state['air_temperature'].units))
    saturation_humidity = saturation_humidity.to('g/g')
    '''
    print("air_pressure")
    print(air_pressure)
    print("air_temperature")
    print(input_dic["air_temperature"])
    print("calculated saturation humidity")
    print(np.array(saturation_humidity))
    '''
    saturation_humidity = saturation_humidity.clip(min=1.e-7)
    return saturation_humidity
    
def calculate_relative_humidity(input_dic):
    saturation_humidity = calculate_saturation_humidity(input_dic)
    relative_humidity = input_dic["humidity"] / saturation_humidity
    return relative_humidity

def generate_new_humidity_profile(old_input_dic, new_input_dic):
    relative_humidity = calculate_relative_humidity(old_input_dic)
    new_saturation_humidity = calculate_saturation_humidity(new_input_dic)
    new_humidity = relative_humidity * new_saturation_humidity
    return new_humidity

def generate_new_profile(old_input_dic, slope_param, shift_param):
    num_levels = len(old_input_dic["air_temperature"])
    air_temperature = old_input_dic["air_temperature"] + slope_param * np.arange(num_levels) + shift_param
    new_input_dic = deepcopy(old_input_dic)
    new_input_dic["air_temperature"] = air_temperature
    new_input_dic["humidity"] = generate_new_humidity_profile(old_input_dic, new_input_dic)
    new_input_dic["surface_temperature"] = old_input_dic["surface_temperature"] + shift_param
    return new_input_dic

def generate_new_samples_to_file(old_sample_file_name, slope_param, shift_param):
    # read from old sample file
    f = Dataset(path + old_sample_file_name, mode='r')
    v = f.variables['radiation_data'][:]
    f.close()
    
    # create new sample file
    num_levels = 60
    if not os.path.isdir(path + str(slope_param).split('.')[0] + "_" + str(shift_param).split('.')[0]):
        os.mkdir(path + str(slope_param).split('.')[0] + "_" + str(shift_param).split('.')[0])
    if os.path.isfile(path + str(slope_param).split('.')[0] + "_" + str(shift_param).split('.')[0] + "/" + old_sample_file_name.split('/')[1]):
        os.remove(path + str(slope_param).split('.')[0] + "_" + str(shift_param).split('.')[0] + "/" + old_sample_file_name.split('/')[1])
    ncfile = Dataset(path + str(slope_param).split('.')[0] + "_" + str(shift_param).split('.')[0] + "/" + old_sample_file_name.split('/')[1], 'w')
    ncfile.createDimension('radiation', 4 * num_levels + 2 + 1)
    ncfile.createDimension('sample_number', None)

    radiation_nc = ncfile.createVariable(
            "radiation_data", "f4", ("sample_number", "radiation"))
    
    local_input_dic = {}
    for i in range(v.shape[0]):
        local_input_dic["CO2"] = v[i, 0]
        local_input_dic["surface_temperature"] = v[i, 1]
        local_input_dic["air_temperature"] = v[i, 2:62]
        local_input_dic["humidity"] = v[i, 62:122]
        local_input_dic["pressure"] = v[i, 182:263]
        new_input_dic = generate_new_profile(local_input_dic, slope_param, shift_param)
        new_radiation = calculate_radiation_onsite(new_input_dic)
        
        
        radiation_results_np = np.append([new_input_dic["CO2"], new_input_dic["surface_temperature"]], [
            new_input_dic["air_temperature"], new_input_dic["humidity"], new_radiation])
        radiation_results_np = np.append(
            radiation_results_np, new_input_dic["pressure"])
        
        radiation_nc[i, :] = radiation_results_np
        
    ncfile.close()
        

def cal_air_pressure(air_pressure_interface):
    air_pressure = np.empty(60)
    for level in range(len(air_pressure_interface) - 1):
        air_pressure[level] = (air_pressure_interface[level] + air_pressure_interface[level+1])*0.5
    return air_pressure

def plot_function(input_dic, label, plot_show):
    
    predicted_values = calculate_prediction(input_dic)

    # calculated radiation value read from file
    label_read = label[0]
    
    # calculate label on site
    label_on_site = calculate_radiation_onsite(input_dic)
    
    # calculate mse
    mse = mean_squared_error(predicted_values, label_on_site)
    
    if plot_show:

        number_of_layers = len(input_dic["air_temperature"])


        # calculate pressure
        #y = [x * (1e5 / number_of_layers) for x in y]
        #print(input_dic["pressure"])
        y = cal_air_pressure(input_dic["pressure"])/100.

        # print(predicted_values)
        # print(cal_radiation)

        plt.figure(figsize=(10,5))
        plt.subplot(1, 3, 1)
        #plt.plot(label_read, y, 'y-o', label='read value')
        plt.plot(predicted_values, y, 'b-*', label='Predicted Value')
        plt.plot(label_on_site, y, 'r-*', label='Calculated Value')
        print(label_on_site[-1])
        print(label_on_site[0])
        plt.title('Surface T: ' + str(input_dic["surface_temperature"]) + '\nCO2: ' + str(input_dic["CO2"]) +
                  '\nPrediction MSE: ' + str(mse))
        plt.ylabel('pressure')
        plt.xlabel('Heating Rate (K/day)')
        plt.ylim(1050,-50)
        #plt.legend()

        air_temperature = input_dic["air_temperature"]-273.15

        plt.subplot(1, 3, 2)
        plt.plot(air_temperature, y, 'r-*')
        plt.xlabel('Temperature')
        plt.ylabel('pressure')
        plt.ylim(1050,-50)

        humidity = input_dic["humidity"]*1.e3

        plt.subplot(1, 3, 3)
        plt.plot(humidity, y, 'r-*')
        plt.xlabel('Humidity (g/g * 1e3)')
        plt.ylabel('pressure')
        plt.ylim(1050,-50)

        plt.tight_layout()
        plt.draw()
        plt.show()

def sample_plots(filename):
        
    init_for_plots(filename)
    rank_mse = {}
    for i in range(len(input_dic)):
        rank_mse[i] = mean_squared_error(calculate_prediction(input_dic[i]), calculate_radiation_onsite(input_dic[i]))
        

    sorted_mse = sorted(rank_mse.items(), key=operator.itemgetter(1), reverse=True)


    #plot for radiations
    for i in sorted_mse:
        plot_function(input_dic[i[0]], label[i[0]], True)


    #mean mse
    print(np.mean(list(rank_mse.values())))
    
    
def prediction_mse(filename):
    
    init_for_plots(filename)
    
    
    x = input_dic[0]['air_temperature']
    y = input_dic[0]['pressure'][:60][::-1]    
    
    
    x2 = np.linspace(0,10000,10)
    x2 = np.append(x2, np.linspace(11000, 80000, 25))
    x2 = np.append(x2, np.linspace(82760, 103000, 25))
    #x2 = [0.0, 1666.6666666666667, 3333.3333333333335, 5000.0, 6666.666666666667, 8333.333333333334, 10000.0, 11666.666666666668, 13333.333333333334, 15000.0, 16500.0, 19842.105263157893, 23184.21052631579, 26526.315789473683, 29868.42105263158, 33210.52631578947, 36552.63157894737, 39894.73684210527, 43236.84210526316, 46578.94736842105, 49921.05263157895, 53263.15789473684, 56605.26315789473, 59947.368421052626, 63289.47368421053, 66631.57894736843, 69973.68421052632, 73315.78947368421, 76657.8947368421, 80000.0, 83175.0, 85016.66666666667, 86858.33333333333, 88700.0, 90541.66666666667, 92383.33333333333, 94225.0, 96066.66666666667, 97908.33333333333, 99750.0, 100000.0, 100263.15789473684, 100526.31578947368, 100789.47368421052, 101052.63157894737, 101315.78947368421, 101578.94736842105, 101842.1052631579, 102105.26315789473, 102368.42105263157, 102631.57894736843, 102894.73684210527, 103157.8947368421, 103421.05263157895, 103684.21052631579, 103947.36842105263, 104210.52631578947, 104473.68421052632, 104736.84210526316, 105000.0]
    print(x2.tolist())
    a = datetime.now()
    for i in range(1):
        spl = splrep(y, x, k =1)
        y2 = splev(x2, spl)
        
    b = datetime.now()
    print('time for interpolation')
    print((b-a).microseconds)
    
    
    plt.xlim((0,105000))
    plt.plot(y, x, 'o', x2, y2, '*')
    plt.show()
    
    
    mse = {}
    for i in range(len(input_dic)):
        mse[i] = mean_squared_error(calculate_prediction(input_dic[i]), calculate_radiation_onsite(input_dic[i]))


    #mean mse
    print(np.mean(list(mse.values())))
    
def scatter_plots(filename):
    init_for_plots(filename)
    air_temp_hist = []
    humidity_hist = []
    pressure_hist = []
    for one_input_sample in input_dic:
        air_temp_hist.append(one_input_sample["air_temperature"])
        humidity_hist.append(one_input_sample["humidity"])
        pressure_hist.append(one_input_sample["pressure"])

    # scatter plot for air temperature
    plt.figure()
    for index, value in enumerate(air_temp_hist):
        #value_tmp = np.append(value[31:60], value[0:31])   
        plt.scatter(value, pressure_hist[index][:60]/100)
    #plt.ylim(1050,0)
    plt.show()

    #scatter plot for humidity
    plt.figure()
    for index, value in enumerate(humidity_hist):
        plt.scatter(value[:60] * 1.e3, pressure_hist[index][:60]/100)
    #plt.xlim(0,200)
    #plt.ylim(1050,0)
    plt.show()

def pdf_plots(filename):
    init_for_plots(filename)
    mse_list = []
    for i in range(len(input_dic)):
        mse_list.append(mean_squared_error(calculate_prediction(input_dic[i]), 
                                           calculate_radiation_onsite(input_dic[i])))

    #pdf plot for mse
    plt.figure()
    n, bin_edges = np.histogram(mse_list, 100)
    bin_probability = n/float(n.sum())
    bin_middles = (bin_edges[1:]+bin_edges[:-1])/2.
    bin_width = bin_edges[1]-bin_edges[0]
    plt.bar(bin_middles, bin_probability, width=bin_width)
    #plt.xlim(0,0.02)
    plt.show()
    
    
def generate_folders_of_samples(sample_dir, slope_param, shift_param, pattern = '*'):
    files = find_files(path + sample_dir, pattern)
    for file in files:
        generate_new_samples_to_file(file[len(path):], slope_param, shift_param)
        print("generated for %s" % file[len(path):])

        

#prediction_mse(filename)
sample_plots(filename)
#scatter_plots(filename)
#pdf_plots(filename)
#generate_folders_of_samples("test_dataset_v2/", 0., 10., pattern='*m0[1-2]_*')
#print("finish")

  from ._conv import register_converters as _register_converters


20
20


ValueError: cannot reshape array of size 241 into shape (60,4,1)