In [1]:
import reproducibility
import datasets
import numpy
import models
import datasets
import joblib
import matplotlib.pyplot as plt

from geochem_table import GeochemTable, GeochemResultSet
from data import plot_confusion_matrix
from analysis import evaluate_performance

seed_li = 3617
seed_sb = 6244
seed_zn = 9659
seed_ag = 8951
seed_au = 2701
seed_sn = 1627
seed_cu = 213
seed_pb = 76

threshold_li_50x = 1000
threshold_sb_50x = 10
threshold_zn_50x = 2500
threshold_ag_50x = 5
threshold_au_50x = 150
threshold_sn_50x = 100
threshold_cu_50x = 2000
threshold_pb_50x = 500

In [2]:
def GetHistogramModel(element,seed,threshold,dataset,date,country=None) -> GeochemResultSet:
    reproducibility.seed_random(seed)
    x_train, y_train, x_test, y_test, x_labels, y_labels, geochemical_analysis, geochemical_analysis_measured, model = models.hist_gradient_boosting_regressor_geochem(element, seed, dataset, country)

    geochem_result_sets = {}

    y_test_predicted = model.predict(x_test)
    comparison, confusion_matrix, geochem_result_set = evaluate_performance(y_test, y_test_predicted, threshold, verbose=False)
    joblib.dump(model, "Models/" + element +"_HistogramGradientBoostingRegressor.joblib")
    if (country is None):
        geochem_result_sets["Test Set"] = geochem_result_set

    x_all = geochemical_analysis_measured[x_labels].to_numpy(numpy.float32)
    y_all = geochemical_analysis_measured[y_labels].to_numpy(numpy.float32).ravel()
    y_all_predicted = model.predict(x_all)

    comparison_all, confusion_matrix, geochem_result_set = evaluate_performance(y_all, y_all_predicted, threshold, verbose=False)
    if (country is None):
        geochem_result_sets["Analyzed"] = geochem_result_set
    else:
        geochem_result_sets[country.capitalize() + " Analyzed"] = geochem_result_set

    geochemical_analysis[element+"_Predicted"] = model.predict(geochemical_analysis[x_labels].to_numpy(numpy.float32))
    geochemical_analysis.to_csv("Output/" + element + " predictions Histogram " + date + ".csv", sep=",", index=False)

    unit_ppm = "b" if "AU" in element.upper() else "m"
    if (country is None):
        plot_confusion_matrix(comparison, threshold, element, date, unit=unit_ppm)
        plt.show()

    for result_set in geochem_result_sets.values():
        result_set.update(**{
            "element": {"value": element.split("_")[0], "label": "Element"},
            "threshold": {"value": str(threshold) + "pp" + unit_ppm, "label": "Threshold"},
            "dataset": {"value": date, "label": "Dataset"}, 
            "seed": {"value": seed, "label": "Dataset"}, 
            "sample_universe": {"value": "", "label": "Sample Universe"}
        })
    
    return geochem_result_sets

In [3]:
# 2023_12_19 dataset
element_list = [
    {
        "element": "Li_ppm",
        "seed": seed_li,
        "threshold": threshold_li_50x
    },
    {
        "element": "Sb_ppm",
        "seed": seed_sb,
        "threshold": threshold_sb_50x
    },
    {
        "element": "Zn_ppm",
        "seed": seed_zn,
        "threshold": threshold_zn_50x
    },
    {
        "element": "Ag_ppm",
        "seed": seed_ag,
        "threshold": threshold_ag_50x
    },
    {
        "element": "Au_ppb",
        "seed": seed_au,
        "threshold": threshold_au_50x
    },
    {
        "element": "Sn_ppm",
        "seed": seed_sn,
        "threshold": threshold_sn_50x
    },
    {
        "element": "Cu_ppm",
        "seed": seed_cu,
        "threshold": threshold_cu_50x
    },
    {
        "element": "Pb_ppm",
        "seed": seed_pb,
        "threshold": threshold_pb_50x
    }
]

geochem_result_sets = []
for element_item in element_list:
    geochem_result_sets.append(GetHistogramModel(element_item["element"], element_item['seed'], element_item['threshold'], datasets.geochemical_analysis_2023_12_19, "2023_12_19"))
    geochem_result_sets.append(GetHistogramModel(element_item["element"], element_item['seed'], element_item['threshold'], datasets.geochemical_analysis_2023_12_19, "2023_12_19", country="PERU"))

table = GeochemTable(geochem_result_sets)
table.generate_html()

In [None]:
# # # 2023_07_16 dataset
# element_list = [
#     {
#         "element": "Li_ppm",
#         "seed": seed_li,
#         "threshold": threshold_li_50x
#     },
#     {
#         "element": "Sb_ppm",
#         "seed": seed_sb,
#         "threshold": threshold_sb_50x
#     },
#     {
#         "element": "Zn_ppm",
#         "seed": seed_zn,
#         "threshold": threshold_zn_50x
#     },
#     {
#         "element": "Ag_ppm",
#         "seed": seed_ag,
#         "threshold": threshold_ag_50x
#     },
#     {
#         "element": "Au_ppb",
#         "seed": seed_au,
#         "threshold": threshold_au_50x
#     },
#     {
#         "element": "Sn_ppm",
#         "seed": seed_sn,
#         "threshold": threshold_sn_50x
#     },
#     {
#         "element": "Cu_ppm",
#         "seed": seed_cu,
#         "threshold": threshold_cu_50x
#     },
#     {
#         "element": "Pb_ppm",
#         "seed": seed_pb,
#         "threshold": threshold_pb_50x
#     }
# ]

# geochem_result_sets = []
# for element_item in element_list:
#     geochem_result_sets.append(GetHistogramModel(element_item["element"], element_item['seed'], element_item['threshold'], datasets.geochemical_analysis_2023_07_16, "2023_07_16"))
#     geochem_result_sets.append(GetHistogramModel(element_item["element"], element_item['seed'], element_item['threshold'], datasets.geochemical_analysis_2023_07_16, "2023_07_16", country="PERU"))

# table = GeochemTable(geochem_result_sets)
# table.generate_html()