# Training Example

In [1]:
%cd ../..

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn, Tensor, optim
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.utils import make_grid

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from PIL import Image

from dlbpgm.datasets import Metadata, PlantDataModule, PlantMetaDataset, PlantDataset, PlantLatentDataset, MultiplePlantDataset
from dlbpgm.models import LGMTrainer, LGUnet, LatentGrowthTransformer, LatentGrowthRegression
from dlbpgm.models.lgm import LatentGrowthVAE, LossDict

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device

## Dataset
We use metadata objects to organize our datasets. 
These objects are an extension of the Pandas DataFrame and contain helpful functions for our use case. 

All metadata is saved in a CSV file and can be easily loaded from the Metadata class.
The CSV files contain all relevant information on the plants in the dataset.
This includes the plant's identification, file paths to its images and latent variables, timestamps, and whether the plant exists on the image. It also specifies whether the plant belongs to the training, validation, or test dataset.

In [2]:
metadata = Metadata.load("dlbpgm/resources/metadata.csv")
metadata = metadata.query("plant_exists == True")
metadata.head()

In the following we split the data into train-, validation-, and test data.

To generate PyTorch datasets from the metadata, we make use of the two classes `PlantLatentDataset` and `MultiplePlantDataset`.

The `PlantLatentDataset` class manages a time series of a single plant and produces samples in the form of latent variables. 
When initializing the class, additional keyword arguments are required to determine the number of inputs and targets to sample from the time series. 
We also specify the size of the time window used for sampling and the path to the parent directory of the latent variables.

`MultiplePlantDataset` is a wrapper class that creates a `PlantLatentDataset` object for each plant contained in the metadata.

In [3]:
# costumize path to the location where your latents are stored
path_to_latents = Path("Arabidopsis/latents128/") 

train_ds = MultiplePlantDataset(
    metadata=metadata.train_ds(),
    dataset_type=PlantLatentDataset,
    timespan=400,
    n_inputs=4,
    n_targets=3,
    latents_dir=path_to_latents
)

val_ds = MultiplePlantDataset(
    metadata=metadata.val_ds(),
    dataset_type=PlantLatentDataset,
    timespan=400,
    n_inputs=4,
    n_targets=3,
    latents_dir=path_to_latents
)

test_ds = MultiplePlantDataset(
    metadata=metadata.test_ds(),
    dataset_type=PlantLatentDataset,
    timespan=400,
    n_inputs=4,
    n_targets=3,
    latents_dir=path_to_latents
)

train_dl = DataLoader(train_ds, batch_size=64, num_workers=4, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=64, num_workers=4, shuffle=None)
test_dl  = DataLoader(test_ds, batch_size=64, num_workers=4, shuffle=None)

In [None]:
# Example variables for inference after training
z_in, t_in, t_out, z_out =  next(iter(DataLoader(test_ds, batch_size=4, shuffle=True)))

kwargs = dict(
    z_in=z_in.to(device), z_out=z_out.to(device),
    t_in=t_in.to(device), t_out=t_out.to(device),
)

## Model
In the following, the models `LGT`, `LGT`, and `LGU` are initialized for an image resolution of $128^2$ px.

The model configurations are identical to those used in the referenced work and can be adjusted as needed.

In [18]:
image_size = 128
latent_size = image_size // 8

# LGT
lgt = LatentGrowthTransformer(
    resolution=latent_size,
    nhead=32,
    dim_head=16,
    dropout=0.1,
    n_positions=1000
)

# LGR
lgr = LatentGrowthRegression(
    resolution=latent_size,
    nhead=32,
    dim_head=16,
    dropout=0.1,
    n_positions=1000
)

# LGU
lgu = LGUnet(
    input_shape=(4, latent_size, latent_size),
    down_features=[128, 256, 512],
    up_features=[512, 256, 256],
    nhead=32,
    dim_head=16,
    dropout=0.1,
    n_positions=1000
)

## Training
The training is carried out with the class `LGMTrainer`.
This requires that we specify a path in which model checkpoints can be saved during training.
In addition, we pass the model to be trained and configure the learning rate and the gradient accumulation steps.

In [19]:
# create a folder where to save checkpoints during training
experiment_dir = Path("LGM")
experiment_dir.mkdir(parents=True, exist_ok=True)

In [24]:
trainer = LGMTrainer(
    path=experiment_dir,
    model=lgt,  # pass lgt, lgr, or lgu, depending on which model you want to train
    lr=1e-4,
    gradient_accumulation_steps=1
).to(device)

In [None]:
for i in range(100):
    trainer.train_epoch(train_dl)
    trainer.validation_epoch(val_dl)

## Inference
Since we are carrying out the training in latent space, we need the decoder of the pre-trained VAE, which transforms the latent variables into images.
This is initialized in the first step.

In the next step, we transform both inputs and predictions into images, which we can then plot.

In [10]:
# pre-trained VAE to decode latent representations to images
vae = LatentGrowthVAE().to(device)

In [11]:
with torch.no_grad():
    # decode input latents to images
    _x = vae.decode(kwargs['z_in'])
    x_in = _x.cpu()

    # inference in latent space
    trainer.model.eval()
    _x, _ = trainer.model(**kwargs)

    # decode predicted latents to images
    _x = vae.decode(_x)
    x_pred = _x.cpu()
    del _x

x = torch.cat([x_in, x_pred], dim=1)

In [15]:
plt.figure(figsize=(16, 10))
grid_img = make_grid(x.flatten(0, 1), nrow=7)
plt.imshow(grid_img.permute(1, 2, 0))