In [None]:
import sample_network_unmix as snu
import export_rate_optimize as ero
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from multiprocessing import Pool, cpu_count
from matplotlib.colors import LogNorm
import time

from typing import Dict, List, Tuple, Any, Callable


def explore_regularisation_strengths(
    export_strength_tryer: Callable[[float], Dict[str, Any]],
    start: float = -5,
    stop: float = 2,
    number: int = 8,
) -> pd.DataFrame:
    """Explore a range of regularisation strengths and plot the results. Takes
    as input a function that maps a regulariser strength to a dictionary of model outputs.
    This requires defining a function for each network which is undesirable but for now it'll do"""
    export_reg_strengths = list(np.logspace(start, stop, number))
    with Pool(cpu_count()) as pool:
        outputs = pool.map(func=export_strength_tryer, iterable=export_reg_strengths)
    # Process the outputs into a more manageable format
    results_df = pd.DataFrame(
        [
            {
                "regulariser": o["regulariser"],
                "data_misfit": o["data_misfit"],
                "model_size": o["model_size"],
                **o["export_rates"].composition,
            }
            for o in outputs
        ]
    )
    plt.scatter(
        results_df["data_misfit"],
        results_df["model_size"],
        c=results_df["regulariser"],
        norm=LogNorm(),
    )
    plt.xlabel("Data misfit")
    plt.ylabel("Model size")
    # Plot the regulariser strength as a text label
    for i, txt in enumerate(np.log10(results_df["regulariser"])):
        plt.annotate(np.round(txt,2), (results_df["data_misfit"][i], results_df["model_size"][i]))
    cb = plt.colorbar()
    cb.set_label("Regulariser strength")
    plt.show()
    return results_df


def plot_export_rates_against_regulariser(results_df: pd.DataFrame, nodes: List[str], title: str):
    """Plot the export rates for each node against regulariser strength."""
    plt.figure(figsize=(10, 6))
    plt.title(title)
    plt.xlabel("Regulariser strength")
    plt.ylabel("Export rate")
    plt.xscale("log")
    plt.yscale("logit")
    # Loop through the nodes in the network and plot the export rate against regulariser strength
    for node in nodes:
        plt.plot(
            results_df["regulariser"],
            results_df[node],
            label=node,
        )
    plt.legend()
    

### Set up the overall sample network and load in the data

In [None]:
# Load sample network
sample_network, _ = snu.get_sample_graphs(
    flowdirs_filename="data/d8.asc",
    sample_data_filename="data/sample_data.dat",
)

# Load in observations
obs_data = pd.read_csv("data/sample_data.dat", delimiter=" ")
obs_data = obs_data.drop(columns=["Bi", "S"])
print("Building problem...")
elements = obs_data.columns[3:-5]
source_regulariser = 10 ** (-1.0)

### Split the network into indidivual networks for each river

In [None]:
# Splitting sample network into one for each river
components = [
    sample_network.subgraph(c).copy() for c in nx.weakly_connected_components(sample_network)
]
river_networks = {}
for c in components:
    root = [n for n in nx.topological_sort(c)][-1]
    if root == "CG039":
        name = "don"
    elif root == "CG088":
        name = "Dee"
    elif root == "CG090":
        name = "Deveron"
    elif root == "CG013":
        name = "Tay"
    elif root == "CG021":
        name = "Don"
    river_networks[name] = c
    plt.figure(figsize=(5, 3))  # Visualise network
    plt.title(name)
    snu.plot_network(c)
    plt.show()


def network_to_problem_multielement(
    river_network: nx.digraph,
) -> Tuple[snu.SampleNetworkUnmixer, ero.MultiElementData]:
    """
    Converts a river network into a SampleNetworkUnmixer problem and a MultiElementData object.
    """
    problem = snu.SampleNetworkUnmixer(sample_network=river_network)
    # Generate a MultiElementData object for each river stored in a dictionary
    samples = [n for n in river_network.nodes]
    obs = obs_data[obs_data["Sample.Code"].isin(samples)]
    multielement = ero.get_multielementdata(obs, elements)
    return problem, multielement

## Deveron

Now we set up the problem in the river Deveron which is small so we can assess the run-time.

In [None]:
deveron_problem, deveron_multielement = network_to_problem_multielement(river_networks["Deveron"])
deveron_nodes = [n for n in river_networks["Deveron"]]
deveron_optimiser = ero.ExportRateOptimizer(
    source_optimiser=deveron_problem,
    observations=deveron_multielement,
    source_regulariser_strength=source_regulariser,
    export_regulariser_strength=0.0,
)


def try_reg_strength_deveron(
    reg_strength: float,
):
    """Try a regulariser strength and return the resulting export rates and data misfit."""
    deveron_optimiser.regulariser_strength = reg_strength
    deveron_optimiser.optimise()
    return {
        "regulariser": reg_strength,
        "export_rates": deveron_optimiser.export_rates,
        "data_misfit": deveron_optimiser.data_misfit,
        "model_size": deveron_optimiser.model_size,
    }


print("Exploring regulariser strengths for Deveron...")
start_time = time.time()

number_to_try = 33 
deveron_results = explore_regularisation_strengths(try_reg_strength_deveron, number=number_to_try)
end_time = time.time()
print("#" * 40)
print(f"Analysed results for {number_to_try} regulariser strengths")
print(f"Time taken: {(end_time - start_time)/60} minutes")
print("#" * 40)

# Read in the value of "chosen_strength" from command line input:
chosen_strength = input("Enter chosen regulariser strength for Deveron (log10 value): ")

deveron_optimiser.regulariser_strength = 10 ** float(chosen_strength)
print(
    f"Solving Deveron network with a regulariser strength of {deveron_optimiser.regulariser_strength}..."
)
start_time = time.time()
deveron_optimiser.optimise()
end_time = time.time()
deveron_export_rates = deveron_optimiser.export_rates.composition
print(f"Time taken: {(end_time - start_time)/60} minutes")

plot_export_rates_against_regulariser(deveron_results, deveron_nodes, "Deveron")
plt.vlines(
    x=deveron_optimiser.regulariser_strength, ymin=1e-2, ymax=1-1e-2, color="grey", linestyles="dashed"
)
plt.show()

## Spey 

In [None]:
spey_problem, spey_multielement = network_to_problem_multielement(river_networks["Spey"])
spey_nodes = [n for n in river_networks["Spey"]]
spey_optimiser = ero.ExportRateOptimizer(
    source_optimiser=spey_problem,
    observations=spey_multielement,
    source_regulariser_strength=source_regulariser,
    export_regulariser_strength=0.0,
)

def try_reg_strength_spey(
    reg_strength: float,
):
    """Try a regulariser strength and return the resulting export rates and data misfit."""
    spey_optimiser.regulariser_strength = reg_strength
    spey_optimiser.optimise()
    return {
        "regulariser": reg_strength,
        "export_rates": spey_optimiser.export_rates,
        "data_misfit": spey_optimiser.data_misfit,
        "model_size": spey_optimiser.model_size,
    }


print("Exploring regulariser strengths for Spey...")
start_time = time.time()

number_to_try = 33 
spey_results = explore_regularisation_strengths(try_reg_strength_spey, number=number_to_try)
end_time = time.time()
print("#" * 40)
print(f"Analysed results for {number_to_try} regulariser strengths")
print(f"Time taken: {(end_time - start_time)/60} minutes")
print("#" * 40)

# Read in the value of "chosen_strength" from command line input:
chosen_strength = input("Enter chosen regulariser strength for Spey (log10 value): ")

spey_optimiser.regulariser_strength = 10 ** float(chosen_strength)
print(
    f"Solving Spey network with a regulariser strength of {spey_optimiser.regulariser_strength}..."
)
start_time = time.time()
spey_optimiser.optimise()
end_time = time.time()
print(f"Time taken: {(end_time - start_time)/60} minutes")
spey_export_rates = spey_optimiser.export_rates.composition

plot_export_rates_against_regulariser(deveron_results, deveron_nodes, "Deveron")
plt.vlines(
    x=deveron_optimiser.regulariser_strength, ymin=1e-2, ymax=1-1e-2, color="grey", linestyles="dashed"
)
plt.show()

In [None]:
plot_export_rates_against_regulariser(spey_results, spey_nodes, "Spey")
plt.vlines(
    x=spey_optimiser.regulariser_strength, ymin=0.001, ymax=0.5, color="grey", linestyles="dashed"
)
plt.show()

## Don

In [None]:
don_problem, don_multielement = network_to_problem_multielement(river_networks["Don"])
don_nodes = [n for n in river_networks["Don"]]
don_optimiser = ero.ExportRateOptimizer(
    source_optimiser=don_problem,
    observations=don_multielement,
    source_regulariser_strength=source_regulariser,
    export_regulariser_strength=0.0,
)

def try_reg_strength_don(
    reg_strength: float,
):
    """Try a regulariser strength and return the resulting export rates and data misfit."""
    don_optimiser.regulariser_strength = reg_strength
    don_optimiser.optimise()
    return {
        "regulariser": reg_strength,
        "export_rates": don_optimiser.export_rates,
        "data_misfit": don_optimiser.data_misfit,
        "model_size": don_optimiser.model_size,
    }


print("Exploring regulariser strengths for Don...")
start_time = time.time()

number_to_try = 17 
don_results = explore_regularisation_strengths(try_reg_strength_don, number=number_to_try)
end_time = time.time()
print("#" * 40)
print(f"Analysed results for {number_to_try} regulariser strengths")
print(f"Time taken: {(end_time - start_time)/60} minutes")
print("#" * 40)

# Read in the value of "chosen_strength" from command line input:
chosen_strength = input("Enter chosen regulariser strength for Don (log10 value): ")

don_optimiser.regulariser_strength = 10 ** float(chosen_strength)
print(
    f"Solving Don network with a regulariser strength of {don_optimiser.regulariser_strength}..."
)
start_time = time.time()
don_optimiser.optimise()
end_time = time.time()
print(f"Time taken: {(end_time - start_time)/60} minutes")
don_export_rates = don_optimiser.export_rates.composition

In [None]:
plot_export_rates_against_regulariser(don_results, don_nodes, "Don")
plt.vlines(
    x=don_optimiser.regulariser_strength, ymin=0.001, ymax=1-0.001, color="grey", linestyles="dashed"
)
plt.show()

## Tay

In [None]:
tay_problem, tay_multielement = network_to_problem_multielement(river_networks["Tay"])
tay_nodes = [n for n in river_networks["Tay"]]
tay_optimiser = ero.ExportRateOptimizer(
    source_optimiser=tay_problem,
    observations=tay_multielement,
    source_regulariser_strength=source_regulariser,
    export_regulariser_strength=0.0,
)

def try_reg_strength_tay(
    reg_strength: float,
):
    """Try a regulariser strength and return the resulting export rates and data misfit."""
    tay_optimiser.regulariser_strength = reg_strength
    tay_optimiser.optimise()
    return {
        "regulariser": reg_strength,
        "export_rates": tay_optimiser.export_rates,
        "data_misfit": tay_optimiser.data_misfit,
        "model_size": tay_optimiser.model_size,
    }


print("Exploring regulariser strengths for Tay...")
start_time = time.time()

number_to_try = 8
tay_results = explore_regularisation_strengths(try_reg_strength_tay, number=number_to_try)
end_time = time.time()
print("#" * 40)
print(f"Analysed results for {number_to_try} regulariser strengths")
print(f"Time taken: {(end_time - start_time)/60} minutes")
print("#" * 40)

# Read in the value of "chosen_strength" from command line input:
chosen_strength = input("Enter chosen regulariser strength for Tay (log10 value): ")

tay_optimiser.regulariser_strength = 10 ** float(chosen_strength)
print(
    f"Solving Tay network with a regulariser strength of {tay_optimiser.regulariser_strength}..."
)
start_time = time.time()
tay_optimiser.optimise()
end_time = time.time()
print(f"Time taken: {(end_time - start_time)/60} minutes")
tay_export_rates = tay_optimiser.export_rates.composition

In [None]:
plot_export_rates_against_regulariser(tay_results, tay_nodes, "Tay")
plt.vlines(
    x=tay_optimiser.regulariser_strength, ymin=1e-2, ymax=1-1e-2, color="grey", linestyles="dashed"
)
plt.show()

## Dee

In [None]:
dee_problem, dee_multielement = network_to_problem_multielement(river_networks["Dee"])
dee_nodes = [n for n in river_networks["Dee"]]
dee_optimiser = ero.ExportRateOptimizer(
    source_optimiser=dee_problem,
    observations=dee_multielement,
    source_regulariser_strength=source_regulariser,
    export_regulariser_strength=0.0,
)

def try_reg_strength_dee(
    reg_strength: float,
):
    """Try a regulariser strength and return the resulting export rates and data misfit."""
    dee_optimiser.regulariser_strength = reg_strength
    dee_optimiser.optimise()
    return {
        "regulariser": reg_strength,
        "export_rates": dee_optimiser.export_rates,
        "data_misfit": dee_optimiser.data_misfit,
        "model_size": dee_optimiser.model_size,
    }


print("Exploring regulariser strengths for Dee...")
start_time = time.time()

number_to_try = 8
dee_results = explore_regularisation_strengths(try_reg_strength_dee, number=number_to_try)
end_time = time.time()
print("#" * 40)
print(f"Analysed results for {number_to_try} regulariser strengths")
print(f"Time taken: {(end_time - start_time)/60} minutes")
print("#" * 40)

# Read in the value of "chosen_strength" from command line input:
chosen_strength = input("Enter chosen regulariser strength for Dee (log10 value): ")

dee_optimiser.regulariser_strength = 10 ** float(chosen_strength)
print(
    f"Solving Dee network with a regulariser strength of {dee_optimiser.regulariser_strength}..."
)
start_time = time.time()
dee_optimiser.optimise()
end_time = time.time()
print(f"Time taken: {(end_time - start_time)/60} minutes")
dee_export_rates = dee_optimiser.export_rates.composition

In [None]:
plot_export_rates_against_regulariser(dee_results, dee_nodes, "Dee")
plt.vlines(
    x=dee_optimiser.regulariser_strength, ymin=1e-2, ymax=1-1e-2, color="grey", linestyles="dashed"
)
plt.show()

## Exploring all the export rates together

In [None]:
# Multiply the export rates for the Dee by 0.03 (Mt/yr) - Milliman & Farnsworth
dee_export_rates = {k: v * 0.03 for k, v in dee_export_rates.items()}
# Multiply the don by 0.03 (Mt/yr) - Milliman & Farnsworth
don_export_rates = {k: v * 0.03 for k, v in don_export_rates.items()}
# Multiply the Deveron by 0.01 (Mt/yr) - Milliman & Farnsworth
deveron_export_rates = {k: v * 0.01 for k, v in deveron_export_rates.items()}

In [None]:
# Combining all the export rates into one dictionary 
all_export_rates = (spey_export_rates | don_export_rates | deveron_export_rates | tay_export_rates | dee_export_rates)
area_dict = snu.get_unique_upstream_areas(sample_network)
export_map = snu.get_upstream_concentration_map(areas=area_dict, upstream_preds=all_export_rates)
export_map[export_map==0] = np.nan
plt.imshow(export_map,norm=LogNorm(),cmap="viridis"()
plt.colorbar()
plt.show()