In [1]:
import pandas as pd
from pandera import DataFrameModel, Field, DateTime
from pandera.typing import DataFrame, Series

from prefect import flow, task
from prefect.blocks.system import Secret

from catnip.fla_redshift import FLA_Redshift
from catnip.fla_sharepoint import FLA_Sharepoint

from typing import Dict
from datetime import datetime, date

import numpy as np
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
from scipy import optimize
from sklearn.preprocessing import StandardScaler
from concurrent.futures import ThreadPoolExecutor

In [2]:
def get_redshift_credentials() -> Dict:

    cred_dict = {
        "dbname": Secret.load("stellar-redshift-db-name").get(),
        "host": Secret.load("stellar-redshift-host").get(),
        "port": 5439,
        "user": Secret.load("stellar-redshift-user-name").get(),
        "password": Secret.load("stellar-redshift-password").get(),

        "aws_access_key_id": Secret.load("fla-s3-aws-access-key-id-east-1").get(),
        "aws_secret_access_key": Secret.load("fla-s3-aws-secret-access-key-east-1").get(),
        "bucket": Secret.load("fla-s3-bucket-name-east-1").get(),
        "subdirectory": "us-east-1",

        "verbose": False,
    }

    return cred_dict

with ThreadPoolExecutor(1) as pool:
    rs_creds = pool.submit(lambda: get_redshift_credentials()).result()

In [9]:
q = "select * from custom.cth_game_descriptions where season = '2023-24'"

In [11]:
df = FLA_Redshift(**rs_creds).query_warehouse(sql_string = q)

In [38]:
q = """
        WITH nightly AS (
            SELECT
                event_date,
                count(*) AS nightly_tickets
            FROM
                custom.cth_historical_all_1718_2223
            WHERE
                season != '2020-21'
                AND is_comp = FALSE
                AND ticket_type IN ('Singles', 'Flex')
            GROUP BY
                event_date
            UNION
            SELECT
                date(event_datetime) as event_date,
                count(*) AS nightly_tickets
            FROM
                custom.cth_v_ticket_2324
            WHERE
                is_comp = FALSE
                AND ticket_type IN ('Singles', 'Flex')
            GROUP BY
                event_date
        ),
        atp AS (
            SELECT
                event_date,
                sum(gross_revenue)/count(*) AS atp
            FROM
                custom.cth_historical_all_1718_2223
            WHERE
                season != '2020-21'
            GROUP BY
                event_date
            UNION
            SELECT
                date(event_datetime) as event_date,
                sum(gross_revenue)/count(*) AS atp
            FROM
                custom.cth_v_ticket_2324
            GROUP BY
                event_date
        ),
        attendance AS (
            SELECT
                event_date,
                sum(did_attend) AS attendance
            FROM
                custom.cth_historical_all_1718_2223
            GROUP BY
                event_date
            UNION
            SELECT
                date(event_datetime) as event_date,
                count(*) AS attendance
            FROM
                custom.cth_v_attendance_2324
            GROUP BY
                event_date
        ),
        agg as
            (SELECT
                n.event_date,
                n.nightly_tickets,
                atp.atp,
                att.attendance
            FROM
                nightly n
            LEFT JOIN
                atp ON n.event_date = atp.event_date
            LEFT JOIN
                attendance att ON n.event_date = att.event_date
            WHERE
                n.event_date < (GETDATE() - 1)
            ORDER BY
                n.event_date)
        SELECT
            agg.*, week_day, trimester, original_six_plus_extra, is_dense
        FROM
            agg
        LEFT JOIN
            custom.cth_game_descriptions on date(agg.event_date) = date(cth_game_descriptions.event_date)
        WHERE
            season = '2023-24'
    """

In [45]:
df = FLA_Redshift(**rs_creds).query_warehouse(sql_string = q)

In [46]:
df_clust = df[[
        'week_day', 
        'trimester', 
        'original_six_plus_extra', 
        'is_dense', 
        'nightly_tickets', 
        'atp', 
        'attendance']]

x = np.array(df_clust)
X = StandardScaler().fit_transform(x)

In [49]:
def get_even_clusters(X: np.array, cluster_size: int) -> np.array:

    n_clusters = int(np.ceil(len(X)/cluster_size))
    kmeans = KMeans(n_clusters, random_state = 1693)
    kmeans.fit(X)
    centers = kmeans.cluster_centers_
    centers = centers.reshape(-1, 1, X.shape[-1]).repeat(cluster_size, 1).reshape(-1, X.shape[-1])
    distance_matrix = cdist(X, centers)
    clusters = optimize.linear_sum_assignment(distance_matrix)[1]//cluster_size

    return clusters

  super()._check_params_vs_input(X, default_n_init=10)


In [54]:
df['cluster'] = get_even_clusters(X, 8)

  super()._check_params_vs_input(X, default_n_init=10)


In [55]:
for i in df['cluster'].unique():

        if df['cluster'].value_counts()[i] < 5:

            mean_atp_df = df[['cluster', 'atp']].groupby(['cluster']).mean()
            mean_atp_df = mean_atp_df.loc[~mean_atp_df.index.isin([i])]

            for index, row in df[df['cluster'] == i].iterrows():

                df_closest = mean_atp_df.iloc[(mean_atp_df['atp']-row['atp']).abs().argsort()[:1]].index
                df.loc[df['event_date'] == row['event_date'], 'cluster'] = df_closest[0]

            if i != 0:

                df.loc[df['cluster'] == 0, 'cluster'] = i


## select cols
df

Unnamed: 0,event_date,nightly_tickets,atp,attendance,week_day,trimester,original_six_plus_extra,is_dense,cluster
0,2024-01-17,3175,64.037092,16009.0,4,2,0.75,1,4
1,2024-02-06,2161,61.352928,14635.0,3,3,1.0,0,3
2,2023-12-30,1362,137.124192,17906.0,7,2,0.75,0,0
3,2024-02-20,3387,50.582824,16623.0,3,3,0.0,0,2
4,2023-11-12,1670,65.243778,16221.0,1,1,1.0,0,1
5,2023-12-06,1683,46.431952,14984.0,4,2,0.0,0,1
6,2024-01-24,1889,48.669586,13849.0,4,2,0.0,0,1
7,2023-11-10,1301,62.917105,15450.0,6,1,0.0,0,1
8,2023-10-19,1187,89.807953,15400.0,5,1,1.0,0,3
9,2023-12-23,1977,95.145841,15615.0,7,2,0.75,0,0
