In [None]:
germ_output_path = ""
node_label_mapping_path = ""
nodes_path = ""
edges_path = "" 
debtors_path = ""
insolvency_data_path = ""

In [None]:
from collections import defaultdict
import os
import sys
import textwrap

import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from dateutil.relativedelta import relativedelta
from sklearn.cluster import SpectralClustering
from math import ceil
from numpy import trapz
from scipy import stats
from tqdm import tqdm
from IPython.display import display

sns.set_palette("bright")
os.chdir(os.environ["EXPERIMENTS_HOMEDIR"])

%matplotlib inline

region_names = [
    'Jihomoravský kraj', 'Jihočeský kraj', 'Karlovarský kraj',
    'Kraj Vysočina', 'Královéhradecký kraj', 'Liberecký kraj',
    'Moravskoslezský kraj', 'Olomoucký kraj', 'Pardubický kraj',
    'Plzeňský kraj', 'Praha', 'Středočeský kraj', 'Zlínský kraj',
    'Ústecký kraj'
]

population_per_region = {
    'Jihomoravský kraj': 1184568, 
    'Jihočeský kraj': 637047, 
    'Karlovarský kraj': 283210,
    'Kraj Vysočina': 504025, 
    'Královéhradecký kraj': 542583, 
    'Liberecký kraj': 437570,
    'Moravskoslezský kraj': 1177989, 
    'Olomoucký kraj': 622930, 
    'Pardubický kraj': 514518,
    'Plzeňský kraj': 578707, 
    'Praha': 1275406, 
    'Středočeský kraj': 1386824, 
    'Zlínský kraj': 572432,
    'Ústecký kraj': 798898 
}

In [None]:
node_label_2_name = dict(
    pd.read_csv(node_label_mapping_path)[["label", "idx"]].values
)
node_label_2_name = {k: node_label_2_name[k].title().replace("_", " ") for k in node_label_2_name}
for k, v in node_label_2_name.items():
    if v == "Debtor":
        node_label_2_name[k] = "Other"
    if v == "Nonbanking":
        node_label_2_name[k] = "Lender"
    if v == "Utilities":
        node_label_2_name[k] = "Util."
print(node_label_2_name)

In [None]:
def _value_label_to_str(value_label):
    return {
        "0": "0–11%",
        "1": "11–31%",
        "2": "31–52%",
        "3": "52–79%",
        "4": "79–100%",
    }[value_label]

def render_graph(g, node_labels=None, edge_labels=None):
    pos=nx.spring_layout(g, scale=0.7)
    color_map=[sns.color_palette()[2]] + [sns.color_palette()[1]] * (len(list(g)) - 1)
    nx.draw(g, pos, arrowsize=18,node_color=color_map, node_size=3500, edge_color="grey", width=3)
    if not edge_labels:
        edge_labels=dict([((u,v),d["label"] + "Y" + "\n" + _value_label_to_str(d["value_label"])) for u,v,d in g.edges(data=True)])
    nx.draw_networkx_edge_labels(g,pos, edge_labels=edge_labels, font_size=15)
    if node_labels:
        nx.draw_networkx_labels(g, pos, labels=node_labels, font_size=16)
    else:
        nx.draw_networkx_labels(
            g,
            pos,
            labels=dict([
                (u, node_label_2_name[int(d["label"])]) for u,d in g.nodes(data=True)
            ]),
            font_size=16
        )
    plt.show()
    
def calculate_confidence(graph, graphs):
    edges = graph["graph"].edges(data=True)
    max_label = max(map(lambda e: int(e[2]["label"]), list(edges)))
    selected_edges = [(ee[0], ee[1]) for ee in filter(lambda e: int(e[2]["label"]) < max_label, list(edges))]
    body = graph["graph"].edge_subgraph(selected_edges)
    
    for g in graphs:
        if nx.is_isomorphic(
            g["graph"], 
            body, 
            edge_match=lambda e1, e2: e1["label"] == e2["label"] and e1["value_label"] == e2["value_label"], 
            node_match=lambda n1, n2: n1["label"]==n2["label"]
        ):
            return graph["support"] / g["support"]
    else:
        return 0

def load_projections(path, max_num_projections=sys.maxsize):
    projections = defaultdict(list)
    current_projection = None
    graph_number = None
    skip = False
    for idx, line in enumerate(open(path)):
        if not line.startswith("t") and skip:
            continue
        if line.startswith("t"):
            skip = False
            graph_number = int(line.split(" ")[2])
            continue
        elif line.startswith("p"):
            if current_projection is not None:
                projections[graph_number].append(current_projection)
            if len(projections[graph_number]) >= max_num_projections:
                skip = True
                current_projection = None
                continue
            current_projection = nx.DiGraph()
        elif line.startswith("e"):
            from_, to, label, value_label = line.split(" ")[1:5]
            current_projection.add_edge(int(from_), int(to), label=int(label), value_label=value_label)

def _get_diff_months(end_date, min_date):
    delta = relativedelta(end_date, min_date)
    return delta.years * 12 + delta.months

In [None]:
def load_rule_2_debtors(debtors_path):
    debtors_num_agg = []
    graph_number = None
    rule_2_debtors = {}
    for line in open(debtors_path):
        if line.startswith("t"):
            if debtors_num_agg:
                rule_2_debtors[graph_number] = debtors_num_agg
            debtors_num_agg = []
            graph_number = int(line.split(" ")[-1])
            continue
        else:
            debtors_num_agg.append(int(line.strip()))
    rule_2_debtors[graph_number] = debtors_num_agg
    return rule_2_debtors

def get_months_to_default(debtors):
    debtors_df = pd.DataFrame(debtors, columns=["idx"])
    selected_nodes_df = nodes_df.merge(debtors_df, on="idx")
    selected_edges_df = edges_df.merge(selected_nodes_df[["idx", "id", "proposal_timestamp"]], left_on="src_id", right_on="id")
    selected_edges_df = selected_edges_df.sort_values(by="due_date")
    first_default_df = selected_edges_df.loc[selected_edges_df.groupby('src_id').due_date.idxmin()]
    first_default_df["months_to_default"] = ((first_default_df.proposal_timestamp - first_default_df.due_date)/np.timedelta64(1, 'M')).astype(int)
    return first_default_df["months_to_default"]

def plot_months_to_default(m2d, pretty):
    ax = m2d.hist(
        bins=20,
        figsize=(8,3),
        range=(0,250)
    )    
    if not pretty:
        ax.set_title(f"Time to bankruptcy (MTTB={round(m2d.median(),3)}) months")
    plt.axvline(x=m2d.median(), color=sns.color_palette()[3], linewidth=4, linestyle="--")
    
    label_bottom, label_top = ax.get_yticks()[-3], ax.get_yticks()[-2]
    plt.text(m2d.median() + 5, label_bottom + (label_top-label_bottom) / 4,f"MTTB={int(m2d.median())}",rotation=0, color=sns.color_palette()[3], fontsize=16)
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    plt.xlabel("Months", fontsize=22, labelpad=6)
    plt.ylabel("# of debtors", fontsize=22, labelpad=5)
    plt.show()
    
def get_debt_accumulation_series(debtors):
    debtors_df = pd.DataFrame(debtors, columns=["idx"])
    selected_nodes_df = nodes_df.merge(debtors_df, on="idx")
    selected_edges_df = edges_df.merge(selected_nodes_df[["idx", "id", "proposal_timestamp"]], left_on="src_id", right_on="id")
    selected_edges_df = selected_edges_df.sort_values(by="due_date")

    min_quarter_df = selected_edges_df.groupby("src_id").agg(min_quarter=("label_monthly", min)).reset_index()
    selected_edges_df = selected_edges_df.merge(min_quarter_df, on="src_id")
    selected_edges_df["quarter"] = selected_edges_df["label_monthly"] - selected_edges_df["min_quarter"]

    sums_df=selected_edges_df.groupby(["src_id", "quarter"])[["value_percentage"]].sum()
    cum_sums_df = sums_df.reset_index().pivot_table(values="value_percentage", index="src_id", columns="quarter").cumsum(axis=1)
    cum_sums_df = cum_sums_df.T.fillna(method="ffill")

    if cum_sums_df.shape[0] < 130:
        cum_sums_df = cum_sums_df.reindex(range(130)).fillna(method="ffill")
    else:
        cum_sums_df = cum_sums_df[:130]

    assert cum_sums_df.shape[0] == 130
    
    return cum_sums_df

def plot_debt_accumulation(stats_df, pretty, mttb):
    chart_df = pd.DataFrame()    
    chart_df["Mean"] = stats_df.mean(axis=1)
    chart_df["Median"] = stats_df.median(axis=1)
    chart_df["std"] = stats_df.std(axis=1)
    if pretty:
        # ax = chart_df[["Mean"]].plot(figsize=(8,3), linewidth=4)
        ax = chart_df[["Median"]].plot(figsize=(8,3), linewidth=4)
        pass
    else:
        ax = chart_df[["Mean", "std"]].plot(figsize=(8,3), yerr='std', capsize=3)
        chart_df[["Median"]].plot(ax=ax, linewidth=4)
    if mttb:
        plt.axvline(x=mttb, color=sns.color_palette()[3], linewidth=4, linestyle="--")
        if mttb > ax.get_xlim()[1] - 10:
            label_x = mttb-33
        else:
            label_x = mttb + 3
        plt.text(label_x, 66,f"MTTB={int(mttb)}",rotation=0, color=sns.color_palette()[3], fontsize=18)
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    if not pretty:
        ax.set_title("Debt accumulation")
    ax.set_ylim(0, 105)
    ax.set_xlabel("Months",  fontsize=22, labelpad=6)
    ax.set_ylabel("Accumulated debt (%)",  fontsize=18, labelpad=5)
    ax.get_legend().remove()
    plt.show()
    
def get_total_debt(debtors):
    debtors_df = pd.DataFrame(debtors, columns=["idx"])
    selected_nodes_df = nodes_df.merge(debtors_df, on="idx")
    selected_edges_df = edges_df.merge(selected_nodes_df[["idx", "id", "proposal_timestamp"]], left_on="src_id", right_on="id")
    return selected_edges_df.groupby("src_id")[["value"]].sum().reset_index()
    
def get_nodes_for_rule(gid):
    debtors = rule_2_debtors[gid]
    debtors_df = pd.DataFrame(debtors, columns=["idx"])
    return nodes_df.merge(debtors_df, on="idx")

def get_categories_for_rule(gid):
    stats_df = edges_df.merge(
        get_nodes_for_rule(gid)[["id"]], left_on="src_id", right_on="id"
    ).merge(
        nodes_df, left_on="dst_id", right_on="id"
    ).groupby("category").count()[["src_id"]]
    stats_df["perc"] = stats_df["src_id"] / stats_df["src_id"].sum() * 100
    stats_df.rename(columns={"src_id": "count", "perc": "percentage"})
    return stats_df.round(1)

def get_regional_stats_for_rule(gid):
    region_stats_df = get_nodes_for_rule(gid).merge(
        edges_df, left_on="id", right_on="src_id").merge(insolvency_data_df, on="insolvency_id").groupby("region").agg(rule_count=("id", "count")
    )
    region_stats_df["rule_percentage"] = region_stats_df["rule_count"] / region_stats_df["rule_count"].sum() * 100
    region_stats_df = region_stats_df.merge(region_num_ins_df, on="region")
    region_stats_df["rule_percentage_vs_exp_ratio"] = region_stats_df["rule_percentage"] / region_stats_df["region_perc_ins"]
    region_stats_df["rule_percentage_per_region_ins"] = (region_stats_df["rule_count"] / region_stats_df["region_num_ins"]) * 100
    region_stats_df = region_stats_df.merge(pd.DataFrame(population_per_region.items(), columns=["region", "population_size"]), on="region")
    region_stats_df["population_percentage"] = region_stats_df["population_size"] / region_stats_df["population_size"].sum() * 100
    region_stats_df["rule_percentage_vs_pop_exp_ratio"] = region_stats_df["rule_percentage"] / region_stats_df["population_percentage"]
    return region_stats_df.round(3)

def analyze_rule(gid, pretty=False):
    g = graphs[gid]
    print("Number: " + str(g["number"]))
    print("Support: " + str(g["support"]))
    print("Confidence: " + str(round(g["confidence"]*100, 2)) + " %")  
    debtors = rule_2_debtors[g["number"]]
    months_to_default = get_months_to_default(debtors)    
    accumulation_series_df = get_debt_accumulation_series(debtors)
    
    individual_aucs = pd.Series([trapz(row) for _, row in accumulation_series_df.T.iterrows()])
    corr_df = pd.DataFrame(zip(months_to_default, individual_aucs), columns=["months_to_default", "auc"])
    corr = round(corr_df['months_to_default'].corr(corr_df['auc']), 5)
    print(f"Correlation: {corr}")
    
    render_graph(g["graph"])
    plot_months_to_default(months_to_default, pretty)
    plot_debt_accumulation(accumulation_series_df, pretty, months_to_default.median())   
    regional_stats_df = get_regional_stats_for_rule(gid)    
    print("Regional stats")
    _, p_value_ins, _, _ = stats.chi2_contingency(regional_stats_df[["rule_count", "region_num_ins"]])
    print(f"Insolvency count dependence test: chi_square_indepdence={p_value_ins > 0.05}, p-value={p_value_ins}")
    _, p_value_pop, _, _ = stats.chi2_contingency(regional_stats_df[["rule_count", "population_size"]])
    print(f"Population count dependence test: chi_square_indepdence={p_value_pop > 0.05}, p-value={p_value_pop}")
    display(regional_stats_df.transpose())

In [None]:
nodes_df = pd.read_csv(nodes_path)
nodes_df.id = nodes_df.id.astype(str)
nodes_df["proposal_timestamp"] = pd.to_datetime(nodes_df["proposal_timestamp"])
edges_df = pd.read_csv(edges_path)
edges_df["due_date"] = pd.to_datetime(edges_df["due_date"])
edges_df.src_id = edges_df.src_id.astype(str)
edges_df.dst_id = edges_df.dst_id.astype(str)
insolvency_data_df = pd.read_csv(insolvency_data_path)

In [None]:
region_num_ins_df = insolvency_data_df[
    insolvency_data_df.insolvency_id.apply(lambda iid: 2020 <= int(iid[-4:]) <= 2022)
].groupby("region").agg(region_num_ins=("insolvency_id", "count")).reset_index()
region_num_ins_df["region_perc_ins"] = region_num_ins_df["region_num_ins"] / region_num_ins_df["region_num_ins"].sum() * 100

In [None]:
# loading patterns
tmp_g = None
support = None
graph_number = None
graphs = []
for line in open(germ_output_path):
    if line.startswith("t"):
        if tmp_g:
            graphs.append({"graph": tmp_g, "support": support, "number": graph_number})
            pass
        tmp_g = nx.DiGraph()
        support = int(line.split(" ")[-1])
        graph_number = int(line.split(" ")[-2])
        continue
    if line.startswith("v"):
        parsed = line.strip().split(" ")
        tmp_g.add_node(parsed[1], label=parsed[2])
    if line.startswith("e"):
        parsed = line.strip().split(" ")
        tmp_g.add_edge(parsed[1], parsed[2], label=parsed[3], value_label=parsed[4])  

for g in graphs:
    g["confidence"] = calculate_confidence(g, graphs)

In [None]:
# loading debtors per rule
rule_2_debtors = load_rule_2_debtors(debtors_path)

all_debtors = set()
for gid in rule_2_debtors:
    all_debtors = all_debtors.union(rule_2_debtors[gid])
print(f"Pocet uzlov pokrytych aspon jednym pravidlom: {len(all_debtors)}")

# Rule analysis

In [None]:
rules = []
for row in tqdm(graphs):
    debtors = rule_2_debtors[row["number"]]
    months_to_default = get_months_to_default(debtors)    
    accumulation_series_df = get_debt_accumulation_series(debtors)
    total_debt_df = get_total_debt(debtors)
    
    sum_jaccard = 0
    for row2 in graphs:
        debtors2 = rule_2_debtors[row2["number"]]
        debtors_set = set(debtors)
        debtors2_set = set(debtors2)
        num_intersect = len(debtors_set.intersection(debtors2_set))
        num_union = len(debtors_set.union(debtors2_set))
        sum_jaccard += num_intersect / num_union
        
    max_rule_label = max(map(lambda e: int(e[2]["label"]), row["graph"].edges(data=True)))
            
    
    individual_aucs = pd.Series([trapz(row) for _, row in accumulation_series_df.T.iterrows()])
    corr_df = pd.DataFrame(zip(months_to_default, individual_aucs), columns=["months_to_default", "auc"])
    corr = round(corr_df['months_to_default'].corr(corr_df['auc']), 5)
    
    regional_stats_df = get_regional_stats_for_rule(row["number"])    
    _, p_value_ins, _, _ = stats.chi2_contingency(regional_stats_df[["rule_count", "region_num_ins"]])
    _, p_value_pop, _, _ = stats.chi2_contingency(regional_stats_df[["rule_count", "population_size"]])
    
    regional_stats = dict(regional_stats_df.set_index("region").rule_percentage_vs_exp_ratio.items())
    rules.append(dict(
        [
            ("gid", row['number']), 
            ("support", row["support"]),
            ("confidence", row["confidence"]),
            ("months_to_default_mean", months_to_default.mean()), 
            ("months_to_default_median", months_to_default.median()), 
            ("aucs_mean", individual_aucs.mean()), 
            ("aucs_median", individual_aucs.median()), 
            ("months_to_default_and_auc_corr", corr),
            ("total_debt_mean", total_debt_df.value.mean()),
            ("total_debt_median", total_debt_df.value.median()),
            ("avg_overlap_with_others", sum_jaccard / len(graphs)),
            ("max_rule_label", max_rule_label),
            ("p_value_ins_independence", p_value_ins),
            ("p_value_pop_independence", p_value_pop),
        ] + [(r, regional_stats.get(r, 0)) for r in region_names]
    ))


rules_df = pd.DataFrame(
    rules,
)
rules_df["aucs_median_normalized"] =  rules_df["aucs_median"] / rules_df["aucs_median"].max()
rules_df["aucs_mean_normalized"] =  rules_df["aucs_mean"] / rules_df["aucs_mean"].max()

rules_df = rules_df[[
    'gid', 'support', 'confidence', 'months_to_default_mean',
    'months_to_default_median', 'aucs_mean', 'aucs_median', 
    'aucs_median_normalized', 'aucs_mean_normalized',
    'months_to_default_and_auc_corr', 'total_debt_mean',
    'total_debt_median', 'avg_overlap_with_others', 'max_rule_label',
    'p_value_ins_independence', 'p_value_pop_independence'    
] + region_names]

## Rule histograms

### Support

In [None]:
ax = (rules_df.support[rules_df.support < 1000]).hist(
    bins=40,
    figsize=(8,3),
)    

plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel("Support", fontsize=18, labelpad=6)
plt.ylabel("# rules", fontsize=18, labelpad=5)
plt.show()

### Confidence

In [None]:
ax = (rules_df.confidence * 100).hist(
    bins=40,
    figsize=(8,3),
)    
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel("Confidence (%)", fontsize=18, labelpad=6)
plt.ylabel("# rules", fontsize=18, labelpad=5)
plt.show()

### MTTB

In [None]:
ax = (rules_df.months_to_default_median).hist(
    bins=40,
    figsize=(8,3),
)    
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel("Median time to bankruptcy (MTTB) in months", fontsize=18, labelpad=6)
plt.ylabel("# rules", fontsize=18, labelpad=5)
plt.show()

### MDEBT

In [None]:
ax = (rules_df.total_debt_median * 0.046 / 1e6).hist(
    bins=40,
    figsize=(8,3),
)    
ax.xaxis.set_major_formatter('{x:.2g}M')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel("Median total debt (MDEBT) in USD", fontsize=18, labelpad=6)
plt.ylabel("# rules", fontsize=18, labelpad=5)
plt.show()

# Rules sorted by different metrics

### Rules based "months to default"

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="months_to_default_median")
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

In [None]:
# bottom rules
for gid in rules_view_df[-5:].gid:
    analyze_rule(gid)

### Rules based on Debt AUC

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="aucs_median_normalized")
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

In [None]:
# bottom rules
for gid in rules_view_df[-5:].gid:
    analyze_rule(gid)

### Rules based Debt AUC and months to default correlation

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="months_to_default_and_auc_corr")
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

In [None]:
# bottom rules
for gid in rules_view_df[-5:].gid:
    analyze_rule(gid)

### Rules based on MDEBT

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="total_debt_median")
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

In [None]:
# bottom rules
for gid in rules_view_df[-5:].gid:
    analyze_rule(gid)

### Rules based on uniqueness

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="avg_overlap_with_others")
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

In [None]:
# bottom rules
for gid in rules_view_df[-5:].gid:
    analyze_rule(gid)

### Praha

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="Praha", ascending=False)
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

### Ustecky kraj

In [None]:
rules_view_df = rules_df[rules_df.max_rule_label > 0].sort_values(by="Ústecký kraj", ascending=False)
rules_view_df

In [None]:
# top rules
for gid in rules_view_df[:5].gid:
    analyze_rule(gid)

### Overlap analysis of positively and negatively correlated rules

In [None]:
overlap_matrix = []
index = []
columns = []
for gid_1, _ in sorted(rules_df[["gid", "months_to_default_and_auc_corr"]].values, key=lambda p: p[1])[:10]:
    gid_1 = int(gid_1)
    overlap_matrix.append([])
    columns.append(gid_1)
    for gid_2, _ in sorted(rules_df[["gid", "months_to_default_and_auc_corr"]].values, key=lambda p: p[1])[-10:]:
        gid_2 = int(gid_2)
        jaccard_index = (
            len(set(rule_2_debtors[gid_1]).intersection(rule_2_debtors[gid_2])) / len(set(rule_2_debtors[gid_1]).union(rule_2_debtors[gid_2]))
        )
        if gid_2 not in index:
            index.append(gid_2)
        overlap_matrix[-1].append(jaccard_index)
overlap_matrix_df = pd.DataFrame(overlap_matrix, index=index, columns=columns)
overlap_matrix_df.style.background_gradient(cmap='Blues', axis=None)

### Rules based on confidence

In [None]:
num_charted_graphs = 0
for row in sorted(graphs, key=lambda d: d["confidence"], reverse=True)[:5]:
    # if not any(map(lambda v: v["label"] == '', dict(row["graph"].nodes(data=True)).values())):
    #     continue
    # if not any(map(lambda e: e[2]["value_label"] == '', list(row["graph"].edges(data=True)))):
    #     continue
    num_charted_graphs += 1
    print("******************************************************")
    analyze_rule(row['number'])
    if num_charted_graphs > 10:
        break

# Rule matrix

In [None]:
overlap_matrix = []
index = []
columns = []
for gid_1, _ in rules_df[["gid", "months_to_default_and_auc_corr"]].values:
    overlap_matrix.append([])
    columns.append(gid_1)
    for gid_2, _ in rules_df[["gid", "months_to_default_and_auc_corr"]].values:
        jaccard_index = (
            len(set(rule_2_debtors[gid_1]).intersection(rule_2_debtors[gid_2])) / len(set(rule_2_debtors[gid_1]).union(rule_2_debtors[gid_2]))
        )
        if gid_2 not in index:
            index.append(gid_2)
        overlap_matrix[-1].append(jaccard_index)
overlap_matrix_df = pd.DataFrame(overlap_matrix, index=index, columns=columns)

In [None]:
n_clusters = 5
clustering = SpectralClustering(
    n_clusters=n_clusters,
    assign_labels='discretize',
    random_state=0,
    affinity="precomputed"
).fit(overlap_matrix_df)

In [None]:
rules_df["cluster"] = clustering.labels_

In [None]:
pd.set_option("display.precision", 8)
clustering_stats_df = rules_df.groupby("cluster").median().merge(
    rules_df.groupby("cluster").count()[["gid"]].rename(columns={"gid": "cluster_size"}), left_index=True, right_index=True
)
clustering_stats_df.style.background_gradient(cmap='Blues')

In [None]:
mean_overlaps = []
for cluster_1 in range(n_clusters):
    mean_overlaps.append([])
    for cluster_2 in range(n_clusters):
        columns = list(rules_df[rules_df["cluster"] == cluster_1].gid)
        rows = list(rules_df[rules_df["cluster"] == cluster_2].gid)
        avg_overlap = overlap_matrix_df[columns].loc[rows].stack().mean()
        mean_overlaps[-1].append(avg_overlap)
mean_overlaps_df = pd.DataFrame(mean_overlaps)
mean_overlaps_df.style.background_gradient(cmap='Blues', axis=None)