In [None]:
import sys
sys.path.append("..")
from figutils import *
from tol_colors import tol_cmap, tol_cset

In [None]:
loss_names = ["l1", "l2"]

losses = []
grads = []
output_dir = os.path.join(OUTPUT_DIR, os.path.basename(os.getcwd()))
spps = np.load(os.path.join(output_dir, "spps.npy"))
param_landscape = np.load(os.path.join(output_dir, "param_landscape.npy"))

for loss_name in loss_names:
    losses.append([])
    grads.append([])
    for spp in spps:
        losses[-1].append(np.load(os.path.join(output_dir, f"loss_{loss_name}_{spp:04d}.npy")))
        grads[-1].append(np.load(os.path.join(output_dir, f"grad_{loss_name}_{spp:04d}.npy")))



In [None]:
n_rows = len(loss_names)
n_cols = 2

aspect = 1.75
fig = plt.figure(1, figsize=(TEXT_WIDTH, TEXT_WIDTH / aspect))
gs = fig.add_gridspec(n_rows, n_cols, wspace=0.2, hspace=0.05)

# colors = sns.color_palette("dark:seagreen", n_colors=len(spps))
colors = [tol_cmap("rainbow_PuRd")(0.15 + 0.5 * i/len(spps)) for i in range(len(spps))]
# colors = tol_cset('bright')

titles = [r"$\mathcal{L}^1$", r"$\mathcal{L}^2$"]
for i, title in enumerate(titles):
    ax_loss = fig.add_subplot(gs[i, 0])
    ax_grad = fig.add_subplot(gs[i, 1])
    ax_grad.plot([5, 5], [min([grads[i][j].min() for j in range(len(spps))]), max([grads[i][j].max() for j in range(len(spps))])], color="grey", linestyle="--")
    ax_grad.plot(param_landscape, np.zeros_like(param_landscape), color="grey", linestyle="--")
    for j, spp in enumerate(spps):
        ax_loss.semilogy(param_landscape, losses[i][j], label=str(spp), color=colors[j])
        ax_loss.scatter(param_landscape[np.argmin(losses[i][j])], np.min(losses[i][j]), color=colors[j], marker="x")
        ax_grad.plot(param_landscape, grads[i][j], label=str(spp), color=colors[j])
    if i == 0:
        ax_grad.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        ax_loss.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    # disable_border(ax_grad)
    # disable_border(ax_loss)
    ax_loss.set_ylabel(title)
    # ax_loss.yaxis.tick_right()
    # ax_grad.yaxis.tick_right()

ax_grad.legend(title="spp")
ax_loss.set_title("Loss", y=-0.3)
ax_grad.set_title("Gradient", y=-0.3)
save_fig("landscapes", pad_inches=0.015)