# Make PDs and match points to map

This notebook contains functions for displaying persistence diagrams and their corresponding choropleths.

**REQUIREMENTS:** To plot the maps, a full census tract shapefile for the US is required, which we obtained from [NHGIS](https://www.nhgis.org/). It's too big to store on github so please go and download the required files from [here](https://uncg-my.sharepoint.com/:f:/g/personal/t_weighill_uncg_edu/EhIEORpMq8ZOgu3yjn7ASdsBwMp2_B25gGizQaLhBlLvJw?e=SobxGa) and unzip them (so that `nhgis0016_shape` and `nhgis0016_shape` are at the same directory level as this notebook). 

In [None]:
from gerrychain import Graph
import networkx as nx
import numpy as np
import gudhi
import random
import gudhi.hera
from sklearn.manifold import MDS
from sklearn.cluster import DBSCAN
import os
from numpy import inf
import matplotlib.pyplot as plt
import scipy
from scipy.stats import pearsonr
from sklearn import linear_model
import pandas as pd
import importlib
import tdaredistricting
import geopandas as gpd
from matplotlib.pyplot import Circle
from tqdm import tqdm
plt.rcParams['text.usetex'] = True
INFINITY = 10e5

## Functions

In [None]:
def infinity_is_one(PD):
    for i, point in enumerate(PD):
        if point[1] == inf or point[1] == INFINITY:
            PD[i] = (PD[i][0], 1)
    return PD
    
def pd_from_graph(graph1, column, popthreshold=0):
    scomplex1 = gudhi.SimplexTree()
    for i in graph1.nodes: 
        scomplex1.insert([i]) #add a 0-simplex, given as a list with just one vertex for (u,v) in grid.edges: 
    for (u,v) in graph1.edges:
        scomplex1.insert([u,v]) #insert edge for v in scomplex.get_skeleton(0):
    for v in scomplex1.get_skeleton(0):
        node = v[0][0]
        if graph1.nodes[node]['TOTPOP'] > popthreshold:
            scomplex1.assign_filtration(
                v[0],
                filtration = 1-graph1.nodes[node][column]/graph1.nodes[node]['TOTPOP']
            )
        else:
            neighbor_values = [
                graph1.nodes[m][column]/graph1.nodes[m]['TOTPOP']
                for m in graph1.neighbors(node) if graph1.nodes[m]['TOTPOP'] > popthreshold
            ]
            #print(graph1.nodes[node]['GEOIDGIS'])
            #print(neighbor_values)
            
            if len(neighbor_values) == 0:
                scomplex1.assign_filtration(
                    v[0],
                    1
                )
            else:
                #print('assigning', 1-max(neighbor_values))
                scomplex1.assign_filtration(
                    v[0],
                    1-max(neighbor_values)
                )
    scomplex1.make_filtration_non_decreasing()
    
    persistence1 = scomplex1.persistence()
    persistence01 = [x[1] for x in persistence1 if x[0] == 0]
    for i, point in enumerate(persistence01):
        if point[1] == inf:
            persistence01[i] = (persistence01[i][0], INFINITY)
    return persistence01

def wasserstein_between_pds(pd1, pd2, p=1):
    return gudhi.hera.wasserstein_distance(pd1, pd2, order = p, internal_p = p)

def bottleneck_between_pds(pd1, pd2):
    return gudhi.bottleneck_distance(pd1, pd2)

  
def total_persistence(pd1, p=1): 
    tp = np.linalg.norm(np.array([x[1]-x[0] for x in pd1]), ord=p)
    return tp

def minshare(graph1, column1):
    total_group = sum([graph1.nodes[n][column1] for n in graph1.nodes])
    total_population = sum([graph1.nodes[n]['TOTPOP'] for n in graph1.nodes])
    return total_group/total_population

def moransI(graph, col, pop_col='TOTPOP'):
    A = nx.adjacency_matrix(graph).toarray()
    P = A/A.sum(axis=0)
    v = np.array([graph.nodes[n][col]/(graph.nodes[n][pop_col]+1e-9) for n in graph.nodes])
    v = v-np.mean(v)
    return np.dot(np.dot(v, P), v)/np.dot(v,v)

def DI(graph, col):
    group_pops = np.array([graph.nodes[n][col] for n in graph.nodes])
    totpops = np.array([graph.nodes[n]['TOTPOP'] for n in graph.nodes])
    DI = 0.5*sum(
        np.abs(group_pops/group_pops.sum() - (totpops - group_pops)/(totpops.sum() - group_pops.sum()))
    )
    return DI

def outliers(distance_matrix):
    loof = np.mean(
        [
            LocalOutlierFactor(
                metric='precomputed',
                n_neighbors=k).fit(distance_matrix).negative_outlier_factor_
            for k in range(10,20)
        ], axis=0
    )
    outliers = [i for i in range(len(loof)) if loof[i] <= -2]
    return outliers

def map_and_pd(year, city, col, npoints=5, verbose=False, popthreshold=10, reorder=None):
    '''
    year: the year, either 2010 or 2020
    city: the name of the city
    col: the demographic, either BLACK or HISP
    npoints: the number of points to color and match, counting from the most persistent
    '''
    
    #PD
    cmap = plt.get_cmap('tab20')
    if reorder==None:
        reorder=range(npoints)
    colors = [cmap(reorder[x]/(7-1)) for x in range(npoints)]
    fig, (ax, mapax) = plt.subplots(1,2, figsize=(6,6))
    graph = Graph.from_json('./cities{}data/{}.json'.format(year, city))
    print(nx.is_connected(graph))
    PD = pd_from_graph(graph, col, popthreshold=popthreshold)
    PD = sorted(infinity_is_one(PD), key = lambda x: x[1]-x[0], reverse=True)
    ax.scatter(
        [x[0] for x in infinity_is_one(PD)[npoints:]],
        [x[1] for x in infinity_is_one(PD)[npoints:]],
        color='black',
        s=10
    )
    vals_to_flag = []
    for i in range(npoints):
        print(PD[i])
        vals_to_flag.append(1-PD[i][0])
        ax.scatter(
            [x[0] for x in infinity_is_one(PD)[i:i+1]],
            [x[1] for x in infinity_is_one(PD)[i:i+1]],
            color=colors[i],
            s=40
        )   
    ax.set_xlim(0,1)
    ax.set_xticks([0.2, 0.4, 0.6, 0.8])
    ax.set_yticks([0.2, 0.4, 0.6, 0.8])
    ax.set_ylim(0,1.1)
    ax.set_aspect(1)
    ax.plot([0,1], [0,1], linestyle='dashed', c='grey')
    ax.annotate('$\infty$', (-0,1))
    
    #map
    if year == 2010:
        tractshapes = gpd.read_file(
            './nhgis0016_shape/nhgis0016_shapefile_tl2010_us_tract_2010/US_tract_2010.shp'
        )
        GEOIDs = [graph.nodes[n]["GEOID10"] for n in graph.nodes]
        cityshapes = tractshapes[tractshapes.GEOID10.isin(GEOIDs)]
        cityshapes = cityshapes.set_index('GEOID10')
        for n in graph.nodes:
            cityshapes.loc[graph.nodes[n]['GEOID10'], 'perc'] = graph.nodes[n][col]/(graph.nodes[n]['TOTPOP']+1e-7)
    if year == 2020:
        tractshapes = gpd.read_file(
            './nhgis0017_shape/nhgis0017_shapefile_tl2020_us_tract_2020/US_tract_2020.shp'
        )
        GEOIDs = [graph.nodes[n]["GEOIDGIS"] for n in graph.nodes]
        cityshapes = tractshapes[tractshapes.GEOID.isin(GEOIDs)]
        cityshapes = cityshapes.set_index('GEOID')
        for n in graph.nodes:
            cityshapes.loc[graph.nodes[n]['GEOIDGIS'], 'perc'] = graph.nodes[n][col]/(graph.nodes[n]['TOTPOP']+1e-7)     
    cityshapes['dummy'] = [1]*len(cityshapes)
    wholecity = cityshapes.dissolve(by='dummy')
    if npoints == 0:
        coloring = 'Blues'
    else:
        coloring = 'binary'
    cityshapes.plot(
        column='perc',
        cmap=coloring,
        vmin=0, vmax=1,
        ax=mapax,
        legend=True,
        legend_kwds={'shrink': 0.3}
    )
    wholecity.boundary.plot(
        edgecolor='black',
        ax=mapax,
        linewidth=0.2
    )
    mapax.axis('off')
    for i, val in enumerate(vals_to_flag):
        nodes = [
            n for n in graph.nodes if np.abs(
                graph.nodes[n][col]/(graph.nodes[n]['TOTPOP']+1e-7) - val
            ) < 1e-7 and graph.nodes[n]['TOTPOP'] > popthreshold
        ]
        node = nodes[0]
        if verbose:
            print(graph.nodes[node], '\n')
        x,y = graph.nodes[node]['C_X'], graph.nodes[node]['C_Y']
        extent = mapax.get_ylim()[1]-mapax.get_ylim()[0]
        patch = plt.Circle((x, y), extent*0.05, color=colors[i], fill=False, linewidth=2)
        mapax.add_patch(patch)

# Plot making

Here's one example:

In [None]:
year, city, col = 2020, 'ChicagoIL', 'HISP'
map_and_pd(year, city, col, npoints=4, verbose=True)
plt.savefig('figs/pds and maps/pd_and_map_{}_{}_{}_nopoints.png'.format(year, city, col), dpi=300, bbox_inches='tight')