# K means ++ clustering of city persistence diagrams

Clustering cities based on their persistence diagrams. 

In [None]:
from gerrychain import Graph
import networkx as nx
import numpy as np
import gudhi
import random
import gudhi.hera
from sklearn.manifold import MDS
from sklearn.cluster import DBSCAN
import os
from numpy import inf
import matplotlib.pyplot as plt
import scipy
from scipy.stats import pearsonr
from sklearn import linear_model
import pandas as pd
import importlib
import tdaredistricting
import geopandas as gpd
from matplotlib.pyplot import Circle
from tqdm import tqdm
plt.rcParams['text.usetex'] = True
INFINITY = 10e5

In [None]:
os.makedirs('figs/clustering', exist_ok=True)
os.makedirs('figs/MDS', exist_ok=True)
os.makedirs('figs/pds and maps', exist_ok=True)
os.makedirs('figs/time', exist_ok=True)
os.makedirs('figs/TP', exist_ok=True)

## Globals

Change these to switch between numbers of clusters, elbow plot functionality, and choice of demographic.

In [None]:
col = 'BLACK'
year = 2020
doelbow=False #set this to generate the elbow plot which can be slow
num_clusters=5 #number of clusters for full analysis

## Functions

In [None]:
def infinity_is_one(PD):
    for i, point in enumerate(PD):
        if point[1] == inf or point[1] == INFINITY:
            PD[i] = (PD[i][0], 1)
    return PD
    
def pd_from_graph(graph1, column, popthreshold=0):
    scomplex1 = gudhi.SimplexTree()
    for i in graph1.nodes: 
        scomplex1.insert([i]) #add a 0-simplex, given as a list with just one vertex for (u,v) in grid.edges: 
    for (u,v) in graph1.edges:
        scomplex1.insert([u,v]) #insert edge for v in scomplex.get_skeleton(0):
    for v in scomplex1.get_skeleton(0):
        node = v[0][0]
        if graph1.nodes[node]['TOTPOP'] > popthreshold:
            scomplex1.assign_filtration(
                v[0],
                filtration = 1-graph1.nodes[node][column]/graph1.nodes[node]['TOTPOP']
            )
        else:
            neighbor_values = [
                graph1.nodes[m][column]/graph1.nodes[m]['TOTPOP']
                for m in graph1.neighbors(node) if graph1.nodes[m]['TOTPOP'] > popthreshold
            ]
            if len(neighbor_values) == 0:
                scomplex1.assign_filtration(
                    v[0],
                    1
                )
            else:
                scomplex1.assign_filtration(
                    v[0],
                    1-max(neighbor_values)
                )
    scomplex1.make_filtration_non_decreasing()
    
    persistence1 = scomplex1.persistence()
    persistence01 = [x[1] for x in persistence1 if x[0] == 0]
    for i, point in enumerate(persistence01):
        if point[1] == inf:
            persistence01[i] = (persistence01[i][0], INFINITY)
    return persistence01

def wasserstein_between_pds(pd1, pd2, p=1):
    return gudhi.hera.wasserstein_distance(pd1, pd2, order = p, internal_p = p)

def bottleneck_between_pds(pd1, pd2):
    return gudhi.bottleneck_distance(pd1, pd2)

  
def total_persistence(pd1, p=1): 
    tp = np.linalg.norm(np.array([x[1]-x[0] for x in pd1]), ord=p)
    return tp

def minshare(graph1, column1):
    total_group = sum([graph1.nodes[n][column1] for n in graph1.nodes])
    total_population = sum([graph1.nodes[n]['TOTPOP'] for n in graph1.nodes])
    return total_group/total_population

def moransI(graph, col, pop_col='TOTPOP'):
    A = nx.adjacency_matrix(graph).toarray()
    P = A/A.sum(axis=0)
    v = np.array([graph.nodes[n][col]/(graph.nodes[n][pop_col]+1e-9) for n in graph.nodes])
    v = v-np.mean(v)
    return np.dot(np.dot(v, P), v)/np.dot(v,v)

def DI(graph, col):
    group_pops = np.array([graph.nodes[n][col] for n in graph.nodes])
    totpops = np.array([graph.nodes[n]['TOTPOP'] for n in graph.nodes])
    DI = 0.5*sum(
        np.abs(group_pops/group_pops.sum() - (totpops - group_pops)/(totpops.sum() - group_pops.sum()))
    )
    return DI

def outliers(distance_matrix):
    loof = np.mean(
        [
            LocalOutlierFactor(
                metric='precomputed',
                n_neighbors=k).fit(distance_matrix).negative_outlier_factor_
            for k in range(10,20)
        ], axis=0
    )
    outliers = [i for i in range(len(loof)) if loof[i] <= -2]
    return outliers
    

In [None]:
def chooseplusplus(M, k):
    np.random.seed(2023)
    seeds = [np.random.choice(range(len(M)))]
    while len(seeds) < k:
        x = np.random.choice(
            range(len(M)),
            p = M[seeds[-1]]**2/sum(M[seeds[-1]]**2)
        )
        seeds.append(x)
    return seeds


def kmeansplusplus(PDs, k, MAXITER=100, p=2, eps=1e-5, verbose=False, seeds=None):
    if seeds is None:
        distances_pairwise = np.array([
            np.array([
                wasserstein_between_pds(pd1, pd2,p=2) for pd1 in PDs
            ]) for pd2 in PDs
        ])
        M = (distances_pairwise + distances_pairwise.T)/2
        seeds = chooseplusplus(M, k)
    if verbose:
        print('seeds:', seeds)
    means = [PDs[i] for i in seeds]
    newmeans = means.copy()
    matches = [0]*len(PDs)
    print('\n')
    for i in range(MAXITER):
        print('.', end='')
        #match
        for j, PD in enumerate(PDs):
            D = [
                gudhi.hera.wasserstein_distance(PD, m, order = p, internal_p = p) for m in means
            ] 
            matches[j] = np.argmin(D)
        #average
        for j in range(k):
            newmeans[j] = tdaredistricting.Frechet_mean(
                [PDs[n] for n in range(len(PDs)) if matches[n] == j],
                seed=means[j]
            )
        #convergence
        deltas = [
            gudhi.hera.wasserstein_distance(newmeans[j], means[j], order = p, internal_p = p) for j in range(k)
        ]
        #track distortion
        distortion = sum([
            gudhi.hera.wasserstein_distance(
                PDs[j], newmeans[matches[j]], order = p, internal_p = p
            ) for j in range(len(PDs))
        ])
        if verbose:
            print('Changes: ', deltas)
            print('Distortion: ', distortion)
        if max(deltas) < eps:
            return means, matches, distortion
        else:
            means = newmeans.copy()
    print('DID NOT CONVERGE')
    return means, matches, distortion 

## Load data

In [None]:
n = 100
list_of_cities_pd =pd.read_csv('./City_Names_And_Populations.csv')
list_of_cities = [x+y for x,y in zip(list_of_cities_pd.NAME, list_of_cities_pd.ST)][:n]
city_names = [x+' ' + y for x,y in zip(list_of_cities_pd.NAME, list_of_cities_pd.ST)][:n]
cmap = plt.get_cmap('tab20')
colors = [cmap(x/(7-1)) for x in range(7)]

In [None]:
list_of_graphs = {'2010':[], '2020':[]}
for year in list_of_graphs:
    for i in range(n):
        graph1=Graph.from_json('./cities{}data/{}.json'.format(year, list_of_cities[i]))
#         cc = sorted(nx.connected_components(graph1), key=len, reverse=True)
#         print(list_of_cities[i], (len(graph1) - len(cc[0]))/len(graph1))
#         graph1 = graph1.subgraph(cc[0])
        list_of_graphs[year].append(graph1) 
PDs = {y: [infinity_is_one(pd_from_graph(graph, col, popthreshold=10)) for graph in list_of_graphs[y]] for y in list_of_graphs}

In [None]:
coords = pd.DataFrame()
x = []
y = []
coords['name'] = list_of_cities[:n]
for i, graph in enumerate(list_of_graphs['2010']):
    node = list(graph.nodes)[0]
    x.append(graph.nodes[node]['C_X'])
    y.append(graph.nodes[node]['C_Y'])
    
coords['x'] = x
coords['y'] = y
coords = coords.set_index('name')
coords.to_csv('city_coordinates.csv', index='name')

## Do k-means

In [None]:
distances_pairwise = np.array([
    np.array([
        wasserstein_between_pds(pd1, pd2,p=2) for pd1 in PDs[year]
    ]) for pd2 in PDs[year]
])
distances_pairwise = (distances_pairwise + distances_pairwise.T)/2

In [None]:
if doelbow:
    elbow = []
    for k in range(2,11):
        np.random.seed(2023)
        random.seed(2023)
        means, matches, distortion = kmeansplusplus(PDs[year], k, verbose=False)
        elbow.append(distortion)
    print(elbow)
    plt.plot(range(2,2+len(elbow)), elbow)
    plt.scatter(range(2,2+len(elbow)), elbow)
    plt.xticks(range(2+len(elbow)))
    plt.xlabel('k')
    plt.ylabel('distortion')
    plt.savefig('figs/clustering/{}_{}_elbow_plot.png'.format(col, year), bbox_inches='tight', dpi=150)
    plt.show()

In [None]:
np.random.seed(2023)
random.seed(2023)
means, matches, distortion = kmeansplusplus(PDs['2020'], num_clusters, verbose=False)

In [None]:
#reorder
index_by_TP = np.argsort(
    [
        -sum([x[1]-x[0] for x in m]) for m in means
    ]
)
reindex = {x:i for i,x in enumerate(index_by_TP)}
matches = [reindex[m] for m in matches]
means = [means[i] for i in index_by_TP]

In [None]:
fipslist = ['01', '02', '04', '05', '06', '08', '09', '10', '11', '12', '13',
            '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25',
            '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36',
            '37', '38', '39', '40', '41', '42', '44', '45', '46', '47', '48',
            '49', '50', '51', '53', '54', '55', '56']
contiglist = [x for x in fipslist if x not in ['02', '15']]
wholeUS = gpd.read_file(
    './cb_2018_us_state_500k/cb_2018_us_state_500k.shp'
)
wholeUS = wholeUS.to_crs('ESRI:102003')
contigUS = wholeUS[wholeUS.STATEFP.isin(contiglist)]
hawaii = wholeUS[wholeUS.STATEFP == '02']
alaska = wholeUS[wholeUS.STATEFP == '15']

fig, ax = plt.subplots(figsize=(15,15))
contigUS.plot(ax=ax, edgecolor='black', facecolor='white', linewidth=0.2)
for i, graph in enumerate(list_of_graphs[year]):
    if list_of_cities[i] not in ['HonoluluHI', 'AnchorageAK']:
        node = list(graph.nodes)[0]
        ax.scatter(
            [graph.nodes[node]['C_X']],
            [graph.nodes[node]['C_Y']],
            s=50,
            c=[colors[matches[i]]],
        )
plt.axis('off')
plt.savefig('figs/clustering/{}_{}_US.png'.format(col, year), bbox_inches='tight', dpi=150)

In [None]:
for i, mean in enumerate(means):
    fig, ax = plt.subplots(figsize=(2.5,2.5))
    ax = plt.gca()
    ax.scatter(
        [x[0] for x in infinity_is_one(mean)],
        [x[1] for x in infinity_is_one(mean)],
        color=colors[i],
        s=10
    )
    ax.set_xlim(0,1)
    ax.set_xticks([0.2, 0.4, 0.6, 0.8])
    ax.set_yticks([0.2, 0.4, 0.6, 0.8])
    ax.set_ylim(0,1.1)
    ax.set_aspect(1)
    ax.plot([0,1], [0,1], linestyle='dashed', c='grey')
    ax.annotate('$\infty$', (-0,1))
    plt.savefig('figs/clustering/{}_{}_mean{}.png'.format(col, year, i), bbox_inches='tight', dpi=150)
    plt.show()
    plt.close()
    

## Cluster stats

Formatted for $\LaTeX$.

In [None]:
def pop_of_city(i):
    graph = list_of_graphs[year][i]
    return sum(graph.nodes[n]['TOTPOP'] for n in graph.nodes)

print('cluster & \\#cities & ave. pop & ave. \\# tracts & ave. \% {} \\\ '.format(col))
for k in range(len(means)):
    indexesofPDs = [j for j in range(len(PDs[year])) if matches[j] == k]
    thesePDs = [PDs[year][j] for j in indexesofPDs]
    print('{} \\textcolor {{mycolor{}}}{{$\\blacksquare$}}'.format(k+1, k+1), end=' & ')
    print(len(thesePDs), end=' & ')
    print('{:.2f} '.format(np.mean([pop_of_city(i) for i in indexesofPDs])), end=' & ')
    print('{:.2f} '.format(np.mean([len(list_of_graphs[year][i]) for i in indexesofPDs])), end=' & ')
    print('{:.2f} \% \\\ '.format(100*np.mean([minshare(list_of_graphs[year][i], col) for i in indexesofPDs])))