In [1]:
import pandas as pd
from pathlib import Path
import h3
import folium
from geojson import Feature, Point, FeatureCollection
import json
import matplotlib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import ast
pd.options.mode.copy_on_write = True 
from functools import reduce
from folium import GeoJsonTooltip



## Create JSON of subspecies

In [4]:
# Load eBird taxonomy
taxonomy = pd.read_csv("eBird_taxonomy_v2024.csv")
taxonomy.head()

Unnamed: 0,TAXON_ORDER,CATEGORY,SPECIES_CODE,TAXON_CONCEPT_ID,PRIMARY_COM_NAME,SCI_NAME,ORDER,FAMILY,SPECIES_GROUP,REPORT_AS
0,2,species,ostric2,,Common Ostrich,Struthio camelus,Struthioniformes,Struthionidae (Ostriches),Ostriches,
1,7,species,ostric3,,Somali Ostrich,Struthio molybdophanes,Struthioniformes,Struthionidae (Ostriches),Ostriches,
2,8,slash,y00934,,Common/Somali Ostrich,Struthio camelus/molybdophanes,Struthioniformes,Struthionidae (Ostriches),Ostriches,
3,10,species,soucas1,,Southern Cassowary,Casuarius casuarius,Casuariiformes,Casuariidae (Cassowaries and Emu),Cassowaries and Emu,
4,11,species,dwacas1,,Dwarf Cassowary,Casuarius bennetti,Casuariiformes,Casuariidae (Cassowaries and Emu),Cassowaries and Emu,


In [523]:
# What categories are there in the taxonomy?
taxonomy.CATEGORY.unique()


array(['species', 'slash', 'issf', 'hybrid', 'spuh', 'domestic', 'form',
       'intergrade'], dtype=object)

In [524]:
# How many species are there?
species = taxonomy[taxonomy.CATEGORY == 'species']
len(species)


11145

In [525]:
# How many infraspecific entries are there?
infrasp_categories = ['issf', 'form', 'intergrade']
infraspp = taxonomy[taxonomy.CATEGORY.isin(infrasp_categories)]
len(infraspp)

3843

In [None]:
# Create a dictionary mapping species to their infraspecies, by category
# Most species have a single category of infraspecies (e.g. either form or subspecies, not both)
# However some species have infraspecies in multiple categories, e.g. Brant (Branta bernicla) has both subspecies and forms
spp_dict = species[['SPECIES_CODE', 'PRIMARY_COM_NAME', 'SCI_NAME']].set_index("SCI_NAME").T.to_dict()
# Add infraspecies to spp_json
for sp in tqdm(spp_dict.keys()):
    # Get infraspecies for this species
    infraspp_for_sp = infraspp[infraspp['SCI_NAME'].apply(lambda x: x[:len(sp)] == sp)]
    infraspp_dict = dict()

    # Add infraspecies to spp_json by category
    for cat in infrasp_categories:
        infrasp_in_category = infraspp_for_sp[infraspp_for_sp.CATEGORY == cat]
        infrasp_cat_dict = infrasp_in_category[
            ['SPECIES_CODE', 'PRIMARY_COM_NAME', 'SCI_NAME']].set_index("SCI_NAME").T.to_dict()
        if len(infrasp_cat_dict.keys()) > 0:
            infraspp_dict[cat] = infrasp_cat_dict
    spp_dict[sp]['infraspecies'] = infraspp_dict
    

In [278]:
with open("infraspecies_ebird.json", 'w') as f:
    f.write(json.dumps(spp_dict, indent=4))

In [2]:
with open("infraspecies_ebird.json") as f:
    spp_dict = json.load(f)

## Prep eBird data

Use the following H3 resolutions:
```
Res	Average Hexagon Area (km2)	Pentagon Area* (km2)	Ratio (P/H)
2	86,801.780398997	44,930.898497879	0.5176
3	12,393.434655088	6,315.472267516	0.5096
4	1,770.347654491	896.582383141	0.5064
5	252.903858182	127.785583023	0.5053
```

They have this many cells:
```
Res	Total number of cells	Number of hexagons	Number of pentagons
2	5,882	5,870	12
3	41,162	41,150	12
4	288,122	288,110	12
5	2,016,842	2,016,830	12
```

## Determine number of sightings of each subspecies per grid cell

### Calculate in batches

In [607]:

resolutions = [2,3,4,5]
def clean_ebd(
    full_df,
    remove_unconfirmed=True, 
    remove_reviewed=False,
    resolutions = resolutions,
):

    # Remove duplicate checklists
    full_df = full_df[full_df['GROUP IDENTIFIER'].isnull() | ~full_df[full_df['GROUP IDENTIFIER'].notnull()].duplicated(subset=["GROUP IDENTIFIER", "SCIENTIFIC NAME"],keep='first')]

    # Removed unconfirmed observations or reviewed observations, if desired
    if remove_unconfirmed:
        full_df = full_df[full_df["APPROVED"] == 1]
    if remove_reviewed:
        full_df = full_df[full_df["REVIEWED"] == 0]

    # Just subset to the needed columns
    needed_columns = [
        'TAXONOMIC ORDER','CATEGORY', 'TAXON CONCEPT ID', 'COMMON NAME', 
        'SCIENTIFIC NAME','SUBSPECIES COMMON NAME', 'SUBSPECIES SCIENTIFIC NAME',
        'SAMPLING EVENT IDENTIFIER',
        'LATITUDE', 'LONGITUDE', 'REVIEWED', 'OBSERVATION DATE']
    full_df = full_df[needed_columns]
    full_df.head()

    # Convert latitude and longitude to an H3 hexagon ID
    for resolution in resolutions:
        full_df[f'hex_id_{resolution}'] = full_df.apply(lambda row:  h3.latlng_to_cell(row.LATITUDE, row.LONGITUDE, resolution), axis=1)
    
    return full_df

def get_grid_cell_species_data(cell_df, sp, subspp, resolution):
    """Get # of checklists containing a species and each subspecies

    Args:
    - cell_df: pd.DataFrame, dataframe of data for a single grid cell (1 row per observation)
    - sp: str, scientific name of species
    - subspp: list of str, scientific names of subspecies for this species

    Returns:
    - cell_data: dict, with keys 'cell_id', species name, and subspecies names
    """
    # Total number of checklists containing the species
    num_checklists = cell_df["SAMPLING EVENT IDENTIFIER"].nunique()

    # Create a dict of # checklists containing sp for all cells
    cell_data = {'cell_id': cell_df[f"hex_id_{resolution}"].iloc[0]}
    cell_data[sp] = num_checklists

    # Add number of checklists containing each subspecies
    for subsp in subspp:
        num_subsp = cell_df[cell_df["SUBSPECIES SCIENTIFIC NAME"] == subsp].shape[0]
        cell_data[subsp] = num_subsp

    return cell_data


def get_species_df(sp, sp_df, subspp, resolution):
    """Make dataframe of species & subspecies data for every cell for a given species

    Args:
    - sp: str, scientific name of species
    - df: pd.DataFrame, dataframe of data for this species
    - subspp: list of str, scientific names of subspecies for this species
    - resolution: int, H3 resolution level
    """

    # Create a dict of # checklists containing sp for all cells
    cell_dicts = []
    for cell in sp_df[f"hex_id_{resolution}"].unique():
        cell_df = sp_df[sp_df[f"hex_id_{resolution}"] == cell]
        cell_data = get_grid_cell_species_data(cell_df, sp, subspp, resolution)
        cell_dicts.append(cell_data)

    sp_cell_df = pd.DataFrame(cell_dicts, index=range(len(cell_dicts)))
    sp_cell_df.set_index("cell_id", inplace=True)

    return sp_cell_df



spp_dict = json.load(open("infraspecies_ebird.json"))

# caja_df = get_species_df(sp, full_df, subspp_dict)
# filepath = Path('batches').joinpath(filename)
# caja_df.to_csv(filepath)

#dataset_filepath = "ebd-sample.txt"
#sp_code = 'rethaw'
use_cols = [
        'TAXONOMIC ORDER','CATEGORY', 'TAXON CONCEPT ID', 'COMMON NAME', 
        'SCIENTIFIC NAME','SUBSPECIES COMMON NAME', 'SUBSPECIES SCIENTIFIC NAME',
        'SAMPLING EVENT IDENTIFIER',
        'LATITUDE', 'LONGITUDE', 'REVIEWED', 'APPROVED', 'GROUP IDENTIFIER', 'OBSERVATION DATE']

sp_codes = [x.name.split('_')[1] for x in list(Path("data/").glob("*.zip"))]
for sp_code in sp_codes:
    print("\n\n\nProcessing", sp_code)
    dataset_filepath = f"data/ebd_{sp_code}_relOct-2024/ebd_{sp_code}_relOct-2024.txt"

    resolution = resolutions[0]

    ssp_batch_directory = Path('batches/')
    ssp_batch_directory.mkdir(exist_ok=True)

    # Read in CSV in batches
    chunk_rows = 100000
    tracker_filepath = f"{sp_code}_tracker_rowsperchunk-{chunk_rows}.csv"

    if Path(tracker_filepath).exists():
        tracker = pd.read_csv(tracker_filepath)
        tracker["spp_to_do"] = tracker["spp_to_do"].apply(ast.literal_eval) 
        tracker["spp_done"] = tracker["spp_done"].apply(ast.literal_eval) 
        start_idx = tracker.index[-1]
        spp_to_do = set(tracker.loc[start_idx].spp_to_do) - set(tracker.loc[start_idx].spp_done)
        if spp_to_do == set():
            skiprows = tracker.loc[start_idx].end_row
            start_idx = start_idx + 1
            spp_to_do = None
        else:
            skiprows = tracker.loc[start_idx].start_row

    else:
        tracker = pd.DataFrame(columns=["start_row", "end_row", "spp_to_do", "spp_done"])
        start_idx = 0
        spp_to_do = None
        skiprows=0

    # TODO: DEAL WITH BUG (BELOW)
    # SWITCH TO DASK TO PARALLELIZE
    for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=chunk_rows, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):
        if chunk.shape[0] == 0:
            print("No more data to process, total rows in dataset: ", (start_idx + idx)*chunk_rows)
            break
        if chunk.shape[0] < chunk_rows:
            # Some kind of weird bug/issue with the last chunk 
            # which finds a single row left to process claiming to be in the next 100,000 rows after the last one
            # This only happens after the first time I rerun this cell
            end_row = (start_idx+idx)*chunk_rows + chunk.shape[0]
            print(f"Last chunk, total rows in dataset:", end_row)
        else:
            end_row = (start_idx + idx)*chunk_rows+chunk_rows
        cleaned = clean_ebd(chunk)
        if spp_to_do == None: # Add new row
            spp_to_do = list(set(cleaned["SCIENTIFIC NAME"].unique()))
            tracker.loc[start_idx+idx] = [(start_idx + idx)*chunk_rows, end_row, spp_to_do, []]

        
        for sp in spp_to_do:
            cleaned_sp = cleaned[cleaned["SCIENTIFIC NAME"] == sp]
            if cleaned_sp.shape[0] == 0:
                #print(f"No data for {sp}")
                continue

            # Get list of subspecies
            subspp = []
            for k, val in spp_dict[sp]['infraspecies'].items():
                subspp.extend(val.keys())
            
            # Get data on presence of each subspp for each resolution
            for resolution in resolutions:
                species_df = get_species_df(sp, cleaned_sp, subspp, resolution)
                filename = ssp_batch_directory.joinpath(f'{sp}_row{(start_idx+idx)*chunk_rows}-{end_row}_resolution{resolution}.csv')
                species_df.to_csv(filename)

            tracker.loc[start_idx+idx].spp_done += [sp]
            tracker.to_csv(tracker_filepath, index=False)

        spp_to_do = None
    




Processing strher
No more data to process, total rows in dataset:  600000



Processing easmea
No more data to process, total rows in dataset:  2500000



Processing yerwar
No more data to process, total rows in dataset:  12300000



Processing eurjay1
No more data to process, total rows in dataset:  1300000



Processing brant
No more data to process, total rows in dataset:  900000



Processing whcspa
No more data to process, total rows in dataset:  6500000



Processing cacgoo1
No more data to process, total rows in dataset:  900000



Processing horlar
No more data to process, total rows in dataset:  2000000



Processing coatit2
No more data to process, total rows in dataset:  900000



Processing foxspa
No more data to process, total rows in dataset:  2000000



Processing daejun
No more data to process, total rows in dataset:  14300000



Processing orcwar
Last chunk, total rows in dataset: 2900001



Processing yebcha
Last chunk, total rows in dataset: 1000001



Processing 

  for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=chunk_rows, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):
  for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=chunk_rows, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):


Last chunk, total rows in dataset: 11778727



Processing comeid


  for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=chunk_rows, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):


Last chunk, total rows in dataset: 1072050



Processing perfal


  for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=chunk_rows, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):
  for idx, chunk in enumerate(pd.read_csv(dataset_filepath, chunksize=chunk_rows, skiprows=range(1,skiprows), usecols=use_cols, sep="\t")):


Last chunk, total rows in dataset: 1759264



Processing redcro
Last chunk, total rows in dataset: 797057


## Sum up the batches

In [608]:
sp_cell_df_directory = Path('sp_cell_dfs/')
sp_cell_df_directory.mkdir(exist_ok=True)

def parse_batch_files(ssp_batch_directory):
    batch_files = list(ssp_batch_directory.glob("*.csv"))
    file_info = [n.name.split("_") for n in batch_files]
    files = pd.DataFrame(file_info, columns=['SCIENTIFIC NAME', 'ROW RANGE', 'RESOLUTION'])
    files['FILENAME'] = batch_files
    for (species, resolution), species_df in files.groupby(["SCIENTIFIC NAME", 'RESOLUTION']):
        species = species.replace(" ", "-")
        resolution = resolution[:-4]
        all_dataframes = [pd.read_csv(f, index_col=0) for f in species_df.FILENAME] 
        sp_cell_df = reduce(lambda a, b: a.add(b, fill_value=0), all_dataframes)
        filename = sp_cell_df_directory.joinpath(f'{species}_{resolution}.csv')
        sp_cell_df.to_csv(filename)



parse_batch_files(ssp_batch_directory)

# Create maps

In [67]:
is_color_too_similar(150, 90)

False

In [90]:
from shapely.geometry import Polygon, MultiPolygon, mapping
from shapely.ops import split
import geopandas as gpd
import numpy as np

import networkx as nx
import numpy as np
import colorsys
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from sklearn.metrics.pairwise import pairwise_distances
import networkx as nx
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt

import colorsys
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity


def get_infraspecies_relationships(sp, spp_dict=spp_dict):
    data = spp_dict[sp]['infraspecies']
    
    # Get a list of each type of infraspecies
    if "issf" in data.keys():
        issfs = [k.replace(sp+' ', '') for k in data["issf"].keys()] # Recognized ssp or ssp groups
    else:
        issfs = []
    if "form" in data.keys():
        forms = [k.replace(sp+' ', '') for k in data["form"].keys()] # Forms
    else:
        forms = []
    if "intergrade" in data.keys():
        intergrades = [k.replace(sp+' ', '') for k in data["intergrade"].keys()] # Intergrades (between ssp? forms?)
    else:
        intergrades = []

    intergrade_to_parents = dict()
    forms_to_parents = dict()
    top_level_intergrades = []
    top_level_forms = []

    # Find parents of the intergrades, if any are in the eBird taxonomy
    # Also determine which intergrades, if any, have no parents
    for intergrade in intergrades:
        parents = [i.strip() for i in intergrade.split('x')]
        # Check if all are true
        if all([p in issfs+forms for p in parents]):
            intergrade_to_parents[intergrade] = parents
        else:
            top_level_intergrades.append(intergrade)

    # Find parents of the forms, if any are in the eBird taxonomy
    # Also determine which forms, if any, have no parents
    for form in forms:
        # Split the form into its individual components
        form_parts = set(form.split("/"))
        
        parent_issfs = []
        for component in issfs:
            # Split the list component into subparts and check if all are in the form_parts
            component_parts = set(component.split("/"))
            if component_parts <= form_parts:  # Check if component_parts is a subset of form_parts
                parent_issfs.append(component)
        if len(parent_issfs):
            forms_to_parents[form] = parent_issfs
        else:
            top_level_forms.append(form)
        
    return issfs, forms, intergrades, intergrade_to_parents, forms_to_parents, top_level_intergrades, top_level_forms


def rgb_to_hex(rgb):
    """Convert an (R, G, B) tuple to a hex color (#RRGGBB)."""
    return "#{:02x}{:02x}{:02x}".format(*rgb)


def hex_to_rgb(hex_color):
    """Convert hex color (#RRGGBB) to an (R, G, B) tuple."""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def combine_rgb_colors(rgb_colors, fracs):
    """Combine a list of RGB colors proportionally."""
    if sum(fracs) == 0:
        return "#999999"
    else:
        combined_rgb = tuple(
            int(sum(frac * color[channel] for color, frac in zip(rgb_colors, fracs)))
            for channel in range(3)
        )
    return rgb_to_hex(combined_rgb)

def name_to_base_hue(name):
    """Generate a base hue from a name."""
    base_hue = hash(name) % 360
    return base_hue

def average_hues(hues):
    """Average a list of hues on the circular scale."""
    x = np.mean([np.cos(np.radians(h)) for h in hues])
    y = np.mean([np.sin(np.radians(h)) for h in hues])
    avg_hue = np.degrees(np.arctan2(y, x)) % 360
    return avg_hue


from colormath.color_objects import sRGBColor, LabColor
from colormath.color_conversions import convert_color
from colormath import color_diff_matrix
#from colormath.color_diff import delta_e_cie2000 # deprecated and doesn't work anymore, reimplemented below

def delta_e_cie2000(color1, color2, Kl=1, Kc=1, Kh=1):
    """
    Calculates the Delta E (CIE2000) of two colors.
    """
    def _get_lab_color1_vector(color):
        return np.array([color.lab_l, color.lab_a, color.lab_b])
    def _get_lab_color2_matrix(color):
        return np.array([(color.lab_l, color.lab_a, color.lab_b)])

    color1_vector = _get_lab_color1_vector(color1)
    color2_matrix = _get_lab_color2_matrix(color2)
    delta_e = color_diff_matrix.delta_e_cie2000(
        color1_vector, color2_matrix, Kl=Kl, Kc=Kc, Kh=Kh)[0]
    return delta_e

def rgb_to_lab(rgb):
    """Convert an RGB color (0-255) to LAB color space."""
    srgb = sRGBColor(rgb[0] / 255, rgb[1] / 255, rgb[2] / 255, is_upscaled=False)
    return convert_color(srgb, LabColor)

def hsl_to_rgb(hue, saturation, lightness):
    """Convert HSL values to RGB (0-255)."""
    r, g, b = colorsys.hls_to_rgb(hue / 360, lightness, saturation)
    return (int(r * 255), int(g * 255), int(b * 255))


# def is_color_too_similar(hue1, hue2, threshold=40):
#     """Check if two hues are too similar by checking their difference in hue space."""
#     return abs(hue1 - hue2) < threshold or abs(hue1 - hue2) > (360 - threshold)


def is_color_too_similar(hue1, hue2, threshold=15):
    """
    Check if two hues are too similar, accounting for perceptual non-uniformity.
    Compare using the CIEDE2000 formula in the LAB color space.

    Higher threshold ==> colors need to be more different
    """

    # Convert hues to RGB colors using fixed saturation and lightness for comparison
    rgb1 = hsl_to_rgb(hue1, 0.8, 0.5)  # Vivid, medium lightness
    rgb2 = hsl_to_rgb(hue2, 0.8, 0.5)

    # Convert RGB to LAB for perceptual uniformity
    lab1 = rgb_to_lab(rgb1)
    lab2 = rgb_to_lab(rgb2)

    # Calculate perceptual difference using CIEDE2000
    delta_e = delta_e_cie2000(lab1, lab2)

    # if delta_e < threshold:
    #     print("Too similar:", hue1, hue2)
    # else:
    #     print("Not too similar:", hue1, hue2)
    return delta_e < threshold


def create_distribution_adjacency_matrix(data, subspecies_cols, cell_col='cell_id'):
    """
    Create an adjacency matrix based on subspecies distribution similarities.

    Parameters:
    - data: DataFrame with cells as rows and subspecies counts as columns.
    - subspecies_cols: List of column names corresponding to subspecies counts.
    - cell_col: Column name for cell identifiers (optional, for reference).

    Returns:
    - adjacency_matrix: A NumPy array where element [i, j] is the similarity between subspecies distributions.
    - subspecies_list: The order of subspecies corresponding to matrix rows/columns.
    """
    # Subset the subspecies columns
    subspecies_data = data[subspecies_cols]

    # Normalize each cell's counts to proportions
    subspecies_distribution = subspecies_data.div(subspecies_data.sum(axis=1), axis=0).fillna(0)

    # Compute cosine similarity between each pair of subspecies
    adjacency_matrix = cosine_similarity(subspecies_distribution.T)

    # Return the matrix and list of subspecies
    return adjacency_matrix, subspecies_cols


def assign_hues(subspecies, overlap_matrix):
    """Assign colors to subspecies based on overlap relationships."""
    # Step 1: Generate base hues
    #base_hues = {subsp: name_to_base_hue(subsp) for subsp in subspecies}

    # # Step 2: Adjust hues based on overlap using graph coloring
    # G = nx.Graph()
    # for i, sp1 in enumerate(subspecies):
    #     for j, sp2 in enumerate(subspecies):
    #         if overlap_matrix[i][j] > 0.1:  # Threshold for "overlap"
    #             G.add_edge(sp1, sp2)

    # coloring = nx.coloring.greedy_color(G, strategy="largest_first")
    # color_mapping = {}
    # used_hues = []

    # for subsp, color_idx in coloring.items():
    #     # Calculate the hue with an initial base hue + an offset determined by color_idx
    #     hue = (base_hues[subsp] + color_idx * 60) % 360  # Spread hues by 60° to maximize contrast
        
    #     # Ensure the hue is distinct from previously used hues
    #     while any(is_color_too_similar(hue, used_hue) for used_hue in used_hues):
    #         hue = (hue + 30) % 360  # Adjust the hue if it's too similar to previously used hues
        
    #     color_mapping[subsp] = hue
    #     # Add the hue to the list of used hues
    #     used_hues.append(hue)
        
    #     # # Convert the hue to RGB
    #     # saturation, lightness = 0.8, 0.5  # Vivid, medium colors
    #     # r, g, b = colorsys.hls_to_rgb(hue / 360, lightness, saturation)
    #     # color_mapping[subsp] = rgb_to_hex((int(r * 255), int(g * 255), int(b * 255)))
    
    # return color_mapping
    return {subsp: idx*(360/(len(subspecies) + 1)) for idx, subsp in enumerate(subspecies)}

def hue_to_hex_vibrant(hue):
    saturation, lightness = 0.8, 0.5  # Vivid, medium colors
    r, g, b = colorsys.hls_to_rgb(hue / 360, lightness, saturation)
    return rgb_to_hex((int(r * 255), int(g * 255), int(b * 255)))


def style_function(feature, subspp_colors):
    """Style a cell based on the proportion of subspecies."""
    properties = feature['properties']
    subspecies_values = {subsp: properties.get(subsp, 0) for subsp in subspp_colors.keys()}
    
    # Normalize the values to sum up to 1 for proportional allocation
    total = sum(subspecies_values.values())
    if total > 0:
        fracs = [value / total for value in subspecies_values.values()]
    else:
        fracs = [0 for _ in subspecies_values]
    
    # Get RGB colors for each subspecies
    hex_colors = [subspp_colors[subsp] for subsp in subspecies_values]
    rgb_colors = [hex_to_rgb(color) for color in hex_colors]
    
    # Combine colors based on the proportional fractions
    cell_color = combine_rgb_colors(rgb_colors, fracs)
    
    return {
        'fillColor': cell_color,  # Cell color
        'color': cell_color,  # Border color
        'weight': 1,  # Border weight
        'fillOpacity': 0.6,  # Cell fill transparency
    }

def calculate_overlap_intensity(overlap_matrix):
    """
    Calculate the overlap intensity for each subspecies.
    Overlap intensity is defined as the proportion of cells that overlap with others.

    Parameters:
    - overlap_matrix: A square matrix where overlap_matrix[i][j] represents the overlap between subspecies[i] and subspecies[j].

    Returns:
    - A list of overlap intensities for each subspecies.
    """
    total_overlap = np.sum(overlap_matrix, axis=1)  # Total overlap for each subspecies
    max_overlap = np.sum(overlap_matrix, axis=1) - np.diagonal(overlap_matrix)  # Exclude self-overlap
    return total_overlap / max_overlap


def generate_priority_hues(subspecies, overlap_matrix):
    """
    Assign perceptually distinct hues to subspecies, prioritizing those with higher overlap intensity.

    Parameters:
    - subspecies: List of subspecies names.
    - overlap_matrix: A square matrix where overlap_matrix[i][j] represents the overlap between subspecies[i] and subspecies[j].

    Returns:
    - A dictionary mapping each subspecies to a distinct hue (in degrees, 0-360).
    """
    n_subspecies = len(subspecies)

    # Step 1: Generate perceptually distinct hues (0-360 degrees)
    hues = np.linspace(0, 360, n_subspecies, endpoint=False)

    # Step 2: Calculate overlap intensity
    overlap_intensity = calculate_overlap_intensity(overlap_matrix)

    # Step 3: Assign hues based on overlap intensity and relationships
    G = nx.Graph()
    for i, sp1 in enumerate(subspecies):
        for j, sp2 in enumerate(subspecies):
            if overlap_matrix[i][j] > 0.1:  # Add edge if overlap is above threshold
                G.add_edge(sp1, sp2, weight=overlap_matrix[i][j])

    # Sort subspecies by overlap intensity
    subspecies_sorted = sorted(zip(subspecies, overlap_intensity), key=lambda x: x[1], reverse=True)

    # Assign hues sequentially to prioritize highly overlapping subspecies
    assigned_hues = {}
    used_hues = set()
    for subsp, _ in subspecies_sorted:
        # Find the most distinct unused hue
        best_hue = None
        max_dist = -1
        for i, hue in enumerate(hues):
            if i in used_hues:
                continue
            # Check perceptual distance to already assigned hues
            if assigned_hues:
                dist = np.min(
                    [min(abs(hue - assigned_hues[sp]), 360 - abs(hue - assigned_hues[sp])) for sp in assigned_hues]
                )
            else:
                dist = float("inf")
            
            if dist > max_dist:
                max_dist = dist
                best_hue = i

        # Assign the best hue
        assigned_hues[subsp] = hues[best_hue]
        used_hues.add(best_hue)
    
    return assigned_hues

def get_color_mapping(sp_cell_df):
    # Get the relationships between the infraspecies (issfs, forms, intergrades)
    issfs, forms, intergrades, inter_to_p, form_to_p, top_inter, top_form = get_infraspecies_relationships(species, spp_dict)

    # Get colors for the top-level infraspecies based on their geographic overlap
    top_level_infras = [*issfs, *top_inter, *top_form]
    sp_cell_df.columns = sp_cell_df.columns.str.replace(species + ' ', "")
    overlap_matrix, top_level_infras = create_distribution_adjacency_matrix(sp_cell_df, top_level_infras)
    #ssp_hues = assign_hues(top_level_infras, overlap_matrix)
    ssp_hues = generate_priority_hues(top_level_infras, overlap_matrix)

    # Get colors for the remaining species based on overlap with top-level infraspecies
    for intergrade, parents in inter_to_p.items():
        ssp_hues[intergrade] = average_hues([ssp_hues[parent] for parent in parents])
    for form, parents in form_to_p.items():
        ssp_hues[form] = average_hues([ssp_hues[parent] for parent in parents])

    # Convert hues to vibrant colors
    ssp_colors = {ssp: hue_to_hex_vibrant(hue) for ssp, hue in ssp_hues.items()}

    return ssp_colors

def get_bounds(geojson_result):
    """
    Calculate the bounding box of all features in the GeoJSON.

    Args:
    - geojson_result: GeoJSON string with features.

    Returns:
    - Bounds as [[southwest_lat, southwest_lon], [northeast_lat, northeast_lon]].
    """
    import json
    geojson_data = json.loads(geojson_result)
    all_coords = []

    for feature in geojson_data['features']:
        # Extract all coordinates from the polygon or multipolygon
        coords = feature['geometry']['coordinates']
        if feature['geometry']['type'] == "Polygon":
            all_coords.extend(coords[0])  # Add outer ring of the polygon
        elif feature['geometry']['type'] == "MultiPolygon":
            for poly in coords:
                all_coords.extend(poly[0])  # Add outer ring of each polygon

    # Extract longitudes (x) and latitudes (y) correctly
    lons, lats = zip(*all_coords)
    return [[min(lats), min(lons)], [max(lats), max(lons)]]


def choropleth_map(sp_cell_df, common_name, subspp_colors):
    """Creates a choropleth map given species data."""
    
    f = folium.Figure()
    map = folium.Map(location=[47, -122], zoom_start=5, tiles="cartodbpositron")
    f.add_child(map)

    sp = sp_cell_df.columns[0]
    subspp = sp_cell_df.columns[1:]
    
    list_features = []
    for _, row in sp_cell_df.iterrows():
        #percentages = (row[subspp] / row[sp]) # For the previous implementation that colored the map by % of total sightings instead of % of ssp sightings
        percentages = (row[subspp] / sum(row[subspp]))*100
        percentages_dict = percentages.to_dict()
        
        # Precompute tooltip text showing only non-zero percentages
        percentages_dict_ordered = pd.DataFrame(percentages_dict, index=['pct']).T.query('pct > 0')['pct'].sort_values(ascending=False).to_dict()

        tooltip_text = []
        for subsp, percent in percentages_dict_ordered.items():
            tooltip_text.append(f"{subsp}: {percent:.0f}%")
        
        # Add tooltip as a string to the properties
        percentages_dict["tooltip"] = "<br>".join(tooltip_text) if tooltip_text else "No data"

        geometry_for_row = h3.cells_to_geo(cells=[row.name])
        feature = Feature(
            geometry=geometry_for_row,
            id=row.name,
            properties=percentages_dict)
        list_features.append(feature)

    feat_collection = FeatureCollection(list_features)
    geojson_result = json.dumps(feat_collection)
    
    # Add GeoJSON layer to the map
    folium.GeoJson(
        geojson_result,
        style_function=lambda feature: style_function(feature, subspp_colors),
        name=f'{sp} Subspecies Map'
    ).add_to(map)
    
    # Add tooltips
    folium.GeoJson(
        geojson_result,
        style_function=lambda feature: {
            'weight': 0,  # No border weight
            'color': 'transparent',  # No border color
            'fillOpacity': 0.6  # Fill transparency
        },
        tooltip=GeoJsonTooltip(
            #fields=list(subspp),
            #aliases=[subsp[len(sp)+1:] for subsp in subspp], # Removes ssp name
            fields=["tooltip"],
            aliases=["Reported\nSubspecies"],
            localize=True,
            sticky=True,
            labels=True,
            labels_format="{:.2f}%",
            #highlight_function=lambda x: x.update({'text': [f'{k}: {v:.2f}%' for k, v in x['properties'].items() if v > 0]})
        )
    ).add_to(map)


    # Add legend
    legend_html = f"""
    <div style="position: fixed; top: 10px; right: 10px; width: 150px; height: auto; z-index: 9999; background-color: white; box-shadow: 0 0 5px rgba(0, 0, 0, 0.2); border: 1px solid lightgray; border-radius: 5px; padding: 10px; font-size: 10px;">
        <strong>{common_name} infraspecies</strong><br>
    """
    for subsp, color in subspp_colors.items():
        legend_html += f"""
        <div style="margin-top: 10px;">
            <span style="display: inline-block; width: 20px; height: 10px; margin-right: 10px; background-color: {color};"></span>
            {subsp}
        </div>
        """
    legend_html += "</div>"
    legend_element = folium.Element(legend_html)
    map.get_root().html.add_child(legend_element)

    # Calculate bounds and adjust the map's view
    bounds = get_bounds(geojson_result)
    map.fit_bounds(bounds)

    return map

# Example usage:
#for sp_code in sp_codes:
remake_maps = True
for sp_code in sp_codes:
    print("Mapping", sp_code)
    common_name = taxonomy[taxonomy['SPECIES_CODE'] == sp_code].PRIMARY_COM_NAME.values[0]
    for resolution in [2,3,4,5]:
        species = taxonomy[taxonomy['PRIMARY_COM_NAME'] == common_name].SCI_NAME.values[0]
        dataname = f"sp_cell_dfs/{species.replace(' ', '-')}_resolution{resolution}.csv"
        if not Path(dataname).exists():
            print("No data for", species, "at resolution", resolution)
            continue
        map_filename = f"docs/maps/{species.replace(' ', '-')}_{resolution}.html"
        if Path(map_filename).exists() and not remake_maps:
            print("Map already exists for", species, "at resolution", resolution)
            continue
        sp_cell_df = pd.read_csv(dataname, index_col=0)
        sp_cell_df.columns = sp_cell_df.columns.str.replace(species + ' ', "")
        subspecies = sp_cell_df.columns[1:]
        if resolution == 2:
            subspp_colors = get_color_mapping(sp_cell_df)
            
        m = choropleth_map(sp_cell_df, common_name, subspp_colors)
        m.save(f"docs/maps/{species.replace(' ', '-')}_{resolution}.html")


Mapping strher


  return total_overlap / max_overlap


KeyboardInterrupt: 

## Create a CSV of map URLs for the website

In [529]:
df = pd.DataFrame(columns=["common_name", "scientific_name", "resolution", "map_url"])
maps_dir = Path("docs/maps")
for idx, file in enumerate(maps_dir.glob("*.html")):
    resolution = file.stem.split("_")[-1]
    species = file.stem.replace(f"_{resolution}", "")
    common_name = taxonomy[taxonomy['SCI_NAME'] == species.replace('-', ' ')].PRIMARY_COM_NAME.values[0]
    map_url = Path(Path(file).parent.stem).joinpath(Path(file).name)
    print(map_url)
    df.loc[idx] = [common_name, species, resolution, map_url]
df.to_csv("docs/data/map_data.csv", index=False)

maps/Butorides-striata_4.html
maps/Butorides-striata_5.html
maps/Butorides-striata_2.html
maps/Butorides-striata_3.html
maps/Branta-bernicla_3.html
maps/Setophaga-coronata_3.html
maps/Buteo-jamaicensis_4.html
maps/Zonotrichia-leucophrys_3.html
maps/Sturnella-magna_5.html
maps/Loxia-curvirostra_3.html
maps/Garrulus-glandarius_4.html
maps/Garrulus-glandarius_5.html
maps/Loxia-curvirostra_2.html
maps/Sturnella-magna_4.html
maps/Zonotrichia-leucophrys_2.html
maps/Buteo-jamaicensis_5.html
maps/Setophaga-coronata_2.html
maps/Branta-bernicla_2.html
maps/Sturnella-magna_3.html
maps/Loxia-curvirostra_5.html
maps/Garrulus-glandarius_2.html
maps/Branta-bernicla_5.html
maps/Setophaga-coronata_5.html
maps/Buteo-jamaicensis_2.html
maps/Branta-hutchinsii_2.html
maps/Zonotrichia-leucophrys_5.html
maps/Zonotrichia-leucophrys_4.html
maps/Buteo-jamaicensis_3.html
maps/Setophaga-coronata_4.html
maps/Branta-bernicla_4.html
maps/Garrulus-glandarius_3.html
maps/Loxia-curvirostra_4.html
maps/Sturnella-magna_2