In [2]:
import os
import functools
import os
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import List, Union
import cartopy.crs as ccrs
import cartopy
import cartopy.feature as cfeature
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import xarray as xr
import trajan as ta
import numpy as np
from tqdm import tqdm
import pandas as pd
from matplotlib.dates import DateFormatter
import ACTEA_common_tools_drift

def plot_timeseries(prop_ds, filename, color):
    hfmt = DateFormatter("%d %b %Y %H:%M")
    fig = plt.figure()
    ax = fig.gca()
    ax.xaxis.set_major_formatter(hfmt)
    plt.xticks(rotation="vertical")
    start_time = prop_ds["time"][0].values.astype(object)
   
    times=prop_ds["time"][:].values
    
    data_ds = prop_ds.T[0 : len(times), :]
  
    color_std = color
    color_mean = "black"
    data_mean = np.mean(data_ds, axis=1)
    data_std = np.std(data_ds, axis=1)

    plt.fill_between(
        times, data_mean - data_std, data_mean + data_std, alpha=1.0, color=color_std
    )
    plt.plot(times, data_mean, color=color_mean)
    plt.title(prop)
    plt.xlabel("Time  [UTC]")
    try:
        plt.ylabel("%s  [%s]" % (prop, o.elements.variables[prop]["units"]))
    except:
        plt.ylabel(prop)
    plt.subplots_adjust(bottom=0.3)
    plt.grid()
    plt.savefig(filename,  bbox_inches='tight', facecolor='w', edgecolor='w', transparent=False, pad_inches=0.1)
    print(f"Saved plot to {filename}")
    plt.show()

def get_map_extent():
    xmin = -170; xmax = -154; ymin = 48; ymax = 59
    return [xmin, xmax, ymin, ymax]

def plot_trajectory(ds, lons, lats, z, traj_index, ax, color):

        if traj_index == 0:
            x = ds["lon"].mean(dim="trajectory", skipna=True).values
            y = ds["lat"].mean(dim="trajectory", skipna=True).values
            z = ds["z"].mean(dim="trajectory", skipna=True).values
           
            alpha=1
            linewidth=2
        else:         
            x = lons[traj_index, (lons[traj_index, :] < 1e30)]
            y = lats[traj_index, (lats[traj_index, :] < 1e30)]
            z = z[traj_index, (lats[traj_index, :] < 1e30)]
            alpha=0.1
            linewidth=0.3
        x[x < 0] += 360
        
        ax.plot(x,
                y,
                c=color,
                alpha=alpha,
                linewidth=linewidth,
                transform=ccrs.PlateCarree())

        if len(x) > 20:
            cbar = ax.scatter(x[-1], y[-1], marker=None,
                            facecolors="tab:red",
                            s=0.4,
                            alpha=0.4,
                            linewidth=0.6,
                            zorder=5,
                            transform=ccrs.PlateCarree())

def create_map(ds, filename, color):
    
    # add coastlines from GSHHS
    shpfile = cartopy.io.shapereader.gshhs('i')
    shp = cartopy.io.shapereader.Reader(shpfile)

    projection = ccrs.NorthPolarStereo(central_longitude=180) 
    ax = plt.axes(projection=projection)
    ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)

    ax.add_geometries(
            shp.geometries(), ccrs.PlateCarree(), 
            edgecolor='black', facecolor='lightgrey', 
            linewidth=0.2)
  #  ax.gridlines(draw_labels=True)
    ax.set_extent(get_map_extent(), ccrs.PlateCarree())

    lats = ds["lat"][:].values
    lons = ds["lon"][:].values
    z = ds["z"][:].values
    time = ds["time"][:].values
   
    for traj_index in tqdm(range(len(lons[:, 0])), colour="red"):
        plot_trajectory(ds,lons, lats, z, traj_index, ax, color)

    # Plot the average trajectory
   # plot_trajectory(ds, lons, lats, z, 0, ax, "k")

    plt.savefig(filename, dpi=300, bbox_inches='tight', facecolor='w', edgecolor='w', transparent=False, pad_inches=0.1)
    plt.show()

project="walleye_pollock_eggs"
port_townsend_eggs=True

for port_townsend_eggs in [True, False]:
    for radius in [0, 5000]:
        for density_type, dynamic_type in zip(["constant", "dynamic", "dynamic"], ["constant", "light", "dark"]):
            
            if density_type=="constant":
                postfix = f"constant_egg_density_seed_radius_{radius}"
                color="tab:green"
            else:
                if dynamic_type == "light":
                    postfix = f"dynamic_egg_density_light_seed_radius_{radius}"
                    color="tab:orange"
                else:
                    postfix = f"dynamic_egg_density_dark_seed_radius_{radius}"
                    color="tab:blue"

            if port_townsend_eggs is True:
                postfix = f"{postfix}_port_townsend_eggs"
                print(f"Postfix is {postfix}")


            ds = xr.open_dataset(f'output_{project}/walleye_pollock_eggs_drift_{postfix}.nc', decode_coords=False)
            # Requirement that status>=0 is needed since non-valid points are not masked in OpenDrift output
            ds = ds.where(ds.status >= 0)  # only active particles
            ds = ds.where(ds.lon < 1e30) 
        
            filename=f"{project}/Figures/trajectories_{project}_{postfix}.png"
            create_map(ds, filename, color)

            properties_to_plot = ["z", "density","sea_water_temperature", "sea_water_salinity"]
            for prop in properties_to_plot:
        
                indexes = ACTEA_common_tools_drift.get_indexes_of_last_valid_position(ds, var_name=prop)
            
                prop_ds = np.ma.masked_invalid(ds[prop][:,indexes].values)
                time = np.ma.masked_invalid(ds.time[indexes].values)
        
                filename=f"{project}/Figures/timeseries_{prop}_{project}_{postfix}.png"
                plot_timeseries(ds[prop], filename, color)


