In [None]:
from matplotlib.collections import EllipseCollection

from ellipse_rcnn.utils.conics import (
    ellipse_to_conic_matrix,
    ellipse_center,
)

from ellipse_rcnn.utils.data.fddb import FDDB

In [None]:
from ellipse_rcnn import EllipseRCNN
from ellipse_rcnn.core.model import EllipseRCNNLightning
import torch

In [None]:
model = EllipseRCNN()

In [None]:
pl_model = EllipseRCNNLightning.load_from_checkpoint(
    "../checkpoints/e=08-loss=1.78192.ckpt", model=model
)

In [None]:
model.eval().cpu()

In [None]:
ds = FDDB("../data/FDDB")
ds_raw = FDDB("../data/FDDB", transform=lambda x: x)

In [None]:
from matplotlib.axes import Axes
from matplotlib import pyplot as plt
from ellipse_rcnn.utils.conics import ellipse_axes, ellipse_angle
import numpy as np


def plot_conics(
    A_craters: torch.Tensor,
    resolution: tuple[int, int],
    figsize: tuple[float, float] = (15, 15),
    plot_centers: bool = False,
    ax: Axes | None = None,
    rim_color="r",
    alpha=1.0,
):
    a_proj, b_proj = ellipse_axes(A_craters)
    psi_proj = ellipse_angle(A_craters)
    x_pix_proj, y_pix_proj = ellipse_center(A_craters)

    a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map(
        lambda t: t.detach().cpu().numpy(),
        (a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj),
    )

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})

    # Set axes according to camera pixel space convention
    ax.set_xlim(0, resolution[0])
    ax.set_ylim(resolution[1], 0)

    ec = EllipseCollection(
        a_proj,
        b_proj,
        np.degrees(psi_proj),
        units="xy",
        offsets=np.column_stack((x_pix_proj, y_pix_proj)),
        transOffset=ax.transData,
        facecolors="None",
        edgecolors=rim_color,
        alpha=alpha,
    )
    ax.add_collection(ec)

    if plot_centers:
        crater_centers = ellipse_center(A_craters)
        for k, c_i in enumerate(crater_centers):
            x, y = c_i[0], c_i[1]
            if 0 <= x <= resolution[0] and 0 <= y <= resolution[1]:
                ax.text(x, y, str(k))


i = 90

image, target_dict = ds[i]
image_raw, _ = ds_raw[i]

resolution = tuple(image.shape[-2:])
print(resolution)
# Save or display the image
pred = model(image.unsqueeze(0))
if len(pred[0]["boxes"]) > 0:
    fig, ax = plt.subplots(1, figsize=(15, 25))
    ax.set_aspect("equal")
    ax.grid(True)
    ax.imshow(np.array(image_raw))
    a, b, x, y, theta = map(
        lambda t: t.transpose(-1, 0).detach(), pred[0]["ellipse_matrices"]
    )

    A_pred = ellipse_to_conic_matrix(a=a, b=b, x=x, y=y, theta=theta)
    plot_conics(A_pred, ax=ax, plot_centers=True, resolution=(450, 399))
    plt.show()