In [None]:
%config InlineBackend.figure_formats = ['svg']
import matplotlib.pyplot as plt
from gerrychain import (GeographicPartition, Partition, Graph, MarkovChain,
                        proposals, updaters, constraints, accept, Election)
from gerrychain.proposals import recom
from gerrychain.metrics import mean_median
from functools import partial
import pandas as pd
import matplotlib.pyplot as plt; plt.style.use('ggplot')

In [None]:
graph = Graph.from_file("data/PA_VTD/PA_VTD.shp")

In [None]:
def pop_deviation(partition):
    return max(partition["population"]) - min(partition["population"])

def abs_mean_median(partition, election="SEN16"):
    return abs(mean_median(partition[election]))

# from https://gerrychain.readthedocs.io/en/latest/user/recom.html
my_updaters = {
    "population": updaters.Tally("TOT_POP", alias="population"),
    "cut_edges": updaters.cut_edges,
    "n_cut_edges": lambda p: len(updaters.cut_edges(p)),
    "mean_median": abs_mean_median,
    "pop_deviation": pop_deviation,
}
elections = [
    Election("SEN10", {"Democratic": "SEN10D", "Republican": "SEN10R"}),
    Election("SEN12", {"Democratic": "USS12D", "Republican": "USS12R"}),
    Election("SEN16", {"Democratic": "T16SEND", "Republican": "T16SENR"}),
    Election("PRES12", {"Democratic": "PRES12D", "Republican": "PRES12R"}),
    Election("PRES16", {"Democratic": "T16PRESD", "Republican": "T16PRESR"})
]
election_updaters = {election.name: election for election in elections}
my_updaters.update(election_updaters)

In [None]:
initial_partition = GeographicPartition(graph, assignment="2011_PLA_1", updaters=my_updaters)
ideal_population = sum(initial_partition["population"].values()) / len(initial_partition)
pop_tolerance = 0.02

proposal = partial(recom,
                   pop_col="TOT_POP",
                   pop_target=ideal_population,
                   epsilon=pop_tolerance,
                   node_repeats=2
                  )

compactness_bound = constraints.UpperBound(
    lambda p: len(p["cut_edges"]),
    2*len(initial_partition["cut_edges"])
)

pop_constraint = constraints.within_percent_of_ideal_population(initial_partition, pop_tolerance)

In [None]:
from random import random
import numpy as np

def pareto_driver(partition, metrics, p=0.01):
    """
    If the proposed partition dominates the previous partition
    with respect to a set of scalar metrics, we accept the proposal
    with probability 1.
    
    Otherwise, we accept the proposal with probability ``p``.
    """
    dominant = True
    for metric in metrics:
        if partition[metric] >= partition.parent[metric]:
            dominant = False
            break
    return dominant or random() < p

In [None]:
def score(partition, metrics, alpha):
    return np.exp(sum(partition[metric] * weight
                  for metric, weight in zip(metrics, alpha)))
        
def weighted_pareto_driver(partition, metrics, alpha, p=0.01):
    return score(partition) < score(partition.parent) or random() < p

In [None]:
chain = MarkovChain(
    proposal=proposal,
    constraints=[
        pop_constraint,
        compactness_bound
    ],
    accept=partial(pareto_driver, metrics=['n_cut_edges', 'mean_median']),
    initial_state=initial_partition,
    total_steps=1000
)

In [None]:
mm_data = []
cut_data = []

In [None]:
for run in range(2):
    print('--- Run', run, '---')
    for part in chain:
        cut_data.append(part['n_cut_edges'])
        mm_data.append(part['mean_median'])
        print(part['n_cut_edges'], part['mean_median'])

In [None]:
fig, ax1 = plt.subplots(figsize=(10, 8))
ax1.set_xlabel('Step number')
ax1.set_ylabel('Number of cut edges', color='tab:red')
ax1.plot(range(len(cut_data)), cut_data, color='tab:red')

ax2 = ax1.twinx()
ax2.set_ylabel('Mean-median gap (SEN12)', color='tab:blue')
ax2.plot(range(len(mm_data)), mm_data, color='tab:blue')

plt.savefig('results/Pareto_driver/tradeoff_run4.png', dpi=300)
plt.show()

In [None]:
plt.scatter(mm_data, cut_data, c=[idx // 1000 for idx, _ in enumerate(mm_data)])
cbar = plt.colorbar()
cbar.ax.set_ylabel('Walk')
plt.xlabel('Absolute mean-median score')
plt.ylabel('Number of cut edges')
plt.title('Convergence of front over multiple walks')
plt.savefig('results/Pareto_driver/front_convergence_over_multiple_walks.png')
plt.show()

In [None]:
plt.scatter(mm_data, cut_data, c=[idx % 1000 for idx, _ in enumerate(mm_data)])
plt.xlabel('Absolute mean-median score')
plt.ylabel('Number of cut edges')
plt.title('Front trajectories over multiple walks')
cbar = plt.colorbar()
cbar.ax.set_ylabel('Step number')
plt.savefig('results/Pareto_driver/front_trajectory_over_multiple_walks.png')
plt.show()