# Evaluation

In [1]:
%cd ../..

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn, Tensor, optim
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.utils import make_grid

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from PIL import Image

from dlbpgm.models import LGU, LGT, LGR
from dlbpgm.models.evaluator import Metrics, Evaluator
from dlbpgm.datasets import StridedPlantDataset, Metadata, MultiplePlantDataset


device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device

## Test Dataset
The evaluation of the models was conducted on three different step sizes between the timestamps, that were passed to the models.
For simplicity, these were defined as strides within the PyTorch datasets.

As with all other examples, we begin by loading the metadata of our dataset, which allows us to access the images in the test dataset.

In [6]:
metadata = Metadata.load("dlbpgm/resources/metadata.csv")
metadata = metadata.remove_empty_pots()

To generate a PyTorch dataset from the metadata, we make use of the two classes `StridedPlantDataset` and `MultiplePlantDataset`.

The `StridedPlantDataset` class manages a time series of a single plant and produces samples in the form of images and their corresponding timestamps. 
When initializing the class, additional keyword arguments are required to determine the sequence lengths, which specifies the number of images in each sample. 
We also specify the stride between the timestamps used for sampling and transforms to be applied to the data before feeding them into the models.
In this case, we separate for- and background of the images, using the alpha-channel of the images as a mask.
Additionally, we scale each image to a resolution of $256^2$ px.

`MultiplePlantDataset` is a wrapper class that creates a `StridedPlantDataset` object for each plant contained in the metadata.

In [8]:
stride = 40  # models were evaluated with stride 40, 80, and 150

test_ds = MultiplePlantDataset(
    metadata=metadata.test_ds(),
    dataset_type=StridedPlantDataset,
    stride=stride,
    sequence_length=9,
    transform=T.Compose([
        T.Lambda(lambda t: t[:, :3] * t[:, 3:]),
        T.Resize((256, 256), antialias=True),
    ])
)

## Model
This this case, we evaluate the LGU model on a resolution of $256^2$ px.
In a first step, we initialize the model by specifying the desired resolution and the path to the pre-trained model weights.
(It may happen that you need to adjust the checkpoint path)

In [3]:
model = LGU(
    resolution=256, 
    ckpt_path="Arabidopsis/experiments/256x256/LGUnet/lgu256.pt"
)

The next step includes the initialization of the `Evaluator` by passing model, test dataset, batch size, and number of workers to the constructor.
If possible, we also put the evaluator on the GPU to speed up the evaluation and set it to evaluation mode.

In [4]:
evaluator = Evaluator(
    model=model,
    dataset=test_ds,
    batch_size=8,
    num_workers=2,
).to(device).eval()

Now we are ready to start the evaluation:

In [None]:
evaluator.evaluate()

The results are returned on form of a tuple, containing three dictionaries.
The first one contains the metrics measured on both interpolated and extrapolated predictions.
The second and third dictionaries look at the two cases serparately.

In [28]:
metrics = evaluator.compute_metrics()
metrics