In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import statistics as stat

folder = 'H:\\My Drive\\PROJECTS\\PSI 2022-2025\\XRF fundamentals vs. MVA\\data'

In [2]:
meta = pd.read_csv(f'{folder}\\meta_filter1_stratified_outliers_removed.csv')

In [3]:
# info for plotting
key = {
    'maj_min':{
        'elem':['SiO2','TiO2','Al2O3','Fe2O3','MgO','MnO','CaO','P2O5'],
        'units':'wt%',
        'cols':2,
        'rows':4,
        'fig':(8,10),
        'xs':[0,0,1,1,2,2,3,3],
        'ys':[0,1]*4
    },
    'trace':{
        'elem':['As','Bi','Cr','Cu','Mo','Nb','Ni','Pb','Rb','SO3','Sn','Sr','Ta','Th','U','V','W','Y','Zn','Zr'],
        'units':'ppm',
        'cols':3,
        'rows':7,
        'fig':(10,16),
        'xs':[0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6],
        'ys':[0,1,2]*7        
        #'xs':[0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4],
        #'ys':[0,1,2,3]*5     
    }
}

In [5]:
# use fold 2 as test fold
b=50

# also extract N, median, mean, etc.
n_train_list = []
n_test_list = []
med_list = []
mean_list = []
stdev_list = []
elem_list = []
unit_list = []

for kind in key.keys():
    
    elem = key[kind]['elem']
    x=key[kind]['xs']
    y=key[kind]['ys']
    cols=key[kind]['cols']
    rows=key[kind]['rows']
    fig=key[kind]['fig']
    
    fig, ax = plt.subplots(ncols=cols,
                           nrows=rows,
                           figsize=fig)

    for i in np.arange(len(elem)):
        e=elem[i]
        if e == 'SO3':
            units='wt%'
        else:
            units=key[kind]['units']
        
        temp = meta[(~meta[e].isna())&(meta[e+'_Folds']!=-1)][[e,e+'_Folds']].copy()
        median = round(stat.median(temp[e].values),2)
        mean = round(stat.mean(temp[e].values),2)
        std = round(stat.stdev(temp[e].values),2)

        temp_test = temp[temp[e+'_Folds']==2].copy()
        temp_train = temp[temp[e+'_Folds']!=2].copy()
        n_test = len(temp_test)
        n_train = len(temp_train)
        
        # add data to lists
        elem_list.append(e)
        n_train_list.append(n_train)
        n_test_list.append(n_test)
        med_list.append(median)
        mean_list.append(mean)
        stdev_list.append(std)
        unit_list.append(units)

        
        ax[x[i],y[i]].hist(np.array([temp_test[e].values, temp_train[e].values]), 
                           bins=b, 
                           color=['#dead08','#1b6f78'],
                           stacked=True)

        # get ylims
        ymin, ymax = ax[x[i],y[i]].get_ylim()

        ax[x[i],y[i]].vlines(median, ymin, ymax, colors='black', linestyle='solid', label=f'Median: {median} {units}')
        ax[x[i],y[i]].vlines(mean, ymin, ymax, colors='black', linestyle='dashed', label=f'Mean: {mean} {units}')
        ax[x[i],y[i]].set_ylim(ymin, ymax)

        ax[x[i],y[i]].legend(fontsize=8)
        title = f'{e} N Train: {n_train} N Test: {n_test} stdev: {std}'
        ax[x[i],y[i]].set_title(title, fontsize=8)

    plt.tight_layout()
    plt.savefig(f'H:\\My Drive\\PROJECTS\\PSI 2022-2025\\XRF fundamentals vs. MVA\\figures\\histogram_{kind}.eps', dpi=600)
    plt.close()

  ax[x[i],y[i]].hist(np.array([temp_test[e].values, temp_train[e].values]),
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
The PostScript backend does not support transparency

In [6]:
dataset_summary = pd.DataFrame({
    'element':elem_list,
    'n_train':n_train_list,
    'n_test':n_test_list,
    'units':unit_list,
    'median':med_list,
    'mean':mean_list,
    'stdev':stdev_list
})
dataset_summary.to_csv(folder+'\\dataset_summary.csv', index=False)