In [66]:
import pandas as pd
from pathlib import Path
import h3
import folium
from geojson import Feature, FeatureCollection
import _json
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import ast
pd.options.mode.copy_on_write = True 
from functools import reduce
from folium import GeoJsonTooltip
import json

## Create JSON of subspecies

In [67]:
# Load eBird taxonomy
taxonomy = pd.read_csv("eBird_taxonomy_v2024.csv")
taxonomy.head()

Unnamed: 0,TAXON_ORDER,CATEGORY,SPECIES_CODE,TAXON_CONCEPT_ID,PRIMARY_COM_NAME,SCI_NAME,ORDER,FAMILY,SPECIES_GROUP,REPORT_AS
0,2,species,ostric2,,Common Ostrich,Struthio camelus,Struthioniformes,Struthionidae (Ostriches),Ostriches,
1,7,species,ostric3,,Somali Ostrich,Struthio molybdophanes,Struthioniformes,Struthionidae (Ostriches),Ostriches,
2,8,slash,y00934,,Common/Somali Ostrich,Struthio camelus/molybdophanes,Struthioniformes,Struthionidae (Ostriches),Ostriches,
3,10,species,soucas1,,Southern Cassowary,Casuarius casuarius,Casuariiformes,Casuariidae (Cassowaries and Emu),Cassowaries and Emu,
4,11,species,dwacas1,,Dwarf Cassowary,Casuarius bennetti,Casuariiformes,Casuariidae (Cassowaries and Emu),Cassowaries and Emu,


In [68]:
# What categories are there in the taxonomy?
taxonomy.CATEGORY.unique()


array(['species', 'slash', 'issf', 'hybrid', 'spuh', 'domestic', 'form',
       'intergrade'], dtype=object)

In [69]:
# How many species are there?
species = taxonomy[taxonomy.CATEGORY == 'species']
len(species)


11145

In [70]:
# How many infraspecific entries are there?
infrasp_categories = ['issf', 'form', 'intergrade']
infraspp = taxonomy[taxonomy.CATEGORY.isin(infrasp_categories)]
len(infraspp)

3843

In [93]:
start_row=0
end_row=2
for x in pd.read_csv("example.csv", header=0, skiprows=range(0,2), chunksize=2):
    print(x)

   1  cell1 cell1.1
0  2  cell2   cell2
1  3  cell3   cell3
   1  cell1 cell1.1
2  4  cell4   cell4
3  5  cell5   cell5
   1  cell1 cell1.1
4  6  cell6   cell6


In [None]:
# Create a dictionary mapping species to their infraspecies, by category
# Most species have a single category of infraspecies (e.g. either form or subspecies, not both)
# However some species have infraspecies in multiple categories, e.g. Brant (Branta bernicla) has both subspecies and forms
spp_dict = species[['SPECIES_CODE', 'PRIMARY_COM_NAME', 'SCI_NAME']].set_index("SCI_NAME").T.to_dict()
# Add infraspecies to spp_json
for sp in tqdm(spp_dict.keys()):
    # Get infraspecies for this species
    infraspp_for_sp = infraspp[infraspp['SCI_NAME'].apply(lambda x: x[:len(sp)] == sp)]
    infraspp_dict = dict()

    # Add infraspecies to spp_json by category
    for cat in infrasp_categories:
        infrasp_in_category = infraspp_for_sp[infraspp_for_sp.CATEGORY == cat]
        infrasp_cat_dict = infrasp_in_category[
            ['SPECIES_CODE', 'PRIMARY_COM_NAME', 'SCI_NAME']].set_index("SCI_NAME").T.to_dict()
        if len(infrasp_cat_dict.keys()) > 0:
            infraspp_dict[cat] = infrasp_cat_dict
    spp_dict[sp]['infraspecies'] = infraspp_dict
    

In [278]:
with open("infraspecies_ebird.json", 'w') as f:
    f.write(json.dumps(spp_dict, indent=4))

In [8]:
with open("infraspecies_ebird.json") as f:
    spp_dict = json.load(f)

## Prep eBird data

Use the following H3 resolutions:
```
Res	Average Hexagon Area (km2)	Pentagon Area* (km2)	Ratio (P/H)
2	86,801.780398997	44,930.898497879	0.5176
3	12,393.434655088	6,315.472267516	0.5096
4	1,770.347654491	896.582383141	0.5064
```

They have this many cells:
```
Res	Total number of cells	Number of hexagons	Number of pentagons
2	5,882	5,870	12
3	41,162	41,150	12
4	288,122	288,110	12
```

In [9]:
resolutions = [2, 3, 4]

In [52]:
# Get list of species codes based on the data downloaded from eBird
sp_codes = [x.name.split('_')[1] for x in list(Path("data/").glob("*.zip"))]

## Determine number of sightings of each subspecies per grid cell

### Calculate in batches

In [98]:
def clean_ebd(
    full_df,
    remove_unconfirmed=True, 
    remove_reviewed=False,
    resolutions = resolutions,
):

    # Remove duplicate checklists
    full_df = full_df[full_df['GROUP IDENTIFIER'].isnull() | ~full_df[full_df['GROUP IDENTIFIER'].notnull()].duplicated(subset=["GROUP IDENTIFIER", "SCIENTIFIC NAME"],keep='first')]

    # Removed unconfirmed observations or reviewed observations, if desired
    if remove_unconfirmed:
        full_df = full_df[full_df["APPROVED"] == 1]
    if remove_reviewed:
        full_df = full_df[full_df["REVIEWED"] == 0]

    # Convert latitude and longitude to an H3 hexagon ID
    for resolution in resolutions:
        full_df[f'hex_id_{resolution}'] = full_df.apply(lambda row:  h3.latlng_to_cell(row.LATITUDE, row.LONGITUDE, resolution), axis=1)
    
    return full_df

def get_grid_cell_species_data(cell_df, sp, subspp, resolution):
    """Get # of checklists containing a species and each subspecies

    Args:
    - cell_df: pd.DataFrame, dataframe of data for a single grid cell (1 row per observation)
    - sp: str, scientific name of species
    - subspp: list of str, scientific names of subspecies for this species

    Returns:
    - cell_data: dict, with keys 'cell_id', species name, and subspecies names
    """
    # Total number of checklists containing the species
    num_checklists = cell_df["SAMPLING EVENT IDENTIFIER"].nunique()

    # Create a dict of # checklists containing sp for all cells
    cell_data = {'cell_id': cell_df[f"hex_id_{resolution}"].iloc[0]}
    cell_data[sp] = num_checklists

    # Add number of checklists containing each subspecies
    for subsp in subspp:
        num_subsp = cell_df[cell_df["SUBSPECIES SCIENTIFIC NAME"] == subsp].shape[0]
        cell_data[subsp] = num_subsp

    return cell_data


def get_species_df(sp, sp_df, subspp, resolution):
    """Make dataframe of species & subspecies data for every cell for a given species

    Args:
    - sp: str, scientific name of species
    - df: pd.DataFrame, dataframe of data for this species
    - subspp: list of str, scientific names of subspecies for this species
    - resolution: int, H3 resolution level
    """

    # Create a dict of # checklists containing sp for all cells
    cell_dicts = []
    for cell in sp_df[f"hex_id_{resolution}"].unique():
        cell_df = sp_df[sp_df[f"hex_id_{resolution}"] == cell]
        cell_data = get_grid_cell_species_data(cell_df, sp, subspp, resolution)
        cell_dicts.append(cell_data)

    sp_cell_df = pd.DataFrame(cell_dicts, index=range(len(cell_dicts)))
    sp_cell_df.set_index("cell_id", inplace=True)

    return sp_cell_df


# The dictionary we've created of all species and their infraspecies, by category
spp_dict = json.load(open("infraspecies_ebird.json"))

# Only include the columns we're using
use_cols = [
        'SCIENTIFIC NAME', 'SUBSPECIES SCIENTIFIC NAME',
        'SAMPLING EVENT IDENTIFIER', 'GROUP IDENTIFIER',
        'LATITUDE', 'LONGITUDE', 
        'REVIEWED', 'APPROVED', 'OBSERVATION DATE']
# These were used in an earlier version of the code, but are not needed now
# use_cols = [
#         'CATEGORY', 'TAXON CONCEPT ID', 'COMMON NAME', 
#         'SCIENTIFIC NAME','SUBSPECIES COMMON NAME', 'SUBSPECIES SCIENTIFIC NAME',
#         'SAMPLING EVENT IDENTIFIER',
#         'LATITUDE', 'LONGITUDE', 'REVIEWED', 'APPROVED', 'GROUP IDENTIFIER', 'OBSERVATION DATE']

# Process each species's dataset
for sp_code in sp_codes:
    print("\n\n\nProcessing", sp_code)
    dataset_filepath = f"data/ebd_{sp_code}_relOct-2024/ebd_{sp_code}_relOct-2024.txt"

    # Where we will store each batch of data
    ssp_batch_directory = Path('batches/')
    ssp_batch_directory.mkdir(exist_ok=True)

    # Read in CSV in batches
    chunk_rows = 100000 # May want to increase this number after using FireDuck

    # Keep track of the progress made in each batch using a tracker CSV
    tracker_filepath = f"progress_trackers/{sp_code}_tracker_rowsperchunk-{chunk_rows}.csv"

    # If already started processing this species, pick up where we left off 
    if Path(tracker_filepath).exists():
        tracker = pd.read_csv(tracker_filepath)
        tracker["spp_to_do"] = tracker["spp_to_do"].apply(ast.literal_eval) 
        tracker["spp_done"] = tracker["spp_done"].apply(ast.literal_eval) 
        start_idx = tracker.index[-1]
        spp_to_do = set(tracker.loc[start_idx].spp_to_do) - set(tracker.loc[start_idx].spp_done)
        if spp_to_do == set():
            prev_end_row = tracker.loc[start_idx].end_row
            prev_start_row = tracker.loc[start_idx].start_row
            if prev_end_row - prev_start_row < chunk_rows:
                print("No more data to process, total rows in dataset: ", prev_end_row)
                continue
            skiprows = tracker.loc[start_idx].end_row # This is the row after the last one we processed
            start_idx = start_idx + 1
            spp_to_do = None
        else:
            skiprows = tracker.loc[start_idx].start_row

    else:
        tracker = pd.DataFrame(columns=["start_row", "end_row", "spp_to_do", "spp_done"])
        start_idx = 0
        spp_to_do = None
        skiprows=0

    for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=10, header=0, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):
        print(f"Processing chunk {idx}...")
        print("Start", (start_idx + idx)*chunk_rows, "End", (start_idx + idx)*chunk_rows+chunk_rows)
        if chunk.shape[0] == 0:
            print("No more data to process, total rows in dataset: ", (start_idx + idx)*chunk_rows)
            break
        if chunk.shape[0] < chunk_rows:
            end_row = (start_idx+idx)*chunk_rows + chunk.shape[0]+1
            print(f"Last chunk, total rows in dataset:", end_row)
        else:
            end_row = (start_idx + idx)*chunk_rows+chunk_rows
        cleaned = clean_ebd(chunk)
        if spp_to_do == None: # Add new row
            spp_to_do = list(set(cleaned["SCIENTIFIC NAME"].unique()))
            tracker.loc[start_idx+idx] = [(start_idx + idx)*chunk_rows, end_row, spp_to_do, []]

        
        for sp in spp_to_do:
            cleaned_sp = cleaned[cleaned["SCIENTIFIC NAME"] == sp]
            if cleaned_sp.shape[0] == 0:
                #print(f"No data for {sp}")
                continue

            # Get list of subspecies
            subspp = []
            for k, val in spp_dict[sp]['infraspecies'].items():
                subspp.extend(val.keys())
            
            # Get data on presence of each subspp for each resolution
            for resolution in resolutions:
                species_df = get_species_df(sp, cleaned_sp, subspp, resolution)
                filename = ssp_batch_directory.joinpath(f'{sp}_row{(start_idx+idx)*chunk_rows}-{end_row}_resolution{resolution}.csv')
                species_df.to_csv(filename)

            tracker.loc[start_idx+idx].spp_done += [sp]
            tracker.to_csv(tracker_filepath, index=False)

        spp_to_do = None
    




Processing strher
No more data to process, total rows in dataset:  499586



Processing easmea
No more data to process, total rows in dataset:  2400001



Processing yerwar
No more data to process, total rows in dataset:  12200001



Processing eurjay1
No more data to process, total rows in dataset:  1200001



Processing brant
No more data to process, total rows in dataset:  800001



Processing whcspa
No more data to process, total rows in dataset:  6400001



Processing cacgoo1
No more data to process, total rows in dataset:  800001



Processing horlar
No more data to process, total rows in dataset:  1900001



Processing coatit2
No more data to process, total rows in dataset:  800001



Processing foxspa
No more data to process, total rows in dataset:  1900001



Processing daejun
No more data to process, total rows in dataset:  14296832



Processing orcwar
No more data to process, total rows in dataset:  2900001



Processing yebcha
No more data to process, total rows in data

## Sum up the batches

In [608]:
sp_cell_df_directory = Path('sp_cell_dfs/')
sp_cell_df_directory.mkdir(exist_ok=True)

def parse_batch_files(ssp_batch_directory):
    batch_files = list(ssp_batch_directory.glob("*.csv"))
    file_info = [n.name.split("_") for n in batch_files]
    files = pd.DataFrame(file_info, columns=['SCIENTIFIC NAME', 'ROW RANGE', 'RESOLUTION'])
    files['FILENAME'] = batch_files
    for (species, resolution), species_df in files.groupby(["SCIENTIFIC NAME", 'RESOLUTION']):
        species = species.replace(" ", "-")
        resolution = resolution[:-4]
        all_dataframes = [pd.read_csv(f, index_col=0) for f in species_df.FILENAME] 
        sp_cell_df = reduce(lambda a, b: a.add(b, fill_value=0), all_dataframes)
        filename = sp_cell_df_directory.joinpath(f'{species}_{resolution}.csv')
        sp_cell_df.to_csv(filename)



parse_batch_files(ssp_batch_directory)

# Create maps

In [31]:
import re

In [None]:
from shapely.geometry import Polygon, LineString, mapping
from shapely.ops import split
import json
import geopandas as gpd
import numpy as np
import networkx as nx
import colorsys
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_conversions import convert_color
from colormath import color_diff_matrix
#from colormath.color_diff import delta_e_cie2000 # deprecated and doesn't work anymore, reimplemented below

def get_infraspecies_relationships(sp, spp_dict=spp_dict):
    data = spp_dict[sp]['infraspecies']
    
    # Get a list of each type of infraspecies
    if "issf" in data.keys():
        issfs = [k.replace(sp+' ', '') for k in data["issf"].keys()] # Recognized ssp or ssp groups
    else:
        issfs = []
    if "form" in data.keys():
        forms = [k.replace(sp+' ', '') for k in data["form"].keys()] # Forms
    else:
        forms = []
    if "intergrade" in data.keys():
        intergrades = [k.replace(sp+' ', '') for k in data["intergrade"].keys()] # Intergrades (between ssp? forms?)
    else:
        intergrades = []

    intergrade_to_parents = dict()
    forms_to_parents = dict()
    top_level_intergrades = []
    top_level_forms = []

    # Find parents of the intergrades, if any are in the eBird taxonomy
    # Also determine which intergrades, if any, have no parents
    for intergrade in intergrades:
        parents = [i.strip() for i in intergrade.split('x')]
        # Check if all are true
        if all([p in issfs+forms for p in parents]):
            intergrade_to_parents[intergrade] = parents
        else:
            top_level_intergrades.append(intergrade)

    # Find parents of the forms, if any are in the eBird taxonomy
    # Also determine which forms, if any, have no parents
    for form in forms:
        # Split the form into its individual components
        form_parts = set(form.split("/"))
        
        parent_issfs = []
        for component in issfs:
            # Split the list component into subparts and check if all are in the form_parts
            component_parts = set(component.split("/"))
            if component_parts <= form_parts:  # Check if component_parts is a subset of form_parts
                parent_issfs.append(component)
        if len(parent_issfs):
            forms_to_parents[form] = parent_issfs
        else:
            top_level_forms.append(form)
        
    return issfs, forms, intergrades, intergrade_to_parents, forms_to_parents, top_level_intergrades, top_level_forms


def rgb_to_hex(rgb):
    """Convert an (R, G, B) tuple to a hex color (#RRGGBB)."""
    return "#{:02x}{:02x}{:02x}".format(*rgb)


def hex_to_rgb(hex_color):
    """Convert hex color (#RRGGBB) to an (R, G, B) tuple."""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def combine_rgb_colors(rgb_colors, fracs):
    """Combine a list of RGB colors proportionally."""
    if sum(fracs) == 0:
        return "#999999"
    else:
        combined_rgb = tuple(
            int(sum(frac * color[channel] for color, frac in zip(rgb_colors, fracs)))
            for channel in range(3)
        )
    return rgb_to_hex(combined_rgb)

def name_to_base_hue(name):
    """Generate a base hue from a name."""
    base_hue = hash(name) % 360
    return base_hue

def average_hues(hues):
    """Average a list of hues on the circular scale."""
    x = np.mean([np.cos(np.radians(h)) for h in hues])
    y = np.mean([np.sin(np.radians(h)) for h in hues])
    avg_hue = np.degrees(np.arctan2(y, x)) % 360
    return avg_hue


def delta_e_cie2000(color1, color2, Kl=1, Kc=1, Kh=1):
    """
    Calculates the Delta E (CIE2000) of two colors.
    """
    def _get_lab_color1_vector(color):
        return np.array([color.lab_l, color.lab_a, color.lab_b])
    def _get_lab_color2_matrix(color):
        return np.array([(color.lab_l, color.lab_a, color.lab_b)])

    color1_vector = _get_lab_color1_vector(color1)
    color2_matrix = _get_lab_color2_matrix(color2)
    delta_e = color_diff_matrix.delta_e_cie2000(
        color1_vector, color2_matrix, Kl=Kl, Kc=Kc, Kh=Kh)[0]
    return delta_e

def rgb_to_lab(rgb):
    """Convert an RGB color (0-255) to LAB color space."""
    srgb = sRGBColor(rgb[0] / 255, rgb[1] / 255, rgb[2] / 255, is_upscaled=False)
    return convert_color(srgb, LabColor)

def hsl_to_rgb(hue, saturation, lightness):
    """Convert HSL values to RGB (0-255)."""
    r, g, b = colorsys.hls_to_rgb(hue / 360, lightness, saturation)
    return (int(r * 255), int(g * 255), int(b * 255))


def is_color_too_similar(hue1, hue2, threshold=15):
    """
    Check if two hues are too similar, accounting for perceptual non-uniformity.
    Compare using the CIEDE2000 formula in the LAB color space.

    Higher threshold ==> colors need to be more different
    """

    # Convert hues to RGB colors using fixed saturation and lightness for comparison
    rgb1 = hsl_to_rgb(hue1, 0.8, 0.5)  # Vivid, medium lightness
    rgb2 = hsl_to_rgb(hue2, 0.8, 0.5)

    # Convert RGB to LAB for perceptual uniformity
    lab1 = rgb_to_lab(rgb1)
    lab2 = rgb_to_lab(rgb2)

    # Calculate perceptual difference using CIEDE2000
    delta_e = delta_e_cie2000(lab1, lab2)
    return delta_e < threshold


def create_distribution_adjacency_matrix(data, subspecies_cols, cell_col='cell_id'):
    """
    Create an adjacency matrix based on subspecies distribution similarities.

    Parameters:
    - data: DataFrame with cells as rows and subspecies counts as columns.
    - subspecies_cols: List of column names corresponding to subspecies counts.
    - cell_col: Column name for cell identifiers (optional, for reference).

    Returns:
    - adjacency_matrix: A NumPy array where element [i, j] is the similarity between subspecies distributions.
    - subspecies_list: The order of subspecies corresponding to matrix rows/columns.
    """
    # Subset the subspecies columns
    subspecies_data = data[subspecies_cols]

    # Normalize each cell's counts to proportions
    subspecies_distribution = subspecies_data.div(subspecies_data.sum(axis=1), axis=0).fillna(0)

    # Compute cosine similarity between each pair of subspecies
    adjacency_matrix = cosine_similarity(subspecies_distribution.T)

    # Return the matrix and list of subspecies
    return adjacency_matrix, subspecies_cols


def hue_to_hex_vibrant(hue):
    saturation, lightness = 0.8, 0.5  # Vivid, medium colors
    r, g, b = colorsys.hls_to_rgb(hue / 360, lightness, saturation)
    return rgb_to_hex((int(r * 255), int(g * 255), int(b * 255)))


def style_function(feature, subspp_colors):
    """Style a cell based on the proportion of subspecies."""
    properties = feature['properties']
    subspecies_values = dict()
    for subsp in subspp_colors.keys():
        val = properties.get(subsp, 0)
        if val > 0:
            subspecies_values[subsp] = val
    
    # Normalize the values to sum up to 1 for proportional allocation
    total = sum(subspecies_values.values())
    if total > 0:
        fracs = [value / total for value in subspecies_values.values()]
    else:
        fracs = [0 for _ in subspecies_values]
    
    # Get RGB colors for each subspecies
    hex_colors = [subspp_colors[subsp] for subsp in subspecies_values]
    rgb_colors = [hex_to_rgb(color) for color in hex_colors]
    
    # Combine colors based on the proportional fractions
    cell_color = combine_rgb_colors(rgb_colors, fracs)
    
    return {
        'fillColor': cell_color,  # Cell color
        'color': cell_color,  # Border color
        'weight': 1,  # Border weight
        'fillOpacity': 0.6,  # Cell fill transparency
    }

def calculate_overlap_intensity(overlap_matrix):
    """
    Calculate the overlap intensity for each subspecies.
    Overlap intensity is defined as the proportion of cells that overlap with others.

    Parameters:
    - overlap_matrix: A square matrix where overlap_matrix[i][j] represents the overlap between subspecies[i] and subspecies[j].

    Returns:
    - A list of overlap intensities for each subspecies.
    """
    total_overlap = np.sum(overlap_matrix, axis=1)  # Total overlap for each subspecies
    max_overlap = np.sum(overlap_matrix, axis=1) - np.diagonal(overlap_matrix)  # Exclude self-overlap
    return total_overlap / max_overlap


def generate_priority_hues(subspecies, overlap_matrix):
    """
    Assign perceptually distinct hues to subspecies, prioritizing those with higher overlap intensity.

    Parameters:
    - subspecies: List of subspecies names.
    - overlap_matrix: A square matrix where overlap_matrix[i][j] represents the overlap between subspecies[i] and subspecies[j].

    Returns:
    - A dictionary mapping each subspecies to a distinct hue (in degrees, 0-360).
    """
    n_subspecies = len(subspecies)

    # Step 1: Generate perceptually distinct hues (0-360 degrees)
    hues = np.linspace(0, 360, n_subspecies, endpoint=False)

    # Step 2: Calculate overlap intensity
    overlap_intensity = calculate_overlap_intensity(overlap_matrix)

    # Step 3: Assign hues based on overlap intensity and relationships
    G = nx.Graph()
    for i, sp1 in enumerate(subspecies):
        for j, sp2 in enumerate(subspecies):
            if overlap_matrix[i][j] > 0.1:  # Add edge if overlap is above threshold
                G.add_edge(sp1, sp2, weight=overlap_matrix[i][j])

    # Sort subspecies by overlap intensity
    subspecies_sorted = sorted(zip(subspecies, overlap_intensity), key=lambda x: x[1], reverse=True)

    # Assign hues sequentially to prioritize highly overlapping subspecies
    assigned_hues = {}
    used_hues = set()
    for subsp, _ in subspecies_sorted:
        # Find the most distinct unused hue
        best_hue = None
        max_dist = -1
        for i, hue in enumerate(hues):
            if i in used_hues:
                continue
            # Check perceptual distance to already assigned hues
            if assigned_hues:
                dist = np.min(
                    [min(abs(hue - assigned_hues[sp]), 360 - abs(hue - assigned_hues[sp])) for sp in assigned_hues]
                )
            else:
                dist = float("inf")
            
            if dist > max_dist:
                max_dist = dist
                best_hue = i

        # Assign the best hue
        assigned_hues[subsp] = hues[best_hue]
        used_hues.add(best_hue)
    
    return assigned_hues


def get_bounds(geojson_result):
    """
    Calculate the bounding box of all features in the GeoJSON.

    Args:
    - geojson_result: GeoJSON string with features.

    Returns:
    - Bounds as [[southwest_lat, southwest_lon], [northeast_lat, northeast_lon]].
    """
    import json
    #geojson_data = json.loads(geojson_result)
    geojson_data = geojson_result
    all_coords = []

    for feature in geojson_data['features']:
        # Extract all coordinates from the polygon or multipolygon
        coords = feature['geometry']['coordinates']
        if feature['geometry']['type'] == "Polygon":
            all_coords.extend(coords[0])  # Add outer ring of the polygon
        elif feature['geometry']['type'] == "MultiPolygon":
            for poly in coords:
                all_coords.extend(poly[0])  # Add outer ring of each polygon

    # Extract longitudes (x) and latitudes (y) correctly
    lons, lats = zip(*all_coords)
    return [[min(lats), min(lons)], [max(lats), max(lons)]]


def split_polygon_at_line(polygon, line):
    """
    Splits a GeoJSON polygon at a given GeoJSON line.
    
    Parameters:
        polygon (dict): shapely representation of the polygon.
        line (dict): shapely representation of the line.
    
    Returns:
        list: A list of GeoJSON polygons after splitting.
    """
    
    # Perform the split
    split_result = split(polygon, line)
    
    # Convert Shapely polygons back to GeoJSON
    split_geojsons = [mapping(geom) for geom in split_result.geoms]
    
    return split_geojsons


def split_at_dateline(geojson):
    """
    Splits polygons that cross the International Date Line (180° longitude).
    
    Parameters:
    - geojson: A GeoJSON-like dictionary of geometries.
    
    Returns:
    - A new GeoJSON with adjusted geometries.
    """
    dateline = LineString([(180, 90), (180, -90)])
    adjusted_polygons = []
    
    for feature in geojson['features']:
        geom = feature['geometry']
        polygon = Polygon(geom['coordinates'][0])

        # Check if the polygon crosses the dateline
        if not polygon.is_valid:
            polygon = polygon.buffer(0)  # Fix invalid geometries

        # Check if any of the polygon's coordinates are within 5 degrees of the dateline
        if any(abs(lon) > 170 for lon, _ in polygon.exterior.coords) and any(abs(lon1 - lon2) > 180 for lon1, lon2 in zip(polygon.exterior.xy[0][:-1], polygon.exterior.xy[0][1:])):
            # For all the longitudes that are less than -170, add 360 to them
            new_coords = [
                [(x + 360 if x < -170 else x, y) for x, y in polygon.exterior.coords]
            ]
            adjusted_polygons.append({
                "type": "Feature",
                "geometry": {
                    "type": "Polygon",
                    "coordinates": new_coords
                },
                "properties": feature['properties']
            })
        else:
            adjusted_polygons.append(feature)
    
    return {
        "type": "FeatureCollection",
        "features": adjusted_polygons
    }


def get_range_size(sp_cell_df, ssp):
    """Calculate the range size of a subspecies."""
    return (sp_cell_df[ssp] > 0).astype(int).sum()

def get_color_mapping(sp_cell_df):
    # Get the relationships between the infraspecies (ISSFs, forms, intergrades)
    issfs, forms, intergrades, inter_to_p, form_to_p, top_inter, top_form = get_infraspecies_relationships(species, spp_dict)

    # Get colors for the top-level infraspecies based on their geographic overlap
    top_level_infras = [*issfs, *top_inter, *top_form]
    sp_cell_df.columns = sp_cell_df.columns.str.replace(species + ' ', "")
    overlap_matrix, top_level_infras = create_distribution_adjacency_matrix(sp_cell_df, top_level_infras)
    ssp_hues = generate_priority_hues(top_level_infras, overlap_matrix)

    # Get colors for the lower-level infraspecies by averaging their "parent" infraspecies
    for intergrade, parents in inter_to_p.items():
        ssp_hues[intergrade] = average_hues([ssp_hues[parent] for parent in parents])
    for form, parents in form_to_p.items():
        ssp_hues[form] = average_hues([ssp_hues[parent] for parent in parents])

    # Convert hues to vibrant colors
    ssp_colors = {ssp: hue_to_hex_vibrant(hue) for ssp, hue in ssp_hues.items()}

    # Organize subspecies into categories and calculate their range sizes (from largest to smallest)
    sorted_issfs = sorted(issfs, key=lambda ssp: get_range_size(sp_cell_df, ssp), reverse=True)
    sorted_forms = sorted(forms, key=lambda ssp: get_range_size(sp_cell_df, ssp), reverse=True)
    sorted_intergrades = sorted(intergrades, key=lambda ssp: get_range_size(sp_cell_df, ssp), reverse=True)

    return ssp_colors, sorted_issfs, sorted_forms, sorted_intergrades

def get_ssp_common_name(species, subsp):
    # Look up the PRIMARY_COM_NAME for the subspecies
    subsp_name = taxonomy[taxonomy['SCI_NAME'] == f"{species} {subsp}"].PRIMARY_COM_NAME.values[0]
    # Get the part in the parentheses
    subsp_name = subsp_name.split('(')[-1].split(')')[0] if '(' in subsp_name else subsp
    return subsp_name


def reduce_precision(geojson_data, precision):
    def round_coords(coords):
        return [[round(coord[0], precision), round(coord[1], precision)] for coord in coords]
    
    def process_feature(feature):
        if feature['geometry']['type'] == 'Polygon':
            feature['geometry']['coordinates'] = [round_coords(coords) for coords in feature['geometry']['coordinates']]
        elif feature['geometry']['type'] == 'MultiPolygon':
            feature['geometry']['coordinates'] = [
                [round_coords(coords) for coords in polygon]
                for polygon in feature['geometry']['coordinates']
            ]
        return feature

    geojson_data['features'] = [process_feature(feature) for feature in geojson_data['features']]
    return geojson_data

def choropleth_map(sp_cell_df, common_name, subspp_colors, sorted_issfs, sorted_forms, sorted_intergrades):
    """Creates a choropleth map given species data."""
    
    f = folium.Figure()
    map = folium.Map(location=[47, -122], zoom_start=5, tiles="cartodbpositron", control_scale=True)
    f.add_child(map)

    sp = sp_cell_df.columns[0]
    subspp = sp_cell_df.columns[1:]
    
    # Create polygon features with 
    list_features = []
    for _, row in sp_cell_df.iterrows():
        # Calculate the relative percentage of sightings of each subspecies for each cell
        percentages = (row[subspp] / sum(row[subspp])) * 100
        percentages = round(percentages, 0)
        percentages_dict = percentages.to_dict()
        
        # Calculate total reports across all subspecies for each cell
        total_reports = row[subspp].sum()

        # Precompute tooltip text showing only subspecies with nonzero percentages
        percentages_dict_ordered = pd.DataFrame(percentages_dict, index=['pct']).T.query('pct > 0')['pct'].sort_values(ascending=False).to_dict()
        
        # Start tooltip container
        tt1_style = '''.tt1 {min-width: 100px; max-width: 300px; overflow: auto;}'''
        tooltip_text = '<div class="tt1">'

        # Add header and create a flex container for total reports and percentages

        tt2_style = '''.tt2 {display: flex; justify-content: space-between; margin-top: 5px;}'''
        tt3_style = '''.tt3 {margin-right: 10px;}'''
        tt4_style = '''.tt4 {text-align: right; background: ;}'''
        total_reports = row[subspp].sum()
        tooltip_text += f'''<div class="tt2"><div class="tt3"><strong>Reported taxa</strong><br>Total reports: {total_reports:.0f}</div><div class="tt4">'''
        # Add percentages to the right-aligned div
        if total_reports > 0:
            for subsp, percent in percentages_dict_ordered.items():
                if percent > 0:
                    subsp_common_name = get_ssp_common_name(sp, subsp)
                    tooltip_text += f"<div>{subsp_common_name}: {percent:.0f}%</div>" if percent > 0.5 else ""
        else:
            tooltip_text += '''</div>None</div>'''

        # Close all divs
        tooltip_text += '''</div></div></div>'''

        # Add the tooltip to the feature's properties
        percentages_dict_ordered["tooltip"] = tooltip_text

        # Convert the H3 cell into a geometry
        geometry_for_row = h3.cells_to_geo(cells=[row.name])

        # Add a GeoJSON Feature to the list of features
        feature = Feature(
            geometry=geometry_for_row,
            id=row.name,
            properties=percentages_dict_ordered
        )
        list_features.append(feature)

    # Create a GeoJSON FeatureCollection with the list of features
    feat_collection = FeatureCollection(list_features)
    geojson_result = json.dumps(feat_collection)

    # Deal with geometries that cross the International Date Line
    geojson_result = split_at_dateline(json.loads(geojson_result))

    # Reduce precision to 3 digits (111 meters)
    geojson_result = reduce_precision(geojson_result, precision=3)
    
    # Add GeoJSON layer to the map
    folium.GeoJson(
        geojson_result,
        style_function=lambda feature: style_function(feature, subspp_colors),
        name=f'{sp} Subspecies Map'
    ).add_to(map)
    
    # Add a tooltip to the GeoJSON layer
    folium.GeoJson(
        geojson_result,
        style_function=lambda feature: {
            'weight': 0, 
            'color': 'transparent',  
            'fillOpacity': 0.6 
        },
        tooltip=GeoJsonTooltip(
            fields=["tooltip"],
            aliases=[None],  # No additional label prepended
            localize=True,
            sticky=True,
            labels=False,  # Disable default labels
            html=True  # Enable HTML formatting in the tooltip
        )
    ).add_to(map)


    # Add legend with subheaders
    legend_html = f"""
    <div style="position: fixed; top: 10px; right: 10px; max-width: 200px; height: auto; z-index: 9999; background-color: white; box-shadow: 0 0 5px rgba(0, 0, 0, 0.2); border: 1px solid lightgray; border-radius: 5px; padding: 10px; font-size: 10px;">
        <div style="font-size: 12px;"><strong>{common_name}</strong></div>
    """
    
    # Add subheaders and sorted species
    issf_colors = {k: subspp_colors[k] for k in sorted_issfs}
    form_colors = {k: subspp_colors[k] for k in sorted_forms}
    intergrade_colors = {k: subspp_colors[k] for k in sorted_intergrades}
    categories = [("Subspecies", issf_colors), ("Forms", form_colors), ("Intergrades", intergrade_colors)]
    # Remove any categories where the dictionary has no keys
    categories = [(category, ssp_in_category) for category, ssp_in_category in categories if ssp_in_category]

    # Create species and colors part of legend
    for category, sorted_subspecies in categories:
        legend_html += f'<div style="margin-top:10px;"><b>{category}</b></div>'
        for subsp in sorted_subspecies:
            subsp_common_name = get_ssp_common_name(sp, subsp)
            if subsp_common_name == subsp.strip('[]'): # Only display one name if no unique common name
                display_name = subsp
            else:
                display_name = f"{subsp_common_name} ({subsp})"

            display_name = '/<wbr>'.join(display_name.split('/')) # Add available wordbreaks for long slash names
            color = subspp_colors.get(subsp, "#ccc")  # Default color if no color is found

            # Get link to the subspecies's eBird species account
            subsp_code = taxonomy[taxonomy['SCI_NAME'] == f"{sp} {subsp}"].SPECIES_CODE.values[0]
            subsp_link = f"https://ebird.org/species/{subsp_code}"
            

            subsp_display = f"""
            <div style="display: inline-block; max-width: 150px; white-space: normal; overflow-wrap: break-word;">
                <a href="{subsp_link}" target="_blank">{display_name}</a>
            </div>
            """
            # Format subspecies name next to colors
            legend_html += f"""
            <div style="margin-top: 5px;">
                <span style="display: inline-block; width: 20px; height: 10px; margin-right: 5px; background-color: {color};"></span>
                {subsp_display}
            </div>
            """
    legend_html += "</div>"
    
    legend_element = folium.Element(legend_html)
    map.get_root().html.add_child(legend_element)
    map.get_root().html.add_child(folium.Element(f"""
        <style>
            {tt1_style}
            {tt2_style}
            {tt3_style}
            {tt4_style}
        </style>"""))

    # Calculate bounds and adjust the map's view
    bounds = get_bounds(geojson_result)
    map.fit_bounds(bounds)

    # TODO: String manipulations to make HTML smaller?
    string_so_far = map.get_root().render()
    return string_so_far


remake_maps = True
for sp_code in sp_codes:
    subspp_colors = None
    print("\nMapping", sp_code)
    common_name = taxonomy[taxonomy['SPECIES_CODE'] == sp_code].PRIMARY_COM_NAME.values[0]
    for resolution in resolutions: #[2,3,4]
        species = taxonomy[taxonomy['PRIMARY_COM_NAME'] == common_name].SCI_NAME.values[0]
        dataname = f"sp_cell_dfs/{species.replace(' ', '-')}_resolution{resolution}.csv"
        if not Path(dataname).exists():
            print("No data for", species, "at resolution", resolution)
            continue
        map_filename = f"docs/maps/{species.replace(' ', '-')}_{resolution}.html"
        if Path(map_filename).exists() and not remake_maps:
            print("Map already exists for", species, "at resolution", resolution)
            continue
        sp_cell_df = pd.read_csv(dataname, index_col=0)
        sp_cell_df.columns = sp_cell_df.columns.str.replace(species + ' ', "")
        subspecies = sp_cell_df.columns[1:]
        if resolution == 2 or subspp_colors == None:
            subspp_colors, sorted_issfs, sorted_forms, sorted_intergrades = get_color_mapping(sp_cell_df)
            
        m = choropleth_map(sp_cell_df, common_name, subspp_colors, sorted_issfs, sorted_forms, sorted_intergrades)
        map_filename = f"docs/maps/{species.replace(' ', '-')}_{resolution}.html"
        #m.save(map_filename)
        with open(map_filename, 'w') as f:
            f.write(m)

## Create a CSV of map URLs for the website

In [100]:
df = pd.DataFrame(columns=["common_name", "scientific_name", "resolution", "map_url"])
maps_dir = Path("docs/maps")

# Sort maps taxonomically
maps_list = list(maps_dir.glob("*.html"))
species = {p.name.split('_')[0].replace('-', ' '): p for p in maps_list}
ordering = taxonomy[['SCI_NAME', 'TAXON_ORDER']].set_index("SCI_NAME").to_dict()['TAXON_ORDER']
maps_list = sorted(maps_list, key=lambda x: ordering[x.name.split('_')[0].replace('-', ' ')])

# Create table of common name, species, resolution, and map URL
# This is used by the website
for idx, file in enumerate(maps_list):
    resolution = file.stem.split("_")[-1]
    species = file.stem.replace(f"_{resolution}", "")
    common_name = taxonomy[taxonomy['SCI_NAME'] == species.replace('-', ' ')].PRIMARY_COM_NAME.values[0]
    #map_url = Path(Path(file).parent.stem).joinpath(Path(file).name)
    map_url = 'https://subspeciesmapper.netlify.app/' + str(file.relative_to(maps_dir))
    print(map_url)
    df.loc[idx] = [common_name, species, resolution, map_url]
df.to_csv("docs/data/map_data.csv", index=False)

https://subspeciesmapper.netlify.app/Branta-bernicla_3.html
https://subspeciesmapper.netlify.app/Branta-bernicla_2.html
https://subspeciesmapper.netlify.app/Branta-bernicla_4.html
https://subspeciesmapper.netlify.app/Branta-hutchinsii_4.html
https://subspeciesmapper.netlify.app/Branta-hutchinsii_2.html
https://subspeciesmapper.netlify.app/Branta-hutchinsii_3.html
https://subspeciesmapper.netlify.app/Branta-canadensis_4.html
https://subspeciesmapper.netlify.app/Branta-canadensis_3.html
https://subspeciesmapper.netlify.app/Branta-canadensis_2.html
https://subspeciesmapper.netlify.app/Somateria-mollissima_4.html
https://subspeciesmapper.netlify.app/Somateria-mollissima_2.html
https://subspeciesmapper.netlify.app/Somateria-mollissima_3.html
https://subspeciesmapper.netlify.app/Butorides-striata_4.html
https://subspeciesmapper.netlify.app/Butorides-striata_2.html
https://subspeciesmapper.netlify.app/Butorides-striata_3.html
https://subspeciesmapper.netlify.app/Buteo-jamaicensis_4.html
https