<a href="https://colab.research.google.com/github/peterrrock2/Computation-of-Infinitesimals/blob/main/notebooks/multi_member_recom_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quick Example notebook for `MultiMemberReCom`

This is just a quick example notebook for using the `MultiMemberReCom` class. The multimember
methods should have all of the appropriate functionality and logic implemented, but **they
not tested beyond an eyeball sanity checks**.


Also, not all of the functions in `gerrychain` will support MultiMemberReCom, but the basics
needed to run a chain are there.

The documentation strings are also not complete

In [None]:
!uv pip install -q git+https://github.com/mggg/GerryChain.git@wip/multi-member

In [None]:
from gerrychain import (Partition, Graph, MarkovChain,
                        updaters, accept)
from gerrychain.proposals.tree_proposals import MultiMemberReCom
from gerrychain.constraints import contiguous

# Set the random seed so that the results are reproducible!
import random
random.seed(2024)

## Basic Syntax

In [None]:
graph = Graph.from_json("https://raw.githubusercontent.com/mggg/GerryChain/refs/heads/main/docs/_static/gerrymandria.json")

my_updaters = {
    "population": updaters.Tally("TOTPOP"),
    "cut_edges": updaters.cut_edges
}

initial_partition = Partition(
    graph,
    assignment=dict(zip(graph.nodes, [
        1,1,1,1,1,1,1,1,
        1,1,1,1,1,1,1,1,
        1,1,1,1,1,1,1,1,
        4,4,4,4,4,4,4,4,
        4,4,4,4,4,4,4,4,
        6,6,6,6,6,6,6,6,
        6,6,6,6,6,6,6,6,
        8,8,8,8,8,8,8,8
    ])),
    updaters=my_updaters
)

In [None]:
print(graph)

In [None]:
n_members = 8
idea_pop_per_member = sum(initial_partition["population"].values()) / n_members


proposal = MultiMemberReCom(
    pop_col="TOTPOP",
    ideal_pop_per_member=idea_pop_per_member,
    epsilon=0.01,
    candidates_per_part_dict={1:3, 4:2, 6:2, 8:1},
)

In [None]:
recom_chain = MarkovChain(
    proposal=proposal,
    constraints=[contiguous],
    accept=accept.always_accept,
    initial_state=initial_partition,
    total_steps=40
)

In [None]:
assignment_list = []

for item in recom_chain.with_progress_bar():
    assignment_list.append(item.assignment)

In [None]:
%matplotlib inline
import matplotlib.cm as mcm
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
import io
import ipywidgets as widgets
from IPython.display import display, clear_output

frames = []

for i in range(len(assignment_list)):
    fig, ax = plt.subplots(figsize=(8,8))
    pos = {node :(data['x'],data['y']) for node, data in graph.nodes(data=True)}
    node_colors = [mcm.tab20(int(assignment_list[i][node]) % 20) for node in graph.nodes()]
    node_labels = {node: str(assignment_list[i][node]) for node in graph.nodes()}

    nx.draw_networkx_nodes(graph, pos, node_color=node_colors)
    nx.draw_networkx_edges(graph, pos)
    nx.draw_networkx_labels(graph, pos, labels=node_labels, font_color="white", font_weight="bold")
    plt.axis('off')

    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    image = Image.open(buffer)
    frames.append(image)
    plt.close(fig)

def show_frame(idx):
    clear_output(wait=True)
    display(frames[idx])

slider = widgets.IntSlider(value=0, min=0, max=len(frames)-1, step=1, description='Frame:')
slider.layout.width = '500px'
widgets.interactive(show_frame, idx=slider)

## Region Aware ReCom for Multi-Member Districting

In [None]:
proposal = MultiMemberReCom(
    pop_col="TOTPOP",
    ideal_pop_per_member=sum(initial_partition["population"].values()) / 8,
    epsilon=0.01,
    candidates_per_part_dict={1:3, 4:2, 6:2, 8:1},
    recom_kwargs=dict(
        node_repeats=2,
        region_surcharge={"muni": 1.0},
    )
)

recom_chain = MarkovChain(
    proposal=proposal,
    constraints=[contiguous],
    accept=accept.always_accept,
    initial_state=initial_partition,
    total_steps=40
)

assignment_list_ra = []

for item in recom_chain.with_progress_bar():
    print(f"Finished step {i+1}/{len(recom_chain)}", end="\r")
    assignment_list_ra.append(item.assignment)

In [None]:
frames_ra = []

for i in range(len(assignment_list_ra)):
    fig, ax = plt.subplots(figsize=(8,8))
    pos = {node :(data['x'],data['y']) for node, data in graph.nodes(data=True)}
    node_colors = [mcm.tab20(int(assignment_list_ra[i][node]) % 20) for node in graph.nodes()]
    node_labels = {node: str(assignment_list_ra[i][node]) for node in graph.nodes()}

    nx.draw_networkx_nodes(graph, pos, node_color=node_colors)
    nx.draw_networkx_edges(graph, pos)
    nx.draw_networkx_labels(graph, pos, labels=node_labels, font_color="white", font_weight="bold")
    plt.axis('off')

    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    image = Image.open(buffer)
    frames_ra.append(image)
    plt.close(fig)

def show_frame_ra(idx):
    clear_output(wait=True)
    display(frames_ra[idx])

slider_ra = widgets.IntSlider(value=0, min=0, max=len(frames_ra)-1, step=1, description='Frame:')
slider_ra.layout.width = '500px'
widgets.interactive(show_frame_ra, idx=slider_ra)

## Starting multi-member ReCom from a Random Seed

In [None]:
n_members_dict = {0:4, 1:2, 2:1, 3:1}


initial_partition = Partition.from_random_assignment(
    graph=graph,
    pop_col="TOTPOP",
    n_parts=4,
    ideal_pop=8,
    epsilon=0.0000001,
    method_kwargs={"partition_pop_multiplier_by_part": n_members_dict},
)

proposal = MultiMemberReCom(
    pop_col="TOTPOP",
    ideal_pop_per_member=sum(initial_partition["population"].values()) / 8,
    epsilon=0.01,
    candidates_per_part_dict=n_members_dict,
)

recom_chain = MarkovChain(
    proposal=proposal,
    constraints=[contiguous],
    accept=accept.always_accept,
    initial_state=initial_partition,
    total_steps=40
)

assignment_list_random_seed = []

for item in recom_chain.with_progress_bar():
    assignment_list_random_seed.append(item.assignment)

In [None]:
frames_random_seed = []

for i in range(len(assignment_list_random_seed)):
    fig, ax = plt.subplots(figsize=(8,8))
    pos = {node :(data['x'],data['y']) for node, data in graph.nodes(data=True)}
    node_colors = [mcm.tab10(int(assignment_list_random_seed[i][node]) % 20) for node in graph.nodes()]
    node_labels = {node: str(assignment_list_random_seed[i][node]) for node in graph.nodes()}

    nx.draw_networkx_nodes(graph, pos, node_color=node_colors)
    nx.draw_networkx_edges(graph, pos)
    nx.draw_networkx_labels(graph, pos, labels=node_labels, font_color="white", font_weight="bold")
    plt.axis('off')

    buffer = io.BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    image = Image.open(buffer)
    frames_random_seed.append(image)
    plt.close(fig)

def show_frame_random_seed(idx):
    clear_output(wait=True)
    display(frames_random_seed[idx])

slider_random_seed = widgets.IntSlider(value=0, min=0, max=len(frames_random_seed)-1, step=1, description='Frame:')
slider_random_seed.layout.width = '500px'
widgets.interactive(show_frame_random_seed, idx=slider_random_seed)