In [None]:
import json
import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import dendrogram, linkage, leaves_list
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
with open(f"results/results_xy_obj_verb_pron_order_statistics.json") as f:
    data = json.load(f)

languages = data['languages']

rules_df = pd.DataFrame.from_records(data['rules']).drop("stats", axis=1)
stats_df = pd.DataFrame.from_records([rule['stats'] for rule in data['rules']])
df = pd.concat([rules_df, stats_df], axis=1)
df

In [None]:
# Choose row_id
row_id = 0

In [None]:
# Draw a dendrogram with euclidean distance and ward method
values = np.array(df.loc[row_id, 'precisions']).reshape(-1,1)
linkage_matrix = linkage(values, method='ward', metric="euclidean")

fig, ax = plt.subplots()
dg = dendrogram(linkage_matrix, labels=languages, leaf_rotation=45, leaf_font_size=12, ax=ax)
plt.show()

In [None]:
leaf_clusters = dg['leaves_color_list']
leaves = dg['ivl']
distances = pdist(values, metric="euclidean")
sq_distance_matrix = squareform(distances)

df_clustermap = pd.DataFrame(sq_distance_matrix, index=languages, columns=languages)
col_order = leaves_list(linkage_matrix)

fig = sns.clustermap(
    df_clustermap.iloc[col_order, :],
    col_cluster=True,
    row_cluster=False,
    col_linkage=linkage_matrix,
    annot=True,
    cmap="crest",
    figsize=(12, 10),
    cbar=True
)

if fig.cax is not None:
    fig.cax.set_visible(True)

heatmap_pos = fig.ax_heatmap.get_position()
cbar_pos = fig.cax.get_position()

fig.cax.set_position([
    1.07 - heatmap_pos.x1,
    cbar_pos.y0,
    cbar_pos.width,
    cbar_pos.height
])

fig.ax_heatmap.set_xticklabels(
    fig.ax_heatmap.get_xticklabels(),
    fontsize=14
)
fig.ax_heatmap.set_yticklabels(
    fig.ax_heatmap.get_yticklabels(),
    ha='center',
    fontsize=14,
    rotation_mode='anchor',
)

plt.show()

In [None]:

residuals = df.loc[row_id, 'residuals']
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(x=languages, y=residuals, ax=ax)
ax.set_title("Standarized residuals")
ax.set_xlabel("")
ax.set_ylabel("")
ax.tick_params(axis='x', labelsize=15)
ax.tick_params(axis='y', labelsize=15)
ax.tick_params(axis='x')
plt.show()