In [None]:
from tbdynamics.inputs import matrix, conmat
from tbdynamics.constants import age_strata
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cdist

In [None]:
def plot_contact_matrix(contact_matrix, age_groups, title):
    """
    Plots a contact matrix with the given age groups as labels.

    Args:
    - contact_matrix: A 2D numpy array representing the contact rates between age groups.
    - age_groups: A list of strings representing the labels for the age groups.
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(contact_matrix, annot=True, fmt=".2f", cmap="YlGnBu", cbar_kws={'label': 'Yearly Contacts'},
                xticklabels=age_groups, yticklabels=age_groups)

    plt.xlabel("Age Group To")
    plt.ylabel("Age Group From")
    plt.title(title)
    plt.xticks(rotation=45, ha="right")
    plt.gca().invert_yaxis()
    plt.show()

In [None]:
plot_contact_matrix(matrix, age_strata, "Yearly contact matrix from survey")

In [None]:
plot_contact_matrix(conmat, age_strata, "Yearly contact matrix extrapolated from conmat")

In [None]:
# Calculate Canberra distances between corresponding rows in 'values' and 'conmat_values'
distance_matrix = cdist(matrix, conmat, metric='canberra')
plot_contact_matrix(distance_matrix, age_strata, "Different of two matrices")