## This Notebook shows some functionalities of the package. For training, please use the train.py file!

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch 
import numpy as np 
from keypoint_detection.utils.heatmap import gaussian_heatmap, generate_keypoints_heatmap, get_keypoints_from_heatmap
from keypoint_detection.models.detector import KeypointDetector
from keypoint_detection.data.datamodule import RandomSplitDataModule
from keypoint_detection.data.dataset import  KeypointsDataset, KeypointsDatasetPreloaded
from keypoint_detection.models.loss import bce_loss
from keypoint_detection.models.backbones.dilated_cnn import DilatedCnn


In [None]:
# make sure to run `wandb login` in your terminal
wandb_logger = WandbLogger(project="test-project", entity="airo-box-manipulation")

In [None]:
!nvidia-smi

torch.cuda.is_available()

In [None]:
## Demonstration on why one should use max(.) instead of sum to combine the keypoints
# sum wil make 1 blob of 2 neighbouring keypoints
# furthermore it will also reduce the 
img = gaussian_heatmap((32,50),(8,25),torch.Tensor([4]),"cpu")
img2 = gaussian_heatmap((32,50),(12,25),torch.Tensor([4]),"cpu")
print(torch.max(img)) # max (at location of keypoint) should be 1!
f, axarr = plt.subplots(1,2)
axarr[0].imshow(img  +img2)
axarr[1].imshow(torch.max(img, img2))

In [None]:
IMAGE_DATASET_PATH = "/workspaces/box-manipulation/keypoint_detection/datasets/box_dataset2"
JSON_PATH = "/workspaces/box-manipulation/keypoint_detection/datasets/box_dataset2/dataset.json"
CHANNELS = "corner_keypoints"
CHANNEL_SIZE ="4"

In [None]:
def imshow(img):
    """
    plot Tensor as image
    images are kept in the [0,1] range, although in theory [-1,1] should be used to whiten..
    """
    np_img = img.numpy()
    # bring (C,W,H) to (W,H,C) dims
    img = np.transpose(np_img, (1,2,0))
    plt.imshow(img)
    plt.show()


In [None]:
def show_heatmap_overlay(img, heatmap):
    """
    plot Tensors of heatmap and image on same figure 
    """
    fig, ax = plt.subplots()  #create figure and axes
    img = img.numpy()
    img = np.transpose(img, (1,2,0))
    ax.imshow(img, alpha= 0.9)
    ax.imshow(heatmap.numpy(), alpha = 0.2)
    plt.show()

## Dataset strategies for minimizing memory footprint and runtime delay

In [None]:
## test caching influence

dataset = KeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH,CHANNELS, CHANNEL_SIZE)
preloaded_dataset = KeypointsDatasetPreloaded(JSON_PATH, IMAGE_DATASET_PATH,CHANNELS, CHANNEL_SIZE)


In [None]:
for i in range(20):
    batch = dataset[i]

In [None]:
for i in range(20):
    batch_preloaded = preloaded_dataset[i]

In [None]:
# show why to keep np in memory and not torch Tensor.
import sys
print(preloaded_dataset[0][0].dtype)
print(preloaded_dataset.preloaded_images[0].dtype)
# get torch tensor memory size -> 
print(f" torch image size = {sys.getsizeof(preloaded_dataset[0][0].storage())}")
print(f" expected torch image size = {256*256*3*4}") # float32!
# get numpy array memory size -> 
print(preloaded_dataset.preloaded_images[0].nbytes) # uint8
print(256*256*3*1)

In [None]:
## show output of batch

module = RandomSplitDataModule(KeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH,CHANNELS, CHANNEL_SIZE),2,0.25,2)
batch = next(iter(module.train_dataloader()))
#print(batch)

## batch: tuple (IMG, Keypoints)
## img is a (B,C,W,H) tensor
## keypoints is a List of channels
## where each item is of shape (B,N,2/3)
print(len(batch[0]))
img, keypoints = batch
print(img.shape)
print(len(keypoints))
print(keypoints[0].shape)

## Show model input 

In [None]:
pl.seed_everything(2021, workers = True) # deterministic run
model = KeypointDetector(heatmap_sigma=2,maximal_gt_keypoint_pixel_distances="2",minimal_keypoint_extraction_pixel_distance=1,learning_rate=3e-4,backbone=DilatedCnn(),loss_function=bce_loss, keypoint_channels=CHANNELS,ap_epoch_freq=4,ap_epoch_start=10)
dataset = KeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH,CHANNELS, CHANNEL_SIZE)
module = RandomSplitDataModule(dataset,batch_size = 4, validation_split_ratio= 0.1,num_workers= 2)
dummy_input = torch.rand((1,3,180,180))

output = model(dummy_input)
print(output.shape)
print(model)


batch = next(iter(module.train_dataloader()))
imgs, keypoints = batch 
print(imgs[0].shape[1:])
print(imgs.shape)

heatmaps = model.create_heatmap_batch(imgs[0].shape[1:],keypoints[0])
flap_heatmaps = model.create_heatmap_batch(imgs[0].shape[1:], keypoints[0])
print(heatmaps.shape)
show_heatmap_overlay(imgs[0],heatmaps[0])
show_heatmap_overlay(imgs[0],flap_heatmaps[0])

## Train model 

In [None]:
pl.seed_everything(2021, workers = True) # deterministic run
model = KeypointDetector(heatmap_sigma=2,maximal_gt_keypoint_pixel_distances="2",minimal_keypoint_extraction_pixel_distance=1,learning_rate=3e-4,backbone=DilatedCnn(),loss_function=bce_loss, keypoint_channels=CHANNELS,ap_epoch_freq=4,ap_epoch_start=10)
dataset = KeypointsDataset(JSON_PATH, IMAGE_DATASET_PATH,CHANNELS, CHANNEL_SIZE)
module = RandomSplitDataModule(dataset,batch_size = 4, validation_split_ratio= 0.1,num_workers= 2)
# number of batches!
print(len(module.val_dataloader()))
print(len(module.train_dataloader()))
trainer = pl.Trainer(max_epochs = 1, logger=wandb_logger, gpus=0)


In [None]:
%%wandb
trainer.fit(model, module)

## Take a look at the model output

In [None]:
batch = next(iter(module.train_dataloader()))

imgs, keypoints = batch 

with torch.no_grad():
    model.eval()
    predictions = model(imgs)
    heatmaps = model.create_heatmap_batch(imgs[0].shape[1:],keypoints[0])
    show_heatmap_overlay(imgs[0], heatmaps[0])
    show_heatmap_overlay(imgs[0],predictions[0][0])