# Example optimisation with stochastic nonlocal modifications and gradient descent

## Imports

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import time
tic = time.time()

In [None]:
from ship_routing.core import Route, WayPoint
from ship_routing.data import (
    load_currents,
    load_winds,
    load_waves,
)
from ship_routing.convenience import (
    create_route, stochastic_search, gradient_descent, Logs, LogsRoute
)
from ship_routing.algorithms import (
    crossover_routes_minimal_cost,
    crossover_routes_random,
)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from copy import deepcopy
import tqdm
from random import choice
import xarray as xr
import shapely

## Parameters

In [None]:
population_size = 512

# reproducibility
random_seed = 345

# data sources
current_data_store = "/gxfs_work/geomar/smomw122/2024_ship_routing/ship_routing_data/data/cmems_mod_glo_phy_my_0.083deg_P1D-m_time_2021_lat_+10_+65_lon_-100_+010_uo-vo.zarr/"
wave_data_store = "/gxfs_work/geomar/smomw122/2024_ship_routing/ship_routing_data/data/cmems_mod_glo_wav_my_0.2deg_PT3H-i_time_2021_lat_+10_+65_lon_-100_+010_VHM0-VMDR.zarr/"
wind_data_store = "/gxfs_work/geomar/smomw122/2024_ship_routing/ship_routing_data/data/cmems_obs-wind_glo_phy_my_l4_0.125deg_PT1H_time_2021_lat_+10_+65_lon_-100_+010_eastward_wind-northward_wind.zarr/"

# initial route
lon_waypoints = [-80.5, -12.0]
lat_waypoints = [30.0, 45.0]
time_start = "2021-08-01T12:00"
time_end = None
speed_knots = 12.0  # either one time and speed or both times needed
time_resolution_hours = 12.0

# stochastic search parameters
stoch_num_generations = 4
stoch_number_of_iterations = 1
stoch_acceptance_rate_target = 0.01
stoch_warmup_acceptance_rate = 0.3

# experiment id
experiment_id = 2

scheduler_file =  None
dask_n_workers = 1

In [None]:
np.random.seed(random_seed)

## Define a route

In [None]:
route_0 = create_route(
    lon_waypoints=lon_waypoints,
    lat_waypoints=lat_waypoints,
    time_start=time_start,
    time_end=time_end,
    speed_knots=speed_knots,
    time_resolution_hours=time_resolution_hours,
)

route_0

In [None]:
print("speed (m/s)", np.mean([l.speed_ms for l in route_0.legs]))

## Load and plot currents, winds, waves

In [None]:
currents = load_currents(
    data_file=current_data_store,
    engine="zarr",
    chunks="auto",
)
currents["speed"] = ((currents.to_array() ** 2).sum("variable") ** 0.5).where(
    ~currents.uo.isnull()
)

currents

In [None]:
winds = load_winds(
    data_file=wind_data_store,
    engine="zarr",
    chunks="auto",
)
winds["speed"] = ((winds.to_array() ** 2).sum("variable") ** 0.5).where(
    ~winds.uw.isnull()
)
winds = winds

In [None]:
waves = load_waves(
    data_file=wave_data_store,
    engine="zarr",
    chunks="auto",
)

waves = waves

## Subset for the route

In [None]:
%time currents = currents.sel(time=slice(route_0.way_points[0].time, route_0.way_points[-1].time)).compute()
%time winds = winds.sel(time=slice(route_0.way_points[0].time, route_0.way_points[-1].time)).compute()
%time waves = waves.sel(time=slice(route_0.way_points[0].time, route_0.way_points[-1].time)).compute()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 3))

fig.set_dpi(200)

currents.speed.mean("time").plot(ax=ax[0])
winds.speed.mean("time").plot(ax=ax[1])
waves.wh.mean("time").plot(ax=ax[2])

for _ax in ax.flatten():
    route_0.data_frame.plot.line(
        x="lon", y="lat", marker=".", ax=_ax, color="violet", label="route_0"
    )

fig.tight_layout();

In [None]:
currents.compute().drop_encoding().to_netcdf(f"{experiment_id}_currents.nc")
winds.compute().drop_encoding().to_netcdf(f"{experiment_id}_winds.nc")
waves.compute().drop_encoding().to_netcdf(f"{experiment_id}_waves.nc")

## Concurrency

In [None]:
from dask import delayed

In [None]:
from ship_routing.data import HashableDataset

@delayed 
def load_data(filename):
    return HashableDataset(xr.open_dataset(filename).compute())

In [None]:
@delayed
def stochastic_search_delayed(logs_routes, mod_width=None, max_move_meters=None, **kwargs):
    route = logs_routes.route
    if mod_width is None:
        mod_width = np.random.uniform(0, 1) * route.length_meters
    if max_move_meters is None:
        max_move_meters = 0.75 * mod_width
    _route, _logs_routes = stochastic_search(
        route=route,
        mod_width=mod_width,
        max_move_meters=max_move_meters,
        acceptance_rate_target=0,
        **kwargs,
    )
    return _logs_routes[-1]

In [None]:
@delayed
def select_cheaper(logs_routes_0, logs_routes_1, chance=0.2):
    if (
        (logs_routes_0.logs.cost < logs_routes_1.logs.cost)
        or (np.random.uniform(0, 1) < chance)
    ):
        return logs_routes_0
    else:
        return logs_routes_1

In [None]:
@delayed
def crossover_routes_random_del(r0, r1, currents=None, winds=None, waves=None):
    try:
        rc = crossover_routes_random(r0.route, r1.route)
        cost = rc.cost_through(
            current_data_set=currents,
            wind_data_set=winds,
            wave_data_set=waves,
        )
        return LogsRoute(route=rc, logs=Logs(cost=cost, method="crossover_random"))
    except:
        if np.random.rand() > 0.5:
            return r0
        else:
            return r1

## Create Dask Cluster

In [None]:
from dask.distributed import Client, Scheduler

In [None]:
if scheduler_file is None:
    client = Client(threads_per_worker=1, n_workers=dask_n_workers, ip="0.0.0.0")
else:
    client = Client(scheduler_file=scheduler_file)
display(client)

## Define currents

In [None]:
_currents = load_data(f"{experiment_id}_currents.nc")
_winds = load_data(f"{experiment_id}_winds.nc")
_waves = load_data(f"{experiment_id}_waves.nc")

## Create population

In [None]:
population = [LogsRoute(logs=Logs(), route=route_0) for _ in range(population_size)]

In [None]:
len_0 = route_0.length_meters

## Run generations of stochastic search

In [None]:
%%time

generations = [population]
for ngen in tqdm.tqdm(range(stoch_num_generations + 1)):

    # delay population
    population = [delayed(p) for p in population]

    # cleanup dask cluster
    time.sleep(2.0)
    client.restart()
    time.sleep(2.0)

    # warmup if first generation
    # (This accepts a fraction of `stoch_warmup_acceptance_rate` modification if it leads to a valid route.)
    if ngen == 0:
        population = [
            stochastic_search_delayed(
                logs_routes=lr,
                number_of_iterations=1,
                acceptance_rate_for_increase_cost=stoch_warmup_acceptance_rate,
                mod_width=len_0,
                max_move_meters=0.75 * len_0,
                current_data_set=_currents,
                wave_data_set=_waves,
                wind_data_set=_winds,
            )
            for lr in population
        ]

    # stochastic search
    population = [
        stochastic_search_delayed(
            logs_routes=lr,
            number_of_iterations=1,
            refinement_factor=0.7,
            acceptance_rate_for_increase_cost=0.0,
            mod_width=len_0,
            max_move_meters=0.75 * len_0,
            current_data_set=_currents,
            wave_data_set=_waves,
            wind_data_set=_winds,
        )
        for lr in population
    ]

    # crossover selection for minimal cost
    population = [
        crossover_routes_random_del(
            population[n0], population[n1],
            currents=_currents,
            winds=_winds,
            waves=_waves,
        )
        for n0, n1 in zip(
            range(len(population)),
            np.random.randint(0, len(population), size=(len(population), )),
        )
    ]

    # compute and filter
    population = [pc.compute() for pc in [pp.persist() for pp in population]]
    cost = [p.logs.cost for p in population]
    c20 = np.quantile(q=0.2, a=cost)
    p20 = [p for p in population if p.logs.cost <= c20]
    population = list(np.random.choice(p20, size=(population_size, )))
    
    generations.append(population)

    len_0 /= 1.5

In [None]:
cost = [rl.logs.cost for rl in population]
lengths = [rl.route.length_meters for rl in population]

In [None]:
plt.hist(pd.Series(cost).where(pd.Series(cost) < pd.Series(cost).quantile(0.99)).dropna())

In [None]:
route_0.cost_through(currents, winds, waves)

In [None]:
np.min(cost) / route_0.cost_through(currents, winds, waves)

In [None]:
imin = np.argmin(cost)
imin

In [None]:
population[imin].route.cost_through(currents, winds, waves) / route_0.cost_through(currents, winds, waves)

In [None]:
r_min = population[imin]

In [None]:
len(r_min.route)

In [None]:
r_min_grad = gradient_descent(
    route=r_min.route,
    num_iterations=3,
    current_data_set=currents,
    wind_data_set=winds,
    wave_data_set=waves,
)

In [None]:
print(r_min.route.cost_through(currents, winds, waves))
print(r_min_grad[0].cost_through(currents, winds, waves))

In [None]:
for lr in population:
    plt.plot(*lr.route.line_string.xy, color="k", alpha=0.1);