# Bayes Model

In [None]:
%pip install numpy scikit-learn pandas boto3 matplotlib seaborn python-dotenv

In [None]:
from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from io import BytesIO
from pathlib import Path
import pandas as pd
import numpy as np
import joblib
import dotenv
import boto3
import logging
import os
import uuid
import sys
from datetime import datetime, timezone

In [None]:
DOTENV_PATH = os.environ.get('DOTENV_PATH', './../.env')

if dotenv.load_dotenv(dotenv_path=DOTENV_PATH) == False:
    print(f'no environment have been loaded from .env path \"{DOTENV_PATH}\"')

In [None]:
LOG_LEVEL = 'INFO'
LOCAL_DATASET_PATH = os.environ.get('LOCAL_DATASET_PATH', '')
IMPORTED_DATASET_S3_KEY = os.environ.get('IMPORTED_DATASET_S3_KEY', '')
IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY = os.environ.get('IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY', '')
PUSH_MODEL_DUMP_TO_S3_ENABLED = os.environ.get('PUSH_MODEL_DUMP_TO_S3_ENABLED', 'true').lower() == 'true'
TMP_DIR = os.environ.get('TMP_DIR', '/tmp/pink-twins')
S3_BUCKET_NAME = os.environ.get('BUCKET_NAME', 'pink-twins-bucket')
S3_BUCKET_FOLDER = os.environ.get('S3_MODELS_BUCKET_FOLDER', '')
S3_ACCESS_KEY_ID = os.environ.get('S3_ACCESS_KEY_ID', '')
S3_SECRET_ACCESS_KEY = os.environ.get('S3_SECRET_ACCESS_KEY', '')
AUTHOR = os.environ.get('AUTHOR', 'undefined')

In [None]:
# Ensure that the temporary folder exist and create one if it doesn't exists
Path(TMP_DIR).mkdir(parents=True, exist_ok=True)

In [None]:
# Set logger format
logging.basicConfig(
    format="%(levelname)s | %(asctime)s | %(message)s",
    datefmt="%Y-%m-%dT%H:%M:%SZ",
    encoding='utf-8',
    level=logging.getLevelName(LOG_LEVEL),
    stream=sys.stdout,
)

In [None]:
if LOCAL_DATASET_PATH != '':
    try:
        df = pd.read_csv(LOCAL_DATASET_PATH)
    except Exception as err:
        logging.fatal(f'failed to load dataset at path {LOCAL_DATASET_PATH}: {err}')
elif IMPORTED_DATASET_S3_KEY != '':
    try:
        # Create an S3 client
        s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY)

        # Download the dump file from S3 into memory
        response = s3.get_object(Bucket=S3_BUCKET_NAME, Key=IMPORTED_DATASET_S3_KEY)
        df_bytes = BytesIO(response['Body'].read())

        # Load the variable back from the dump data
        df = joblib.load(df_bytes)

    except Exception as err:
        logging.fatal(f'failed to load dataset {IMPORTED_DATASET_S3_KEY} from S3 bucket: {err}')
else:
    logging.fatal('no source dataset have been defined')

In [None]:
# Remove irrevelant columns
df = df.drop(columns=['Location_Easting_OSGR',
                      'Location_Northing_OSGR', 'Number_of_Vehicles',
                      'Number_of_Casualties', '1st_Road_Class',
                      '1st_Road_Number', 'Junction_Control', '2nd_Road_Class',
                      'Pedestrian_Crossing-Human_Control',
                      'Pedestrian_Crossing-Physical_Facilities',
                      'Special_Conditions_at_Site', 'Carriageway_Hazards'])

# Also remove severity as we want the model to generate the clusters by
# severities 
df = df.drop(columns=['Accident_Severity'])

# Sample the data (because none of us have a 8TB RAM machine)
df = df.sample(frac=0.01)

X = df.to_numpy()

In [None]:
if IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY != '':
    try:
        # Create an S3 client
        s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY)
        imported_model_id = IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY.split('/')[-1]
        imported_model_file = f'{TMP_DIR}/{imported_model_id}'
        
        # Download the dump file from S3
        response = s3.download_file(Bucket=S3_BUCKET_NAME, Key=IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY,
            Filename=imported_model_file)

        # Load the variable back from the dump data
        model = joblib.load(imported_model_file)

    except Exception as err:
        logging.fatal(f'failed to load dataset {IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY} from S3 bucket: {err}')
else:
    # setting distance_threshold=0 ensures we compute the full tree.
    model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
    model = model.fit(X)

In [None]:
def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)

In [None]:
plt.title("Hierarchical Clustering Dendrogram")
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode="level", p=3, show_leaf_counts=False)
plt.show()

In [None]:
if IMPORTED_HIERARCHICAL_CLUSTERING_S3_KEY == '' and PUSH_MODEL_DUMP_TO_S3_ENABLED:
    model_id = uuid.uuid4()
    key = f'{S3_BUCKET_FOLDER}/hierarchical-clustering/{model_id}.joblib'

    try:
        model_id = uuid.uuid4()
        tmp_file = f'{TMP_DIR}/{model_id}.joblib'

        joblib.dump(model, tmp_file)

        s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY)
        s3.upload_file(Bucket=S3_BUCKET_NAME, Key=key, Filename=tmp_file,
                      ExtraArgs={
                          'Metadata': {
                          'author': AUTHOR,
                          'date': datetime.now(timezone.utc).astimezone().isoformat(),
                          'training_dataset_key': IMPORTED_DATASET_S3_KEY,
        }})

        logging.info(f'successfully pushed model as: {key}')
    except Exception as err:
        logging.fatal(f'failed to push model {key}: {err}')
