## This Notebook shows some functionalities of the package. For training, please use the train.py file!

In [1]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch 
import numpy as np 
from keypoint_detection.src.keypoint_utils import gaussian_heatmap, generate_keypoints_heatmap, get_keypoints_from_heatmap
from keypoint_detection import KeypointDetector
from keypoint_detection import BoxKeypointsDataModule, BoxKeypointsDataset, DatasetPreloader


In [None]:

wandb_logger = WandbLogger(project="test-project", entity="airo-box-manipulation")

In [None]:
!nvidia-smi

In [None]:
torch.cuda.is_available()

In [None]:
## Demonstration on why one should use max(.) instead of sum to combine the keypoints
# sum wil make 1 blob of 2 neighbouring keypoints
# furthermore it will also reduce the 
img = gaussian_heatmap((32,50),(8,25),torch.Tensor([4]))
img2 = gaussian_heatmap((32,50),(12,25),torch.Tensor([4]))

f, axarr = plt.subplots(1,2)
axarr[0].imshow(img  +img2)
axarr[1].imshow(torch.max(img, img2))

In [4]:
IMAGE_DATASET_PATH = "/workspaces/box-manipulation/datasets/box_dataset2"
JSON_PATH = "/workspaces/box-manipulation/datasets/box_dataset2/dataset.json"

In [None]:
def imshow(img):
    """
    plot Tensor as image
    images are kept in the [0,1] range, although in theory [-1,1] should be used to whiten..
    """
    np_img = img.numpy()
    # bring (C,W,H) to (W,H,C) dims
    img = np.transpose(np_img, (1,2,0))
    plt.imshow(img)
    plt.show()


In [None]:
def show_heatmap_overlay(img, heatmap):
    """
    plot Tensors of heatmap and image on same figure 
    """
    fig, ax = plt.subplots()  #create figure and axes
    img = img.numpy()
    img = np.transpose(img, (1,2,0))
    ax.imshow(img, alpha= 0.9)
    ax.imshow(heatmap.numpy(), alpha = 0.2)
    plt.show()

In [9]:
## test caching influence

dataset = BoxKeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH)

preloaded_dataset = DatasetPreloader(dataset, 2)

In [16]:
for i in range(len(dataset)):
    a = dataset[i]

In [18]:
for i in range(len(preloaded_dataset)):
    a = preloaded_dataset[i]

In [None]:


module = BoxKeypointsDataModule(BoxKeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH),2)
batch = next(iter(module.train_dataloader()))
#print(batch)
print(batch[0].shape)
print(batch[1].shape)

In [None]:
model = KeypointDetector(heatmap_sigma= 8)
dummy_input = torch.rand((1,3,180,180))

output = model(dummy_input)
print(output.shape)
print(model)

module = BoxKeypointsDataModule(BoxKeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH),2)
batch = next(iter(module.train_dataloader()))
imgs, corner_keypoints, flap_keypoints = batch 
print(imgs[0].shape[1:])
print(imgs.shape)

heatmaps = model.create_heatmap_batch(imgs[0].shape[1:],corner_keypoints)
flap_heatmaps = model.create_heatmap_batch(imgs[0].shape[1:], flap_keypoints)
print(heatmaps.shape)
show_heatmap_overlay(imgs[0],heatmaps[0])
show_heatmap_overlay(imgs[0],flap_heatmaps[0])

In [None]:
pl.seed_everything(2021, workers = True) # deterministic run
model = KeypointDetector(detect_flap_keypoints=False)
module = BoxKeypointsDataModule(BoxKeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH),2)
print(len(module.val_dataloader()))
print(len(module.train_dataloader()))
trainer = pl.Trainer(max_epochs = 1, logger=wandb_logger, gpus=0)


In [None]:
trainer.fit(model, module)

In [None]:
batch = next(iter(module.train_dataloader()))

imgs, corner_keypoints, flap_keypoints = batch 

with torch.no_grad():
    predictions = model(imgs)
    show_heatmap_overlay(imgs[0],predictions[0][0])