In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
def create_mixing_matrix(n_strata, prop_mixing_same_stratum, proportions):
    """
    Creates a mixing matrix based on the number of strata, the within-stratum mixing proportion,
    and the population proportions for each stratum.
    
    Parameters:
    - n_strata (int): The number of strata.
    - prop_mixing_same_stratum (float): The within-stratum mixing proportion (0 to 1).
    - proportions (list): List of population proportions for each stratum. Should sum to 1.

    Returns:
    - mixing_matrix (np.array): Generated mixing matrix.
    """
    # Initialize an n_strata x n_strata zero matrix
    mixing_matrix = np.zeros((n_strata, n_strata))

    # Populate the mixing matrix based on within- and between-strata mixing
    for i in range(n_strata):
        for j in range(n_strata):
            if i == j:
                # Apply within-stratum mixing proportion
                mixing_matrix[i, j] = prop_mixing_same_stratum
            else:
                # Between-strata mixing
                prop_pop_j = proportions[j]
                prop_pop_non_i = sum(proportions[k] for k in range(n_strata) if k != i)
                
                # Ensure non-zero sum of proportions for non-i strata
                assert prop_pop_non_i > 0, "Population proportions for non-i strata must be positive."
                
                # Calculate the between-strata mixing proportion
                mixing_matrix[i, j] = (1 - prop_mixing_same_stratum) * prop_pop_j / prop_pop_non_i

    return mixing_matrix

In [None]:
n_strata = 3  # Define the number of strata
prop_mixing_same_stratum = 0.9  # Define within-stratum mixing proportion

# Example proportions (should sum to 1)
proportions = [0.06, 0.06, 0.88]

# Generate the mixing matrix
mixing_matrix = create_mixing_matrix(n_strata, prop_mixing_same_stratum, proportions)

# Display the mixing matrix
print("Generated Mixing Matrix:")
print(mixing_matrix)

In [None]:
stratum_names = ['Trial', 'Control', 'Other']

# Plot the mixing matrix using a heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(mixing_matrix, annot=True, fmt=".6f", cmap="YlGnBu", xticklabels=stratum_names, yticklabels=stratum_names)
plt.title("")
plt.xlabel("Stratum")
plt.ylabel("Stratum")
plt.show()