In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torchvision.transforms
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
import numpy as np
from keypoint_detection.utils.heatmap import gaussian_heatmap, generate_keypoints_heatmap, overlay_image_with_heatmap, get_keypoints_from_heatmap
from keypoint_detection.models.detector import KeypointDetector
from keypoint_detection.data.unlabeled_dataset import UnlabeledBoxDataset

In [None]:
import wandb
from pathlib import Path
from skimage import io
import torchvision

In [None]:
## Get Model checkpoint from wandb


checkpoint_reference = "airo-box-manipulation/box_dataset_04_02/model-k3plhnyf:latest"

# download checkpoint locally (if not already cached)
run = wandb.init(project="test-project", entity="airo-box-manipulation")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()

# 
checkpoint = torch.load(Path(artifact_dir) / "model.ckpt")
print(checkpoint["hyper_parameters"])
# load checkpoint
model = KeypointDetector.load_from_checkpoint(Path(artifact_dir) / "model.ckpt", backbone='Unet')

In [None]:
IMAGE_DIR = UnlabeledBoxDataset.get_data_dir_path() / "real_data_2021_12_09"
JSON_PATH = UnlabeledBoxDataset.get_data_dir_path() / "real_data_2021_12_09" / "dataset.json"
dataset = UnlabeledBoxDataset(JSON_PATH,IMAGE_DIR)
print(len(dataset))
dataloader = torch.utils.data.DataLoader(dataset, batch_size= 8, shuffle= True)

In [None]:
def imshow(img):
    """
    plot Tensor as image
    images are kept in the [0,1] range, although in theory [-1,1] should be used to whiten..
    """
    np_img = img.numpy()
    # bring (C,W,H) to (W,H,C) dims
    img = np.transpose(np_img, (1,2,0))
    plt.imshow(img)
    plt.show()


In [None]:
transform  = torchvision.transforms.Resize((256,256))

In [None]:
def show_results(type = "corner", show_extracted_keypoints = True, mode ="eval"):
    """
    show network outputs on the dataset.
    """
    plt.rcParams["figure.figsize"] = (20,3)
    pil_to_torch = torchvision.transforms.ToTensor()
    if mode == "eval":
        model.eval()
    else:
        model.train()
    for batch in iter(dataloader):
        with torch.no_grad():
            batch = transform(batch)
            channel = 1 if type == "flap" else 0
            output = model(batch)[:,channel]
            if not show_extracted_keypoints:
                overlayed_heatmap = torch.stack(
                    [
                        pil_to_torch(overlay_image_with_heatmap(batch[i], torch.unsqueeze(output[i].cpu(), 0),0.6))
                        for i in range(batch.shape[0])
                    ]
                )
            else:
                n_keypoints = 4 if type == "corner" else 8
                overlayed_heatmap = torch.stack(
                [
                    pil_to_torch(
                        overlay_image_with_heatmap(
                            batch[i],
                            torch.unsqueeze(
                                generate_keypoints_heatmap(
                                    batch.shape[-2:],

                                    get_keypoints_from_heatmap(output[i].cpu(), 1,n_keypoints),
                                    sigma=2,
                                    device = 'cpu'
                                ),
                                0,
                            ),
                            0.6
                        )
                    )
                    for i in range(batch.shape[0])
                ]
        )
        grid = torchvision.utils.make_grid(overlayed_heatmap, nrow=8)
        imshow(grid)


In [None]:
show_results("corner",False)

In [None]:
show_results("flap", False)