from collections.abc import Sequence

import firedrake as fd
from firedrake.adjoint import (
    Control, L2TransformedFunctional, ReducedFunctional, continue_annotation,
    minimize, pause_annotation)
import numpy as np
from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy

import matplotlib.pyplot as plt

mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed")
x, y = fd.SpatialCoordinate(mesh)
space = fd.FunctionSpace(mesh, "Lagrange", 1)
test = fd.TestFunction(space)
trial = fd.TrialFunction(space)
bc = fd.DirichletBC(space, 0, "on_boundary")


def pre_process(m, bc):
    space = m.function_space()
    m_0 = m.copy(deepcopy=True)
    bc.apply(m_0)
    m_1 = fd.Function(space, name="m_1").assign(m - m_0)
    return m_0, m_1


def forward(m):
    m_0, m_1 = pre_process(m, bc)
    u = fd.Function(space, name="u")
    fd.solve(fd.inner(fd.grad(trial), fd.grad(test)) * fd.dx
             == fd.inner(m_0, test) * fd.dx,
             u, bc)
    return m_0, m_1, u


def forward_J(m, u_ref, *, alpha=1):
    m_0, m_1, u = forward(m)
    return fd.assemble(
        fd.inner(u - u_ref, u - u_ref) * fd.dx
        + fd.Constant(alpha ** 2) * fd.inner(m_1, m_1) * fd.ds)


m_ref = fd.Function(space, name="m_ref").interpolate(
    fd.exp(x) * fd.sin(fd.pi * x) * fd.sin(fd.pi * y))
m_ref, _, u_ref = forward(m_ref)


continue_annotation()
m_0 = fd.Function(space, name="m_0")
J = forward_J(m_0, u_ref)
pause_annotation()
c = Control(m_0, riesz_map="l2")


class MinimizeCallback(Sequence):
    def __init__(self, m_0, error_norm):
        self._space = m_0.function_space()
        self._error_norm = error_norm
        self._data = []

        self(np.asarray(m_0._ad_to_list(m_0)))

    def __len__(self):
        return len(self._data)

    def __getitem__(self, key):
        return self._data[key]

    def __call__(self, xk):
        k = len(self)
        m_k = fd.Function(self._space, name="m_k")
        m_k._ad_assign_numpy(m_k, xk, 0)
        error_norm = self._error_norm(m_k)
        print(f"{k=} {error_norm=:6g}")
        self._data.append(error_norm)


def error_norm(m_k):
    m_k, _ = pre_process(m_k, bc)
    return fd.norm(m_k - m_ref, "L2")


J_hat = ReducedFunctional(J, c)
min_nopc = MinimizeCallback(m_0, error_norm)
_ = minimize(J_hat, method="L-BFGS-B",
             callback=min_nopc,
             options={"ftol": 0,
                      "gtol": 1e-10})


def error_norm(m_k):
    m_k = J_hat.map_result(m_k)
    m_k, _ = pre_process(m_k, bc)
    return fd.norm(m_k - m_ref, "L2")


J_hat = L2TransformedFunctional(J, c, alpha=1e-5)
min_pc = MinimizeCallback(J_hat.controls[0].control, error_norm)
_ = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B",
             callback=min_pc,
             options={"ftol": 0,
                      "gtol": 1e-12})

fig, ax = plt.subplots(1, constrained_layout=True)
ax.semilogy(tuple(range(len(min_nopc))), tuple(min_nopc), "k-",
            label="$l_2$")
ax.semilogy(tuple(range(len(min_pc))), tuple(min_pc), "r-",
            label="$L^2$ transformed")
ax.set_xlim(0, max(len(min_nopc), len(min_pc)) - 1)
ax.set_xlabel("Iteration", fontsize="x-large")
ax.set_ylabel(r"$L^2$ error norm", fontsize="x-large")
ax.legend(fontsize="large")
fig.savefig("pc.png", dpi=288)
plt.close(fig)
