diff --git a/.gitignore b/.gitignore index 4c7b546..36062cb 100644 --- a/.gitignore +++ b/.gitignore @@ -4,8 +4,10 @@ .vscode/** datasets/ +scripts/dummy_dataset/ **wandb/ lightning_logs** **.ckpt **/checkpoints/ +build/** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8959224..5a57043 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: files: \.py$ - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort - sort imports diff --git a/README.md b/README.md index 87eed6e..883dd4a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

Pytorch Keypoint Detection

-This repo contains a Python package for 2D keypoint detection using [Pytorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) and [wandb](https://docs.wandb.ai/). Keypoints are trained using Gaussian Heatmaps, as in [Jakab et Al.](https://proceedings.neurips.cc/paper/2018/hash/1f36c15d6a3d18d52e8d493bc8187cb9-Abstract.html) or [Centernet](https://github.com/xingyizhou/CenterNet). +A Framework for keypoint detection using [Pytorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) and [wandb](https://docs.wandb.ai/). Keypoints are trained with Gaussian Heatmaps, as in [Jakab et Al.](https://proceedings.neurips.cc/paper/2018/hash/1f36c15d6a3d18d52e8d493bc8187cb9-Abstract.html) or [Centernet](https://github.com/xingyizhou/CenterNet). This package is been used for research at the [AI and Robotics](https://airo.ugent.be/projects/computervision/) research group at Ghent University. You can see some applications below: The first image shows how this package is used to detect corners of cardboard boxes, in order to close the box with a robot. The second example shows how it is used to detect a varying number of flowers.
@@ -10,15 +10,16 @@ This package is been used for research at the [AI and Robotics](https://airo.uge ## Main Features +- The detector can deal with an **arbitrary number of keypoint channels**, that can contain **a varying amount of keypoints**. You can easily configure which keypoint types from the COCO dataset should be mapped onto the different channels of the keypoint detector. This flexibility allows to e.g. combine different semantic locations that have symmetries onto the same channel to overcome this ambiguity. +- We use the standard **COCO dataset format**. -- This package contains **different backbones** (Unet-like, dilated CNN, Unet-like with pretrained ConvNeXt encoder). Furthermore you can easily add new backbones or loss functions. The head of the keypoint detector is a single CNN layer. -- The package uses the often-used **COCO dataset format**. -- The detector can deal with an **arbitrary number of keypoint channels**, that can contain **a varying amount of keypoints**. You can easily configure which keypoint types from the COCO dataset should be mapped onto the different channels of the keypoint detector. -- The package contains an implementation of the Average Precision metric for keypoint detection. -- Extensive **logging to wandb is provided**: The loss for each channel is logged, together with the AP metrics for all specified treshold distances. Furthermore, the raw heatmaps, detected keypoints and ground truth heatmaps are logged at every epoch for the first batch to provide insight in the training dynamics and to verify all data processing is as desired. +- **different backbones** can be used (Unet-like, dilated CNN, Unet-like with pretrained encoders). Furthermore you can easily add new backbones or loss functions. The head of the keypoint detector is a single CNN layer. + +- The package contains an implementation of the Average Precision metric for keypoint detection. The threshold distance for classification of detections as FP or TP is based on L2 distance between the keypoints and ground truth keypoints. +- Extensive **logging to wandb is provided**: The train/val loss for each channel is logged, together with the AP metrics for all specified treshold distances and all channels. Furthermore, the raw heatmaps, detected keypoints and ground truth heatmaps are logged to provide insight in the training dynamics and to verify all data processing is as desired. - All **hyperparameters are configurable** using a python argumentparser or wandb sweeps. -note: this is the second version of the package, for the older version that used a custom dataset format, see the github releases. +note: this package is still under development and we make no commitment on backwards compatibility nor reproducibility on the main branch. If you need this, it is best to pin a single commit. TODO: add integration example. @@ -30,6 +31,20 @@ TODO: add integration example. - run `wandb login` to set up your wandb account. - you are now ready to start training. + +## Training + +To train a keypoint detector, run the `keypoint-detection train` CLI with the appropriate arguments. +To create your own configuration: run `keypoint-detection train -h` to see all parameter options and their documentation. + +A good starting point could be the bash script `bash test/integration_test.sh` to test on the provided test dataset, which contains 4 images. You should see the loss going down consistently until the detector has completely overfit the train set and the loss is around the entropy of the ground truth heatmaps (if you selected the default BCE loss). + +### Wandb sweeps +Alternatively, you can create a sweep on [wandb](https://wandb.ai) and to then start a (number of) wandb agent(s). This is very useful for running multiple configurations (hparam search, testing on multiple datasets,..) + +### Loading pretrained weights +If you want to load pretrained keypoint detector weights, you can specify the wandb artifact of the checkpoint in the training parameters: `keypoint-detection train ..... -wandb_checkpoint_artifact `. This can be used for example to finetune on real data after pretraining on synthetic data. + ## Dataset This package used the [COCO format](https://cocodataset.org/#format-data) for keypoint annotation and expects a dataset with the following structure: @@ -37,29 +52,50 @@ This package used the [COCO format](https://cocodataset.org/#format-data) for ke dataset/ images/ ... - .json : a COCO-formatted keypoint annotation file. + .json : a COCO-formatted keypoint annotation file with filepaths relative to its parent directory. ``` For an example, see the `test_dataset` at `test/test_dataset`. ### Labeling -If you want to label data, we provide integration with the [CVAT](https://github.com/opencv/cvat) labeling tool: You can annotate your data and export it in their custom format, which can then be converted to COCO format. Take a look [here](labeling/Readme.md) for more information on this workflow and an example. To visualize a given dataset, you can use the `keypoint_detection/utils/visualization.py` script. +If you want to label data, we use[CVAT](https://github.com/opencv/cvat) labeling tool. The flow and the code to create COCO keypoints datasets is all available in the [airo-dataset-tools](https://github.com/airo-ugent/airo-mono/tree/main) package. -## Training +It is best to label your data with floats that represent the subpixel location of the keypoints. This allows for more precise resizing of the images later on. The keypoint detector cast them to ints before training to obtain the pixel they belong to (it does not support sub-pixel detections). + +## Evaluation +TODO +`keypoint-detection eval --help` -There are 2 ways to train the keypoint detector: +## Fiftyone viewer +TODO +`scripts/fiftyone_viewer` -- The first is to run the `train.py` script with the appropriate arguments. e.g. from the root folder of this repo, you can run the bash script `bash test/integration_test.sh` to test on the provided test dataset, which contains 4 images. You should see the loss going down consistently until the detector has completely overfit the train set and the loss is around the entropy of the ground truth heatmaps (if you selected the default BCE loss). +## Using a trained model for Inference +During training Pytorch Lightning will have saved checkpoints. See `scripts/checkpoint_inference.py` for a simple example to run inference with a checkpoint. +For benchmarking the inference (or training), see `scripts/benchmark.py`. -- The second method is to create a sweep on [wandb](https://wandb.ai) and to then start a wandb agent from the correct relative location. -A minimal sweep example is given in `test/configuration.py`. The same content should be written to a yaml file according to the wandb format. The sweep can be started by running `wandb agent ` from your CLI. +## Metrics + +TO calculate AP, precision or recall, the detections need to be classified into False Positives and False negatives as for object detection or instance segmentation. + +This package simply uses a number of euclidian pixel distance thresholds. You can set the euclidian distances for which you want to calculate the metrics in the hyperparameters. + +Pixel perfect keypoints have a pixel distance of 0, so if you want a metric for pixel-perfect keypoints you should add a threshold distance of 0. + +Usually it is best to calculate the real-world deviations (in cm) that are acceptable and then determine the threshold(s) (in pixels) you are interested in. + +In general a lower threshold will result in a lower metric. The size of this gap is determined by the 'ambiguity' of your dataset and/or the accuracy of your labels. + +#TODO: add a figure to illustrate this. + + +We do not use OKS as in COCO for the following reasons: +1. it requires bbox annotations, which are not always required for keypoint detection itself and represent additional label effort. +2. More importantly, in robotics the size of an object does not always correlate with the required precision. If a large and a small mug stand on a table, they require the same precise localisation of keypoints for a robot to grasp them even though their apparent size is different. +3. (you need to estimate label variance, though you could simply set k=1 and skip this part) -To create your own configuration: run `python train.py -h` to see all parameter options and their documentation. -## Using a trained model (Inference) -During training Pytorch Lightning will have saved checkpoints. See `scripts/checkpoint_inference.py` for a simple example to run inference with a checkpoint. -For benchmarking the inference (or training), see `scripts/benchmark.py`. ## Development info - formatting and linting is done using [pre-commit](https://pre-commit.com/) @@ -67,10 +103,35 @@ For benchmarking the inference (or training), see `scripts/benchmark.py`. ## Note on performance -- Keep in mind that the Average Precision is a very expensive operation, it can easily take as long to calculate the AP of a .1 data split as it takes to train on the remaining 90% of the data. Therefore it makes sense to use the metric sparsely. The AP will always be calculated at the final epoch, so for optimal train performance (w/o intermediate feedback), you can e.g. set the `ap_epoch_start` parameter to your max number of epochs + 1. +- Keep in mind that calculating the Average Precision is expensive operation, it can easily take as long to calculate the AP of a .1 data split as it takes to train on the remaining 90% of the data. Therefore it makes sense to use the metric sparsely, for which hyperparameters are available. The AP will always be calculated at the final epoch. + +## Note on top-down vs. bottom-up keypoint detection. +There are 2 ways to do keypoint detection when multiple instances are present in an image: +1. first do instance detection and then detect keypoints on a crop of the bbox for each instance +2. detect keypoints on the full image. + +Option 1 suffers from compounding errors (if the instance is not detected, no keypoints will be detected) and/or requires you to train (and hence label) an object detector. +Option 2 can have lower performance for the keypoints (more 'noise' in the image that can distract the detector) and if you have multiple keypoints / instance as well as multiple instances per image, you need to do keypoint association. + +This repo is somewhat agnostic to that choice. +For 1: crop your dataset upfront and train the detector on those crops, at inference: chain the object detector and the keypoint detector. +for 2: If you can do the association manually, simply do it after inference. However this repo does not offer learning the associations as in the [Part Affinity Fields]() paper. + ## Rationale: TODO - why this repo? - why not label keypoints as bboxes and use YOLO/Detectron2? - .. + +# Citing this project + +You are invited to cite the following publication if you use this keypoint detector in your research: +``` +@inproceedings{lips2022synthkeypoints, + title={Learning Keypoints from Synthetic Data for Robotic Cloth Folding}, + author={Lips, Thomas and De Gusseme, Victor-Louis and others}, + journal={2nd workshop on Representing and Manipulating Deformable Objects - ICRA}, + year={2022} +} +``` diff --git a/environment.yaml b/environment.yaml index e3fd1f1..91fc4ae 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,11 +1,12 @@ name: keypoint-detection # to update an existing environment: conda env update -n --file channels: - pytorch + - nvidia - conda-forge dependencies: - - cudatoolkit=11.3 - python=3.9 - - pytorch + - pytorch=1.13 + - pytorch-cuda=11.7 - torchvision - pip - pip: diff --git a/keypoint_detection/data/coco_dataset.py b/keypoint_detection/data/coco_dataset.py index 733990b..51ec227 100644 --- a/keypoint_detection/data/coco_dataset.py +++ b/keypoint_detection/data/coco_dataset.py @@ -1,5 +1,6 @@ import argparse import json +import math import typing from collections import defaultdict from pathlib import Path @@ -42,21 +43,23 @@ def add_argparse_args(parent_parser: argparse.ArgumentParser) -> argparse.Argume """ parser = parent_parser.add_argument_group("COCOkeypointsDataset") parser.add_argument( - "--detect_non_visible_keypoints", - default=True, - type=str, - help="detect keypoints with visibility flag = 1? default = True", + "--detect_only_visible_keypoints", + dest="detect_only_visible_keypoints", + default=False, + action="store_true", + help="If set, only keypoints with flag > 1.0 will be used.", ) + return parent_parser def __init__( self, json_dataset_path: str, keypoint_channel_configuration: list[list[str]], - detect_non_visible_keypoints: bool = True, + detect_only_visible_keypoints: bool = True, transform: A.Compose = None, imageloader: ImageLoader = None, - **kwargs + **kwargs, ): super().__init__(imageloader) @@ -65,7 +68,9 @@ def __init__( self.dataset_dir_path = self.dataset_json_path.parent # assume paths in JSON are relative to this directory! self.keypoint_channel_configuration = keypoint_channel_configuration - self.detect_non_visible_keypoints = detect_non_visible_keypoints + self.detect_only_visible_keypoints = detect_only_visible_keypoints + + print(f"{detect_only_visible_keypoints=}") self.random_crop_transform = None self.transform = transform @@ -88,6 +93,9 @@ def __getitem__(self, index) -> Tuple[torch.Tensor, IMG_KEYPOINTS_TYPE]: image_path = self.dataset_dir_path / self.dataset[index][0] image = self.image_loader.get_image(str(image_path), index) + # remove a-channel if needed + if image.shape[2] == 4: + image = image[..., :3] keypoints = self.dataset[index][1] @@ -95,6 +103,17 @@ def __getitem__(self, index) -> Tuple[torch.Tensor, IMG_KEYPOINTS_TYPE]: transformed = self.transform(image=image, keypoints=keypoints) image, keypoints = transformed["image"], transformed["keypoints"] + # convert all keypoints to integers values. + # COCO keypoints can be floats if they specify the exact location of the keypoint (e.g. from CVAT) + # even though COCO format specifies zero-indexed integers (i.e. every keypoint in the [0,1]x [0.1] pixel box becomes (0,0) + # we convert them to ints here, as the heatmap generation will add a 0.5 offset to the keypoint location to center it in the pixel + # the distance metrics also operate on integer values. + + # so basically from here on every keypoint is an int that represents the pixel-box in which the keypoint is located. + keypoints = [ + [[math.floor(keypoint[0]), math.floor(keypoint[1])] for keypoint in channel_keypoints] + for channel_keypoints in keypoints + ] image = self.image_to_tensor_transform(image) return image, keypoints @@ -169,10 +188,12 @@ def is_keypoint_visible(self, keypoint: COCO_KEYPOINT_TYPE) -> bool: Returns: bool: True if current keypoint is considered visible according to the dataset configuration, else False """ - minimal_flag = 0 - if not self.detect_non_visible_keypoints: - minimal_flag = 1 - return keypoint[2] > minimal_flag + if self.detect_only_visible_keypoints: + # filter out occluded keypoints with flag 1.0 + return keypoint[2] > 1.5 + else: + # filter out non-labeled keypoints with flag 0.0 + return keypoint[2] > 0.5 @staticmethod def split_list_in_keypoints(list_to_split: List[COCO_KEYPOINT_TYPE]) -> List[List[COCO_KEYPOINT_TYPE]]: diff --git a/keypoint_detection/data/coco_parser.py b/keypoint_detection/data/coco_parser.py index a35c03e..192809e 100644 --- a/keypoint_detection/data/coco_parser.py +++ b/keypoint_detection/data/coco_parser.py @@ -51,6 +51,8 @@ class CocoKeypointAnnotation(BaseModel): image_id: ImageID num_keypoints: Optional[int] + # COCO keypoints can be floats if they specify the exact location of the keypoint (e.g. from CVAT) + # even though COCO format specifies zero-indexed integers (i.e. every keypoint in the [0,1]x [0.1] pixel box becomes (0,0) keypoints: List[float] # TODO: add checks. diff --git a/keypoint_detection/data/datamodule.py b/keypoint_detection/data/datamodule.py index 27bc1c9..36400b3 100644 --- a/keypoint_detection/data/datamodule.py +++ b/keypoint_detection/data/datamodule.py @@ -5,7 +5,7 @@ import numpy as np import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Subset from keypoint_detection.data.augmentations import MultiChannelKeypointsCompose from keypoint_detection.data.coco_dataset import COCOKeypointsDataset @@ -31,7 +31,7 @@ def add_argparse_args(parent_parser: argparse.ArgumentParser) -> argparse.Argume "--json_validation_dataset_path", type=str, help="Absolute path to the json file that defines the validation dataset according to the COCO format. \ - If not specified, the train dataset will be split to create a validation set.", + If not specified, the train dataset will be split to create a validation set if there is one.", ) parser.add_argument( "--json_test_dataset_path", @@ -47,21 +47,24 @@ def add_argparse_args(parent_parser: argparse.ArgumentParser) -> argparse.Argume def __init__( self, - json_dataset_path: str, keypoint_channel_configuration: list[list[str]], - batch_size: int, - validation_split_ratio: float, - num_workers: int, + batch_size: int = 16, + validation_split_ratio: float = 0.25, + num_workers: int = 2, + json_dataset_path: str = None, json_validation_dataset_path: str = None, json_test_dataset_path=None, augment_train: bool = False, - **kwargs + **kwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.augment_train = augment_train - self.train_dataset = COCOKeypointsDataset(json_dataset_path, keypoint_channel_configuration, **kwargs) + + self.train_dataset = None + if json_dataset_path: + self.train_dataset = COCOKeypointsDataset(json_dataset_path, keypoint_channel_configuration, **kwargs) self.validation_dataset = None self.test_dataset = None @@ -71,65 +74,90 @@ def __init__( json_validation_dataset_path, keypoint_channel_configuration, **kwargs ) else: - self.train_dataset, self.validation_dataset = KeypointsDataModule._split_dataset( - self.train_dataset, validation_split_ratio - ) + if self.train_dataset is not None: + print(f"splitting the train set to create a validation set with ratio {validation_split_ratio} ") + self.train_dataset, self.validation_dataset = KeypointsDataModule._split_dataset( + self.train_dataset, validation_split_ratio + ) if json_test_dataset_path: self.test_dataset = COCOKeypointsDataset(json_test_dataset_path, keypoint_channel_configuration, **kwargs) + # create the transforms if needed and set them to the datasets if augment_train: - img_size = self.train_dataset[0][0].shape[1] # assume rectangular! + print("Augmenting the training dataset!") + img_height, img_width = self.train_dataset[0][0].shape[1], self.train_dataset[0][0].shape[2] + aspect_ratio = img_width / img_height train_transform = MultiChannelKeypointsCompose( [ - A.ColorJitter(), - A.RandomRotate90(), - A.HorizontalFlip(), - A.RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0), ratio=(0.95, 1.0)), + A.ColorJitter(p=0.8), + A.RandomBrightnessContrast(p=0.8), + A.RandomResizedCrop( + img_height, img_width, scale=(0.8, 1.0), ratio=(0.9 * aspect_ratio, 1.1 * aspect_ratio), p=1.0 + ), + A.GaussianBlur(p=0.2, blur_limit=(3, 3)), + A.Sharpen(p=0.2), + A.GaussNoise(), ] ) - self.train_dataset.transform = train_transform + if isinstance(self.train_dataset, COCOKeypointsDataset): + self.train_dataset.transform = train_transform + elif isinstance(self.train_dataset, Subset): + # if the train dataset is a subset, we need to set the transform to the underlying dataset + # otherwise the transform will not be applied.. + assert isinstance(self.train_dataset.dataset, COCOKeypointsDataset) + self.train_dataset.dataset.transform = train_transform @staticmethod def _split_dataset(dataset, validation_split_ratio): validation_size = int(validation_split_ratio * len(dataset)) train_size = len(dataset) - validation_size train_dataset, validation_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size]) + print(f"train size: {len(train_dataset)}") + print(f"validation size: {len(validation_dataset)}") return train_dataset, validation_dataset def train_dataloader(self): + # usually need to seed workers for reproducibility + # cf. https://pytorch.org/docs/stable/notes/randomness.html + # but PL does for us in their seeding function: + # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility + + if self.train_dataset is None: + return None + dataloader = DataLoader( self.train_dataset, self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=COCOKeypointsDataset.collate_fn, - pin_memory=True, + pin_memory=True, # usually a little faster ) return dataloader def val_dataloader(self): - def seed_worker(worker_id): - worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) - random.seed(worker_seed) - - g = torch.Generator() - g.manual_seed(0) - # num workers to zero to avoid non-reproducibility bc of random seeds for workers + # usually need to seed workers for reproducibility # cf. https://pytorch.org/docs/stable/notes/randomness.html + # but PL does for us in their seeding function: + # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility + + if self.validation_dataset is None: + return None + dataloader = DataLoader( self.validation_dataset, self.batch_size, shuffle=False, num_workers=self.num_workers, - worker_init_fn=seed_worker, - generator=g, collate_fn=COCOKeypointsDataset.collate_fn, ) return dataloader def test_dataloader(self): + + if self.test_dataset is None: + return None dataloader = DataLoader( self.test_dataset, min(4, self.batch_size), # 4 as max for better visualization in wandb. @@ -138,3 +166,9 @@ def test_dataloader(self): collate_fn=COCOKeypointsDataset.collate_fn, ) return dataloader + + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/keypoint_detection/models/backbones/backbone_factory.py b/keypoint_detection/models/backbones/backbone_factory.py index e959df3..05b60bd 100644 --- a/keypoint_detection/models/backbones/backbone_factory.py +++ b/keypoint_detection/models/backbones/backbone_factory.py @@ -4,14 +4,23 @@ from keypoint_detection.models.backbones.base_backbone import Backbone from keypoint_detection.models.backbones.convnext_unet import ConvNeXtUnet from keypoint_detection.models.backbones.dilated_cnn import DilatedCnn -from keypoint_detection.models.backbones.maxvit_unet import MaxVitUnet +from keypoint_detection.models.backbones.maxvit_unet import MaxVitPicoUnet, MaxVitUnet +from keypoint_detection.models.backbones.mobilenetv3 import MobileNetV3 from keypoint_detection.models.backbones.s3k import S3K from keypoint_detection.models.backbones.unet import Unet class BackboneFactory: # TODO: how to auto-register with __init__subclass over multiple files? - registered_backbone_classes: List[Backbone] = [Unet, ConvNeXtUnet, MaxVitUnet, S3K, DilatedCnn] + registered_backbone_classes: List[Backbone] = [ + Unet, + ConvNeXtUnet, + MaxVitUnet, + MaxVitPicoUnet, + S3K, + DilatedCnn, + MobileNetV3, + ] @staticmethod def create_backbone(backbone_type: str, **kwargs) -> Backbone: diff --git a/keypoint_detection/models/backbones/convnext_unet.py b/keypoint_detection/models/backbones/convnext_unet.py index f9b4fa3..d454961 100644 --- a/keypoint_detection/models/backbones/convnext_unet.py +++ b/keypoint_detection/models/backbones/convnext_unet.py @@ -24,28 +24,28 @@ class UpSamplingBlock(nn.Module): def __init__(self, n_channels_in, n_skip_channels_in, n_channels_out, kernel_size): super().__init__() - self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) - self.conv_reduce = nn.Conv2d( - in_channels=n_channels_in, out_channels=n_skip_channels_in, kernel_size=1, bias=False, padding="same" - ) - self.conv = nn.Conv2d( - in_channels=n_skip_channels_in * 2, + + self.conv1 = nn.Conv2d( + in_channels=n_skip_channels_in + n_channels_in, out_channels=n_channels_out, kernel_size=kernel_size, bias=False, padding="same", ) - self.norm = nn.BatchNorm2d(n_channels_out) - self.relu = nn.ReLU() + + self.norm1 = nn.BatchNorm2d(n_channels_out) + self.relu1 = nn.ReLU() def forward(self, x, x_skip): - x = self.upsample(x) - x = self.conv_reduce(x) + # bilinear is not deterministic, use nearest neighbor instead + x = nn.functional.interpolate(x, scale_factor=2.0) x = torch.cat([x, x_skip], dim=1) - x = self.conv(x) - x = self.norm(x) - x = self.relu(x) + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + # second conv as in original UNet upsampling block decreases performance + # probably because I was using a small dataset that did not have enough data to learn the extra parameters return x @@ -60,12 +60,13 @@ class ConvNeXtUnet(Backbone): nano -> 17M params (but only twice as slow) - (head) - stem final_up (bilinear 4x) - res1 ---> 1/4 decode3 - res2 ---> 1/8 decode2 - res3 ---> 1/16 decode1 - res4 ---1/32----| + input final_conv --- head + stem upsampling + upsamping + res1 ---> 1/4 decode3 + res2 ---> 1/8 decode2 + res3 ---> 1/16 decode1 + res4 ---1/32----| """ def __init__(self, **kwargs): @@ -82,17 +83,18 @@ def __init__(self, **kwargs): block = UpSamplingBlock(channels_in, skip_channels_in, skip_channels_in, 3) self.decoder_blocks.append(block) - self.final_upsampling_block = nn.Sequential( - nn.UpsamplingBilinear2d(scale_factor=4), nn.Conv2d(skip_channels_in, skip_channels_in, 3, padding="same") - ) + self.final_conv = nn.Conv2d(skip_channels_in + 3, skip_channels_in, 3, padding="same") def forward(self, x): + x_orig = torch.clone(x) features = self.encoder(x) x = features.pop() for block in self.decoder_blocks: x = block(x, features.pop()) - x = self.final_upsampling_block(x) + x = nn.functional.interpolate(x, scale_factor=4.0) + x = torch.cat([x, x_orig], dim=1) + x = self.final_conv(x) return x def get_n_channels_out(self): diff --git a/keypoint_detection/models/backbones/maxvit_unet.py b/keypoint_detection/models/backbones/maxvit_unet.py index fe4173d..12d302c 100644 --- a/keypoint_detection/models/backbones/maxvit_unet.py +++ b/keypoint_detection/models/backbones/maxvit_unet.py @@ -28,53 +28,72 @@ class MaxVitUnet(Backbone): For now only 256 is supported so input sizes are restricted to 256,512,... - (head) - stem --- 1/2 --> final_up (bilinear 2x) - stage 1 --- 1/4 --> decode3 - stage 2 --- 1/8 --> decode2 - stage 3 --- 1/16 --> decode1 - stage 4 ---1/32----| + + orig --- 1/1 --> ---> (head) + stem --- 1/2 --> decode4 + stage 1 --- 1/4 --> decode3 + stage 2 --- 1/8 --> decode2 + stage 3 --- 1/16 --> decode1 + stage 4 ---1/32----| """ - # manually gathered for maxvit_nano_rw_256 - feature_config = [ + # 15M params + FEATURE_CONFIG = [ {"down": 2, "channels": 64}, {"down": 4, "channels": 64}, {"down": 8, "channels": 128}, {"down": 16, "channels": 256}, {"down": 32, "channels": 512}, ] + MODEL_NAME = "maxvit_nano_rw_256" feature_layers = ["stem", "stages.0", "stages.1", "stages.2", "stages.3"] def __init__(self, **kwargs) -> None: super().__init__() - self.encoder = timm.create_model("maxvit_nano_rw_256", pretrained=True, num_classes=0) # 15M params + self.encoder = timm.create_model(self.MODEL_NAME, pretrained=True, num_classes=0) self.feature_extractor = create_feature_extractor(self.encoder, self.feature_layers) self.decoder_blocks = nn.ModuleList() - for config_skip, config_in in zip(self.feature_config, self.feature_config[1:]): + for config_skip, config_in in zip(self.FEATURE_CONFIG, self.FEATURE_CONFIG[1:]): block = UpSamplingBlock(config_in["channels"], config_skip["channels"], config_skip["channels"], 3) self.decoder_blocks.append(block) - self.final_upsampling_block = nn.Sequential( - nn.UpsamplingBilinear2d(scale_factor=2), - nn.Conv2d(self.feature_config[0]["channels"], self.feature_config[0]["channels"], 3, padding="same"), + self.final_conv = nn.Conv2d( + self.FEATURE_CONFIG[0]["channels"], self.FEATURE_CONFIG[0]["channels"], 3, padding="same" + ) + self.final_upsampling_block = UpSamplingBlock( + self.FEATURE_CONFIG[0]["channels"], 3, self.FEATURE_CONFIG[0]["channels"], 3 ) def forward(self, x): + orig_x = torch.clone(x) features = list(self.feature_extractor(x).values()) x = features.pop(-1) for block in self.decoder_blocks[::-1]: x = block(x, features.pop(-1)) - x = self.final_upsampling_block(x) + + # x = nn.functional.interpolate(x, scale_factor=2) + # x = self.final_conv(x) + x = self.final_upsampling_block(x, orig_x) return x def get_n_channels_out(self): - return self.feature_config[0]["channels"] + return self.FEATURE_CONFIG[0]["channels"] + + +class MaxVitPicoUnet(MaxVitUnet): + MODEL_NAME = "maxvit_rmlp_pico_rw_256" # 7.5M params. + FEATURE_CONFIG = [ + {"down": 2, "channels": 32}, + {"down": 4, "channels": 32}, + {"down": 8, "channels": 64}, + {"down": 16, "channels": 128}, + {"down": 32, "channels": 256}, + ] if __name__ == "__main__": - # model = timm.create_model("maxvit_rmlp_pico_rw_256") - model = timm.create_model("maxvit_nano_rw_256") + model = timm.create_model("maxvit_rmlp_pico_rw_256") + # model = timm.create_model("maxvit_nano_rw_256") feature_extractor = create_feature_extractor(model, ["stem", "stages.0", "stages.1", "stages.2", "stages.3"]) x = torch.zeros((1, 3, 256, 256)) features = list(feature_extractor(x).values()) @@ -86,3 +105,8 @@ def get_n_channels_out(self): config = {"down": 256 // x.shape[2], "channels": x.shape[1]} feature_config.append(config) print(f"{feature_config=}") + + model = MaxVitPicoUnet() + x = torch.zeros((1, 3, 256, 256)) + y = model(x) + print(f"{y.shape=}") diff --git a/keypoint_detection/models/backbones/mobilenetv3.py b/keypoint_detection/models/backbones/mobilenetv3.py new file mode 100644 index 0000000..fe830ca --- /dev/null +++ b/keypoint_detection/models/backbones/mobilenetv3.py @@ -0,0 +1,42 @@ +"""A MobileNetV3-based backbone. +""" +import timm +import torch +import torch.nn as nn + +from keypoint_detection.models.backbones.base_backbone import Backbone +from keypoint_detection.models.backbones.convnext_unet import UpSamplingBlock + + +class MobileNetV3(Backbone): + """ + Pretrained MobileNetV3 using the large_100 model with 3.4M parameters. + """ + + def __init__(self, **kwargs): + super().__init__() + self.encoder = timm.create_model("mobilenetv3_large_100", pretrained=True, features_only=True) + self.decoder_blocks = nn.ModuleList() + for i in range(1, len(self.encoder.feature_info.info)): + channels_in, skip_channels_in = ( + self.encoder.feature_info.info[-i]["num_chs"], + self.encoder.feature_info.info[-i - 1]["num_chs"], + ) + block = UpSamplingBlock(channels_in, skip_channels_in, skip_channels_in, 3) + self.decoder_blocks.append(block) + + self.final_conv = nn.Conv2d(skip_channels_in, skip_channels_in, 3, padding="same") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + features = self.encoder(x) + + x = features.pop() + for block in self.decoder_blocks: + x = block(x, features.pop()) + x = nn.functional.interpolate(x, scale_factor=2) + x = self.final_conv(x) + + return x + + def get_n_channels_out(self): + return self.encoder.feature_info.info[0]["num_chs"] diff --git a/keypoint_detection/models/backbones/unet.py b/keypoint_detection/models/backbones/unet.py index 48581f8..41b0f1e 100644 --- a/keypoint_detection/models/backbones/unet.py +++ b/keypoint_detection/models/backbones/unet.py @@ -43,7 +43,6 @@ def forward(self, x): class UpSamplingBlock(nn.Module): def __init__(self, n_channels_in, n_channels_out, kernel_size): super().__init__() - self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) self.conv = nn.Conv2d( in_channels=n_channels_in * 2, out_channels=n_channels_out, @@ -55,7 +54,7 @@ def __init__(self, n_channels_in, n_channels_out, kernel_size): self.relu = nn.ReLU() def forward(self, x, x_skip): - x = self.upsample(x) + x = nn.functional.interpolate(x, scale_factor=2) x = torch.cat([x, x_skip], dim=1) x = self.conv(x) x = self.relu(x) diff --git a/keypoint_detection/models/detector.py b/keypoint_detection/models/detector.py index 3b10347..f62b62a 100644 --- a/keypoint_detection/models/detector.py +++ b/keypoint_detection/models/detector.py @@ -9,13 +9,12 @@ from keypoint_detection.models.backbones.base_backbone import Backbone from keypoint_detection.models.metrics import DetectedKeypoint, Keypoint, KeypointAPMetrics -from keypoint_detection.utils.heatmap import ( - BCE_loss, - compute_keypoint_probability, - create_heatmap_batch, - get_keypoints_from_heatmap, +from keypoint_detection.utils.heatmap import BCE_loss, create_heatmap_batch, get_keypoints_from_heatmap_batch_maxpool +from keypoint_detection.utils.visualization import ( + get_logging_label_from_channel_configuration, + visualize_predicted_heatmaps, + visualize_predicted_keypoints, ) -from keypoint_detection.utils.visualization import visualize_predictions class KeypointDetector(pl.LightningModule): @@ -121,10 +120,13 @@ def __init__( # parse the gt pixel distances if isinstance(maximal_gt_keypoint_pixel_distances, str): maximal_gt_keypoint_pixel_distances = [ - float(val) for val in maximal_gt_keypoint_pixel_distances.strip().split(" ") + int(val) for val in maximal_gt_keypoint_pixel_distances.strip().split(" ") ] self.maximal_gt_keypoint_pixel_distances = maximal_gt_keypoint_pixel_distances + self.ap_training_metrics = [ + KeypointAPMetrics(self.maximal_gt_keypoint_pixel_distances) for _ in self.keypoint_channel_configuration + ] self.ap_validation_metrics = [ KeypointAPMetrics(self.maximal_gt_keypoint_pixel_distances) for _ in self.keypoint_channel_configuration ] @@ -256,11 +258,19 @@ def shared_step(self, batch, batch_idx, include_visualization_data_in_result_dic def training_step(self, train_batch, batch_idx): log_images = batch_idx == 0 and self.current_epoch > 0 - result_dict = self.shared_step(train_batch, batch_idx, include_visualization_data_in_result_dict=log_images) + should_log_ap = self.is_ap_epoch() and batch_idx < 20 # limit AP calculation to first 20 batches to save time + include_vis_data = log_images or should_log_ap + + result_dict = self.shared_step( + train_batch, batch_idx, include_visualization_data_in_result_dict=include_vis_data + ) + + if should_log_ap: + self.update_ap_metrics(result_dict, self.ap_training_metrics) if log_images: image_grids = self.visualize_predictions_channels(result_dict) - self.log_image_grids(image_grids, mode="train") + self.log_channel_predictions_grids(image_grids, mode="train") for channel_name in self.keypoint_channel_configuration: self.log(f"train/{channel_name}", result_dict[f"{channel_name}_loss"]) @@ -285,35 +295,37 @@ def visualize_predictions_channels(self, result_dict): image_grids = [] for channel_idx in range(len(self.keypoint_channel_configuration)): - grid = visualize_predictions( + grid = visualize_predicted_heatmaps( input_images, predicted_heatmaps[:, channel_idx, :, :], gt_heatmaps[channel_idx].cpu(), - minimal_keypoint_pixel_distance=6, ) image_grids.append(grid) return image_grids - @staticmethod - def logging_label(channel_configuration, mode: str) -> str: - channel_name = channel_configuration - - if isinstance(channel_configuration, list): - if len(channel_configuration) == 1: - channel_name = channel_configuration[0] - else: - channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..." - - channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name - label = f"{channel_name_short}_{mode}_keypoints" - return label - - def log_image_grids(self, image_grids, mode: str): + def log_channel_predictions_grids(self, image_grids, mode: str): for channel_configuration, grid in zip(self.keypoint_channel_configuration, image_grids): - label = KeypointDetector.logging_label(channel_configuration, mode) - image_caption = "top: predicted heatmaps, middle: predicted keypoints, bottom: gt heatmap" + label = get_logging_label_from_channel_configuration(channel_configuration, mode) + image_caption = "top: predicted heatmaps, bottom: gt heatmaps" self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption)}) + def visualize_predicted_keypoints(self, result_dict): + images = result_dict["input_images"] + predicted_heatmaps = result_dict["predicted_heatmaps"] + # get the keypoints from the heatmaps + predicted_heatmaps = predicted_heatmaps.detach().float() + predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool( + predicted_heatmaps, self.max_keypoints, self.minimal_keypoint_pixel_distance, abs_max_threshold=0.1 + ) + # overlay the images with the keypoints + grid = visualize_predicted_keypoints(images, predicted_keypoints, self.keypoint_channel_configuration) + return grid + + def log_predicted_keypoints(self, grid, mode=str): + label = f"predicted_keypoints_{mode}" + image_caption = "predicted keypoints" + self.logger.experiment.log({label: wandb.Image(grid, caption=image_caption)}) + def validation_step(self, val_batch, batch_idx): # no need to switch model to eval mode, this is handled by pytorch lightning result_dict = self.shared_step(val_batch, batch_idx, include_visualization_data_in_result_dict=True) @@ -321,10 +333,13 @@ def validation_step(self, val_batch, batch_idx): if self.is_ap_epoch(): self.update_ap_metrics(result_dict, self.ap_validation_metrics) - log_images = batch_idx == 0 and self.current_epoch > 0 - if log_images: - image_grids = self.visualize_predictions_channels(result_dict) - self.log_image_grids(image_grids, mode="validation") + log_images = batch_idx == 0 and self.current_epoch > 0 and self.is_ap_epoch() + if log_images and isinstance(self.logger, pl.loggers.wandb.WandbLogger): + channel_grids = self.visualize_predictions_channels(result_dict) + self.log_channel_predictions_grids(channel_grids, mode="validation") + + keypoint_grids = self.visualize_predicted_keypoints(result_dict) + self.log_predicted_keypoints(keypoint_grids, mode="validation") ## log (defaults to on_epoch, which aggregates the logged values over entire validation set) self.log("validation/epoch_loss", result_dict["loss"]) @@ -334,20 +349,53 @@ def test_step(self, test_batch, batch_idx): # no need to switch model to eval mode, this is handled by pytorch lightning result_dict = self.shared_step(test_batch, batch_idx, include_visualization_data_in_result_dict=True) self.update_ap_metrics(result_dict, self.ap_test_metrics) - image_grids = self.visualize_predictions_channels(result_dict) - self.log_image_grids(image_grids, mode="test") + # only log first 10 batches to reduce storage space + if batch_idx < 10 and isinstance(self.logger, pl.loggers.wandb.WandbLogger): + image_grids = self.visualize_predictions_channels(result_dict) + self.log_channel_predictions_grids(image_grids, mode="test") + + keypoint_grids = self.visualize_predicted_keypoints(result_dict) + self.log_predicted_keypoints(keypoint_grids, mode="test") + self.log("test/epoch_loss", result_dict["loss"]) self.log("test/gt_loss", result_dict["gt_loss"]) def log_and_reset_mean_ap(self, mode: str): - mean_ap = 0.0 - metrics = self.ap_test_metrics if mode == "test" else self.ap_validation_metrics - + mean_ap_per_threshold = torch.zeros(len(self.maximal_gt_keypoint_pixel_distances)) + if mode == "train": + metrics = self.ap_training_metrics + elif mode == "validation": + metrics = self.ap_validation_metrics + elif mode == "test": + metrics = self.ap_test_metrics + else: + raise ValueError(f"mode {mode} not recognized") + + # calculate APs for each channel and each threshold distance, and log them + print(f" # {mode} metrics:") for channel_idx, channel_name in enumerate(self.keypoint_channel_configuration): - channel_mean_ap = self.compute_and_log_metrics_for_channel(metrics[channel_idx], channel_name, mode) - mean_ap += channel_mean_ap - mean_ap /= len(self.keypoint_channel_configuration) + channel_aps = self.compute_and_log_metrics_for_channel(metrics[channel_idx], channel_name, mode) + mean_ap_per_threshold += torch.tensor(channel_aps) + + # calculate the mAP over all channels for each threshold distance, and log them + for i, maximal_distance in enumerate(self.maximal_gt_keypoint_pixel_distances): + self.log( + f"{mode}/meanAP/d={float(maximal_distance):.1f}", + mean_ap_per_threshold[i] / len(self.keypoint_channel_configuration), + ) + + # calculate the mAP over all channels and all threshold distances, and log it + mean_ap = mean_ap_per_threshold.mean() / len(self.keypoint_channel_configuration) self.log(f"{mode}/meanAP", mean_ap) + self.log(f"{mode}/meanAP/meanAP", mean_ap) + + def training_epoch_end(self, outputs): + """ + Called on the end of a training epoch. + Used to compute and log the AP metrics. + """ + if self.is_ap_epoch(): + self.log_and_reset_mean_ap("train") def validation_epoch_end(self, outputs): """ @@ -368,46 +416,53 @@ def update_channel_ap_metrics( self, predicted_heatmaps: torch.Tensor, gt_keypoints: List[torch.Tensor], validation_metric: KeypointAPMetrics ): """ - Updates the AP metric for a batch of heatmaps and keypoins of a single channel. + Updates the AP metric for a batch of heatmaps and keypoins of a single channel (!) This is done by extracting the detected keypoints for each heatmap and combining them with the gt keypoints for the same frame, so that the confusion matrix can be determined together with the distance thresholds. - predicted_heatmaps: N x H x W tensor + predicted_heatmaps: N x H x W tensor with the batch of predicted heatmaps for a single channel gt_keypoints: List of size N, containing K_i x 2 tensors with the ground truth keypoints for the channel of that sample """ - # log corner keypoints to AP metrics, frame by frame + # log corner keypoints to AP metrics for all images in this batch formatted_gt_keypoints = [ [Keypoint(int(k[0]), int(k[1])) for k in frame_gt_keypoints] for frame_gt_keypoints in gt_keypoints ] - for i, predicted_heatmap in enumerate(torch.unbind(predicted_heatmaps, 0)): - detected_keypoints = self.extract_detected_keypoints_from_heatmap(predicted_heatmap) - validation_metric.update(detected_keypoints, formatted_gt_keypoints[i]) + batch_detected_channel_keypoints = self.extract_detected_keypoints_from_heatmap( + predicted_heatmaps.unsqueeze(1) + ) + batch_detected_channel_keypoints = [batch_detected_channel_keypoints[i][0] for i in range(len(gt_keypoints))] + for i, detected_channel_keypoints in enumerate(batch_detected_channel_keypoints): + validation_metric.update(detected_channel_keypoints, formatted_gt_keypoints[i]) def compute_and_log_metrics_for_channel( self, metrics: KeypointAPMetrics, channel: str, training_mode: str - ) -> float: + ) -> List[float]: """ - logs AP of predictions of single ChannelĀ² for each threshold distance (as configured) for the categorization of the keypoints into a confusion matrix. - Also resets metric and returns resulting meanAP over all channels. + logs AP of predictions of single Channel for each threshold distance. + Also resets metric and returns resulting AP for all distances. """ - # compute ap's ap_metrics = metrics.compute() - print(f"{ap_metrics=}") + rounded_ap_metrics = {k: round(v, 3) for k, v in ap_metrics.items()} + print(f"{channel} : {rounded_ap_metrics}") for maximal_distance, ap in ap_metrics.items(): - self.log(f"{training_mode}/{channel}_ap/d={maximal_distance}", ap) + self.log(f"{training_mode}/{channel}_ap/d={float(maximal_distance):.1f}", ap) mean_ap = sum(ap_metrics.values()) / len(ap_metrics.values()) + self.log(f"{training_mode}/{channel}_ap/meanAP", mean_ap) # log top level for wandb hyperparam chart. - self.log(f"{training_mode}/{channel}_meanAP", mean_ap) # log top level for wandb hyperparam chart. metrics.reset() - return mean_ap + return list(ap_metrics.values()) def is_ap_epoch(self) -> bool: """Returns True if the AP should be calculated in this epoch.""" - return ( - self.ap_epoch_start <= self.current_epoch and self.current_epoch % self.ap_epoch_freq == 0 - ) or self.current_epoch == self.trainer.max_epochs - 1 + is_epch = self.ap_epoch_start <= self.current_epoch and self.current_epoch % self.ap_epoch_freq == 0 + # always log the AP in the last epoch + is_epch = is_epch or self.current_epoch == self.trainer.max_epochs - 1 + + # if user manually specified a validation frequency, we should always log the AP in that epoch + is_epch = is_epch or (self.current_epoch > 0 and self.trainer.check_val_every_n_epoch > 1) + return is_epch def extract_detected_keypoints_from_heatmap(self, heatmap: torch.Tensor) -> List[DetectedKeypoint]: """ @@ -416,14 +471,27 @@ def extract_detected_keypoints_from_heatmap(self, heatmap: torch.Tensor) -> List Args: heatmap (torch.Tensor) : H x W tensor that represents a heatmap. """ - - detected_keypoints = get_keypoints_from_heatmap( - heatmap, self.minimal_keypoint_pixel_distance, self.max_keypoints + if heatmap.dtype == torch.float16: + # Maxpool_2d not implemented for FP16 apparently + heatmap_to_extract_from = heatmap.float() + else: + heatmap_to_extract_from = heatmap + + keypoints, scores = get_keypoints_from_heatmap_batch_maxpool( + heatmap_to_extract_from, self.max_keypoints, self.minimal_keypoint_pixel_distance, return_scores=True ) - keypoint_probabilities = compute_keypoint_probability(heatmap, detected_keypoints) detected_keypoints = [ - DetectedKeypoint(detected_keypoints[i][0], detected_keypoints[i][1], keypoint_probabilities[i]) - for i in range(len(detected_keypoints)) + [[] for _ in range(heatmap_to_extract_from.shape[1])] for _ in range(heatmap_to_extract_from.shape[0]) ] + for batch_idx in range(len(detected_keypoints)): + for channel_idx in range(len(detected_keypoints[batch_idx])): + for kp_idx in range(len(keypoints[batch_idx][channel_idx])): + detected_keypoints[batch_idx][channel_idx].append( + DetectedKeypoint( + keypoints[batch_idx][channel_idx][kp_idx][0], + keypoints[batch_idx][channel_idx][kp_idx][1], + scores[batch_idx][channel_idx][kp_idx], + ) + ) return detected_keypoints diff --git a/keypoint_detection/models/metrics.py b/keypoint_detection/models/metrics.py index 52a1421..1b4905f 100644 --- a/keypoint_detection/models/metrics.py +++ b/keypoint_detection/models/metrics.py @@ -43,17 +43,17 @@ class ClassifiedKeypoint(DetectedKeypoint): unsafe_hash -> dirty fix to allow for hash w/o explictly telling python the object is immutable. """ - threshold_distance: float + threshold_distance: int true_positive: bool def keypoint_classification( detected_keypoints: List[DetectedKeypoint], ground_truth_keypoints: List[Keypoint], - threshold_distance: float, + threshold_distance: int, ) -> List[ClassifiedKeypoint]: """Classifies keypoints of a **single** frame in True Positives or False Positives by searching for unused gt keypoints in prediction probability order - that are within distance d of the detected keypoint. + that are within distance d of the detected keypoint (greedy matching). Args: detected_keypoints (List[DetectedKeypoint]): The detected keypoints in the frame @@ -73,7 +73,8 @@ def keypoint_classification( matched = False for gt_keypoint in ground_truth_keypoints: distance = detected_keypoint.l2_distance(gt_keypoint) - if distance < threshold_distance: + # add small epsilon to avoid numerical errors + if distance <= threshold_distance + 1e-5: classified_keypoint = ClassifiedKeypoint( detected_keypoint.u, detected_keypoint.v, @@ -131,8 +132,8 @@ def calculate_precision_recall( else: false_positives += 1 - precision.append(true_positives / (true_positives + false_positives)) - recall.append(true_positives / total_ground_truth_keypoints) + precision.append(_zero_aware_division(true_positives, (true_positives + false_positives))) + recall.append(_zero_aware_division(true_positives, total_ground_truth_keypoints)) precision.append(0.0) recall.append(1.0) @@ -209,7 +210,7 @@ class KeypointAPMetrics(Metric): full_state_update = False - def __init__(self, keypoint_threshold_distances: List[float], dist_sync_on_step=False): + def __init__(self, keypoint_threshold_distances: List[int], dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ap_metrics = [KeypointAPMetric(dst, dist_sync_on_step) for dst in keypoint_threshold_distances] @@ -229,6 +230,15 @@ def reset(self) -> None: metric.reset() +def _zero_aware_division(num: float, denom: float) -> float: + if num == 0: + return 0 + if denom == 0 and num != 0: + return float("inf") + else: + return num / denom + + if __name__ == "__main__": print( check_forward_full_state_property( diff --git a/labeling/__init__.py b/keypoint_detection/tasks/__init__.py similarity index 100% rename from labeling/__init__.py rename to keypoint_detection/tasks/__init__.py diff --git a/keypoint_detection/tasks/cli.py b/keypoint_detection/tasks/cli.py new file mode 100644 index 0000000..09e7649 --- /dev/null +++ b/keypoint_detection/tasks/cli.py @@ -0,0 +1,28 @@ +"""cli entry point""" +import sys + +from keypoint_detection.tasks.eval import eval_cli +from keypoint_detection.tasks.train import train_cli + +TRAIN_TASK = "train" +EVAL_TASK = "eval" +TASKS = [TRAIN_TASK, EVAL_TASK] + + +def main(): + # read command line args in plain python + + # TODO this is a very hacky approach for combining independent cli scripts + # should redesign this in the future. + + print(sys.argv) + task = sys.argv[1] + sys.argv.pop(1) + + if task == "--help" or task == "-h": + print("Usage: keypoint-detection [task] [task args]") + print(f"Tasks: {TASKS}") + elif task == TRAIN_TASK: + train_cli() + elif task == EVAL_TASK: + eval_cli() diff --git a/keypoint_detection/tasks/eval.py b/keypoint_detection/tasks/eval.py new file mode 100644 index 0000000..bc4cb7c --- /dev/null +++ b/keypoint_detection/tasks/eval.py @@ -0,0 +1,62 @@ +"""run evaluation on a model for the given dataset""" + + +import argparse + +import pytorch_lightning as pl +import torch + +from keypoint_detection.data.datamodule import KeypointsDataModule +from keypoint_detection.models.detector import KeypointDetector +from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint + + +def evaluate_model(model: KeypointDetector, datamodule: KeypointsDataModule) -> None: + """evaluate the model on the given datamodule and checkpoint path""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + + model.to(device) + model.eval() + + trainer = pl.Trainer( + gpus=1 if torch.cuda.is_available() else 0, + deterministic=True, + ) + output = trainer.test(model, datamodule) + return output + + +def eval_cli(): + argparser = argparse.ArgumentParser() + argparser.add_argument( + "--wandb_checkpoint", type=str, required=True, help="The wandb checkpoint to load the model from" + ) + argparser.add_argument( + "--test_json_path", + type=str, + required=True, + help="The path to the json file that defines the test dataset according to the COCO format.", + ) + args = argparser.parse_args() + + wandb_checkpoint = args.wandb_checkpoint + test_json_path = args.test_json_path + + model = get_model_from_wandb_checkpoint(wandb_checkpoint) + data_module = KeypointsDataModule( + model.keypoint_channel_configuration, json_test_dataset_path=test_json_path, batch_size=8 + ) + evaluate_model(model, data_module) + + +if __name__ == "__main__": + + wandb_checkpoint = "tlips/synthetic-cloth-keypoints-single-towel/model-gl39yjtf:v0" + test_json_path = "/home/tlips/Documents/synthetic-cloth-data/synthetic-cloth-data/data/datasets/TOWEL/07-purple-towel-on-white/annotations_val.json" + test_json_path = "/storage/users/tlips/aRTFClothes/cloth-on-white/purple-towel-on-white_resized_512x256/purple-towel-on-white.json" + model = get_model_from_wandb_checkpoint(wandb_checkpoint) + data_module = KeypointsDataModule( + model.keypoint_channel_configuration, json_test_dataset_path=test_json_path, batch_size=8 + ) + output = evaluate_model(model, data_module) diff --git a/keypoint_detection/tasks/inference.py b/keypoint_detection/tasks/inference.py new file mode 100644 index 0000000..b6bae63 --- /dev/null +++ b/keypoint_detection/tasks/inference.py @@ -0,0 +1,41 @@ +""" run inference on a provided image and save the result to a file """ + +import numpy as np +import torch +from PIL import Image + +from keypoint_detection.models.detector import KeypointDetector +from keypoint_detection.utils.heatmap import get_keypoints_from_heatmap_batch_maxpool +from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint +from keypoint_detection.utils.visualization import draw_keypoints_on_image + + +def run_inference(model: KeypointDetector, image, confidence_threshold: float = 0.1) -> Image: + model.eval() + tensored_image = torch.from_numpy(np.array(image)).float() + tensored_image = tensored_image / 255.0 + tensored_image = tensored_image.permute(2, 0, 1) + tensored_image = tensored_image.unsqueeze(0) + with torch.no_grad(): + heatmaps = model(tensored_image) + + keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmaps, abs_max_threshold=confidence_threshold) + image_keypoints = keypoints[0] + for keypoints, channel_config in zip(image_keypoints, model.keypoint_channel_configuration): + print(f"Keypoints for {channel_config}: {keypoints}") + image = draw_keypoints_on_image(image, image_keypoints, model.keypoint_channel_configuration) + return image + + +if __name__ == "__main__": + wandb_checkpoint = "tlips/synthetic-lego-battery-keypoints/model-tbzd50z8:v0" + image_path = "/home/tlips/Downloads/Lego-battery-real/0.jpg" + # image_path = "/home/tlips/Documents/synthetic-cloth-data/synthetic-cloth-data/data/datasets/LEGO-battery/01/images/0.jpg" + image_size = (256, 256) + + image = Image.open(image_path) + image = image.resize(image_size) + + model = get_model_from_wandb_checkpoint(wandb_checkpoint) + image = run_inference(model, image) + image.save("inference_result.png") diff --git a/keypoint_detection/train/train.py b/keypoint_detection/tasks/train.py similarity index 56% rename from keypoint_detection/train/train.py rename to keypoint_detection/tasks/train.py index e6ab3cc..61cd51a 100644 --- a/keypoint_detection/train/train.py +++ b/keypoint_detection/tasks/train.py @@ -1,5 +1,6 @@ +"""train detector based on argparse configuration""" from argparse import ArgumentParser -from typing import List, Tuple +from typing import Tuple import pytorch_lightning as pl import wandb @@ -9,7 +10,8 @@ from keypoint_detection.data.datamodule import KeypointsDataModule from keypoint_detection.models.backbones.backbone_factory import BackboneFactory from keypoint_detection.models.detector import KeypointDetector -from keypoint_detection.train.utils import create_pl_trainer +from keypoint_detection.tasks.train_utils import create_pl_trainer, parse_channel_configuration +from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint from keypoint_detection.utils.path import get_wandb_log_dir_path @@ -28,7 +30,7 @@ def add_system_args(parent_parser: ArgumentParser) -> ArgumentParser: parser.add_argument( "--keypoint_channel_configuration", type=str, - help="A list of the semantic keypoints that you want to learn in each channel. These semantic categories must be defined in the COCO dataset. Seperate the channels with a ; and the categories within a channel with a =", + help="A list of the semantic keypoints that you want to learn in each channel. These semantic categories must be defined in the COCO dataset. Seperate the channels with a : and the categories within a channel with a =", ) parser.add_argument( @@ -37,42 +39,76 @@ def add_system_args(parent_parser: ArgumentParser) -> ArgumentParser: type=float, help="relative threshold for early stopping callback. If validation epoch loss does not increase with at least this fraction compared to the best result so far for 5 consecutive epochs, training is stopped.", ) + # deterministic argument for PL trainer, not exposed in their CLI. + # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility + # set to True by default, but can be set to False to speed up training. + + parser.add_argument( + "--non-deterministic-pytorch", + action="store_false", + dest="deterministic", + help="do not use deterministic algorithms for pytorch. This can speed up training, but will make it non-reproducible.", + ) + + parser.add_argument( + "--wandb_checkpoint_artifact", + type=str, + help="A checkpoint to resume/start training from. keep in mind that you currently cannot specify hyperparameters other than the LR.", + required=False, + ) + parser.set_defaults(deterministic=True) return parent_parser -def main(hparams: dict) -> Tuple[KeypointDetector, pl.Trainer]: +def train(hparams: dict) -> Tuple[KeypointDetector, pl.Trainer]: """ Initializes the datamodule, model and trainer based on the global hyperparameters. calls trainer.fit(model, module) afterwards and returns both model and trainer. """ + # seed all random number generators on all processes and workers for reproducibility pl.seed_everything(hparams["seed"], workers=True) - backbone = BackboneFactory.create_backbone(**hparams) - model = KeypointDetector(backbone=backbone, **hparams) + # use deterministic algorithms for torch to ensure exact reproducibility + # we have to set it in the trainer! (see create_pl_trainer) + if "wandb_checkpoint_artifact" in hparams.keys(): + print("Loading checkpoint from wandb") + # This will create a KeypointDetector model with the associated hyperparameters. + # Model weights will be loaded. + # Optimizer and LR scheduler will be initiated from scratch" (if you want to really resume training, you have to pass the ckeckpoint to the trainer) + # cf. https://lightning.ai/docs/pytorch/latest/common/checkpointing_basic.html#lightningmodule-from-checkpoint + model = get_model_from_wandb_checkpoint(hparams["wandb_checkpoint_artifact"]) + # TODO: how can specific hparams be overwritten here? e.g. LR reduction for finetuning or something? + else: + backbone = BackboneFactory.create_backbone(**hparams) + model = KeypointDetector(backbone=backbone, **hparams) + data_module = KeypointsDataModule(**hparams) wandb_logger = WandbLogger( project=hparams["wandb_project"], entity=hparams["wandb_entity"], save_dir=get_wandb_log_dir_path(), - log_model="all", # log all checkpoints made by PL, see create_trainer for callback + log_model=True, # only log checkpoints at the end of training, i.e. only log the best checkpoint + # not suitable for expensive training runs where you might want to restart from checkpoint + # but this saves storage and usually keypoint detector training runs are not that expensive anyway ) trainer = create_pl_trainer(hparams, wandb_logger) trainer.fit(model, data_module) if "json_test_dataset_path" in hparams: - trainer.test(model, data_module) - - return model, trainer + # check if we have a best checkpoint, if not, use the current weights but log a warning + # it makes more sense to evaluate on the best checkpoint because, i.e. the best validation score obtained. + # evaluating on the current weights is more noisy and would also result in lower evaluation scores if overfitting happens + # when training longer, even with perfect i.i.d. test/val sets. This is not desired. + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + print("No best checkpoint found, using current weights for test set evaluation") + trainer.test(model, data_module, ckpt_path="best") -def parse_channel_configuration(channel_configuration: str) -> List[List[str]]: - assert isinstance(channel_configuration, str) - channels = channel_configuration.split(";") - channels = [[category.strip() for category in channel.split("=")] for channel in channels] - return channels + return model, trainer -if __name__ == "__main__": +def train_cli(): """ 1. creates argumentparser with Model, Trainer and system paramaters; which can be used to overwrite default parameters when running python train.py -- @@ -112,4 +148,8 @@ def parse_channel_configuration(channel_configuration: str) -> List[List[str]]: print(f" config after wandb init: {hparams}") print("starting training") - main(hparams) + train(hparams) + + +if __name__ == "__main__": + train_cli() diff --git a/keypoint_detection/train/utils.py b/keypoint_detection/tasks/train_utils.py similarity index 78% rename from keypoint_detection/train/utils.py rename to keypoint_detection/tasks/train_utils.py index 22ea20c..94d8027 100644 --- a/keypoint_detection/train/utils.py +++ b/keypoint_detection/tasks/train_utils.py @@ -1,5 +1,5 @@ import inspect -from typing import Optional, Tuple +from typing import List, Optional, Tuple import pytorch_lightning as pl import torch @@ -82,7 +82,21 @@ def create_pl_trainer(hparams: dict, wandb_logger: WandbLogger) -> Trainer: ) # cf https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html - checkpoint_callback = ModelCheckpoint(monitor="validation/epoch_loss", mode="min") + # would be better to use mAP metric for checkpointing, but this is not calculated every epoch because it is rather expensive + # epoch_loss still correlates rather well though + # only store the best checkpoint and only the weights + # so cannot be used to resume training but only for inference + # saves storage though and training the detector is usually cheap enough to retrain it from scratch if you need specific weights etc. + checkpoint_callback = ModelCheckpoint( + monitor="validation/epoch_loss", mode="min", save_weights_only=True, save_top_k=1 + ) trainer = pl.Trainer(**trainer_kwargs, callbacks=[early_stopping, checkpoint_callback]) return trainer + + +def parse_channel_configuration(channel_configuration: str) -> List[List[str]]: + assert isinstance(channel_configuration, str) + channels = channel_configuration.split(":") + channels = [[category.strip() for category in channel.split("=")] for channel in channels] + return channels diff --git a/keypoint_detection/utils/__init__.py b/keypoint_detection/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/keypoint_detection/utils/heatmap.py b/keypoint_detection/utils/heatmap.py index 497fc96..b44c3fc 100644 --- a/keypoint_detection/utils/heatmap.py +++ b/keypoint_detection/utils/heatmap.py @@ -1,4 +1,5 @@ -from typing import List, Tuple +import warnings +from typing import List, Optional, Tuple import numpy as np import torch @@ -51,10 +52,6 @@ def generate_channel_heatmap( Torch.tensor: A Tensor with the combined heatmaps of all keypoints. """ - # cast keypoints (center) to ints to make grid align with pixel raster. - # Otherwise, the AP metric for d = 1 will not result in 1 - # if the gt_heatmaps are used as input. - assert isinstance(keypoints, torch.Tensor) if keypoints.numel() == 0: @@ -77,23 +74,26 @@ def generate_channel_heatmap( return heatmap -def get_keypoints_from_heatmap( +def get_keypoints_from_heatmap_scipy( heatmap: torch.Tensor, min_keypoint_pixel_distance: int, max_keypoints: int = 20 ) -> List[Tuple[int, int]]: """ Extracts at most 20 keypoints from a heatmap, where each keypoint is defined as being a local maximum within a 2D mask [ -min_pixel_distance, + pixel_distance]^2 cf https://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.peak_local_max + THIS IS SLOW! use get_keypoints_from_heatmap_batch_maxpool instead. + + Args: heatmap : heatmap image - min_keypoint_pixel_distance : The size of the local mask + min_keypoint_pixel_distance : The size of the local mask, serves as NMS max_keypoints: the amount of keypoints to determine from the heatmap, -1 to return all points. Defaults to 20 to limit computational burder for models that predict random keypoints in early stage of training. Returns: A list of 2D keypoints """ - + warnings.warn("get_keypoints_from_heatmap_scipy is slow! Use get_keypoints_from_heatmap_batch_maxpool instead.") np_heatmap = heatmap.cpu().numpy().astype(np.float32) # num_peaks and rel_threshold are set to limit computational burden when models do random predictions. @@ -109,6 +109,89 @@ def get_keypoints_from_heatmap( return keypoints[::, ::-1].tolist() # convert to (u,v) aka (col,row) coord frame from (row,col) +def get_keypoints_from_heatmap_batch_maxpool( + heatmap: torch.Tensor, + max_keypoints: int = 20, + min_keypoint_pixel_distance: int = 1, + abs_max_threshold: Optional[float] = None, + rel_max_threshold: Optional[float] = None, + return_scores: bool = False, +) -> List[List[List[Tuple[int, int]]]]: + """Fast extraction of keypoints from a batch of heatmaps using maxpooling. + + Inspired by mmdetection and CenterNet: + https://mmdetection.readthedocs.io/en/v2.13.0/_modules/mmdet/models/utils/gaussian_target.html + + Args: + heatmap (torch.Tensor): NxCxHxW heatmap batch + max_keypoints (int, optional): max number of keypoints to extract, lowering will result in faster execution times. Defaults to 20. + min_keypoint_pixel_distance (int, optional): _description_. Defaults to 1. + + Following thresholds can be used at inference time to select where you want to be on the AP curve. They should ofc. not be used for training + abs_max_threshold (Optional[float], optional): _description_. Defaults to None. + rel_max_threshold (Optional[float], optional): _description_. Defaults to None. + + Returns: + The extracted keypoints for each batch, channel and heatmap; and their scores + """ + + # TODO: maybe separate the thresholding into another function to make sure it is not used during training, where it should not be used? + + # TODO: ugly that the output can change based on a flag.. should always return scores and discard them when I don't need them... + + batch_size, n_channels, _, width = heatmap.shape + + # obtain max_keypoints local maxima for each channel (w/ maxpool) + + kernel = min_keypoint_pixel_distance * 2 + 1 + pad = min_keypoint_pixel_distance + # exclude border keypoints by padding with highest possible value + # bc the borders are more susceptible to noise and could result in false positives + padded_heatmap = torch.nn.functional.pad(heatmap, (pad, pad, pad, pad), mode="constant", value=1.0) + max_pooled_heatmap = torch.nn.functional.max_pool2d(padded_heatmap, kernel, stride=1, padding=0) + # if the value equals the original value, it is the local maximum + local_maxima = max_pooled_heatmap == heatmap + # all values to zero that are not local maxima + heatmap = heatmap * local_maxima + + # extract top-k from heatmap (may include non-local maxima if there are less peaks than max_keypoints) + scores, indices = torch.topk(heatmap.view(batch_size, n_channels, -1), max_keypoints, sorted=True) + indices = torch.stack([torch.div(indices, width, rounding_mode="floor"), indices % width], dim=-1) + # at this point either score > 0.0, in which case the index is a local maximum + # or score is 0.0, in which case topk returned non-maxima, which will be filtered out later. + + # remove top-k that are not local maxima and threshold (if required) + # thresholding shouldn't be done during training + + # moving them to CPU now to avoid multiple GPU-mem accesses! + indices = indices.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + filtered_indices = [[[] for _ in range(n_channels)] for _ in range(batch_size)] + filtered_scores = [[[] for _ in range(n_channels)] for _ in range(batch_size)] + # determine NMS threshold + threshold = 0.01 # make sure it is > 0 to filter out top-k that are not local maxima + if abs_max_threshold is not None: + threshold = max(threshold, abs_max_threshold) + if rel_max_threshold is not None: + threshold = max(threshold, rel_max_threshold * heatmap.max()) + + # have to do this manually as the number of maxima for each channel can be different + for batch_idx in range(batch_size): + for channel_idx in range(n_channels): + candidates = indices[batch_idx, channel_idx] + for candidate_idx in range(candidates.shape[0]): + + # these are filtered out directly. + if scores[batch_idx, channel_idx, candidate_idx] > threshold: + # convert to (u,v) + filtered_indices[batch_idx][channel_idx].append(candidates[candidate_idx][::-1].tolist()) + filtered_scores[batch_idx][channel_idx].append(scores[batch_idx, channel_idx, candidate_idx]) + if return_scores: + return filtered_indices, filtered_scores + else: + return filtered_indices + + def compute_keypoint_probability(heatmap: torch.Tensor, detected_keypoints: List[Tuple[int, int]]) -> List[float]: """Compute probability measure for each detected keypoint on the heatmap @@ -121,3 +204,17 @@ def compute_keypoint_probability(heatmap: torch.Tensor, detected_keypoints: List """ # note the order! (u,v) is how we write , but the heatmap has to be indexed (v,u) as it is H x W return [heatmap[k[1]][k[0]].item() for k in detected_keypoints] + + +if __name__ == "__main__": + import torch.profiler as profiler + + keypoints = torch.tensor([[150, 134], [64, 153]]).cuda() + heatmap = generate_channel_heatmap((1080, 1920), keypoints, 6, "cuda") + heatmap = heatmap.unsqueeze(0).unsqueeze(0).repeat(1, 1, 1, 1) + # heatmap = torch.stack([heatmap, heatmap], dim=0) + print(heatmap.shape) + with profiler.profile(record_shapes=True) as prof: + with profiler.record_function("get_keypoints_from_heatmap_batch_maxpool"): + print(get_keypoints_from_heatmap_batch_maxpool(heatmap, 50, min_keypoint_pixel_distance=5)) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) diff --git a/keypoint_detection/utils/load_checkpoints.py b/keypoint_detection/utils/load_checkpoints.py index 51e7473..d43feb9 100644 --- a/keypoint_detection/utils/load_checkpoints.py +++ b/keypoint_detection/utils/load_checkpoints.py @@ -15,14 +15,17 @@ def get_model_from_wandb_checkpoint(checkpoint_reference: str): import wandb # download checkpoint locally (if not already cached) - run = wandb.init(project="inference") + if wandb.run is None: + run = wandb.init(project="inference") + else: + run = wandb.run artifact = run.use_artifact(checkpoint_reference, type="model") artifact_dir = artifact.download() checkpoint_path = Path(artifact_dir) / "model.ckpt" return load_from_checkpoint(checkpoint_path) -def load_from_checkpoint(checkpoint_path: str): +def load_from_checkpoint(checkpoint_path: str, hparams_to_override: dict = None): """ function to load a Keypoint Detector model from a local pytorch lightning checkpoint. @@ -43,3 +46,8 @@ def load_from_checkpoint(checkpoint_path: str): backbone = BackboneFactory.create_backbone(**checkpoint["hyper_parameters"]) model = KeypointDetector.load_from_checkpoint(checkpoint_path, backbone=backbone) return model + + +if __name__ == "__main__": + model = get_model_from_wandb_checkpoint("tlips/synthetic-cloth-keypoints-tshirts/model-4um302zo:v0") + print(model.hparams) diff --git a/keypoint_detection/utils/visualization.py b/keypoint_detection/utils/visualization.py index e2f0d39..b547513 100644 --- a/keypoint_detection/utils/visualization.py +++ b/keypoint_detection/utils/visualization.py @@ -1,11 +1,30 @@ from argparse import ArgumentParser -from typing import List +from typing import List, Tuple +import numpy as np import torch import torchvision from matplotlib import cm +from PIL import Image, ImageDraw, ImageFont -from keypoint_detection.utils.heatmap import generate_channel_heatmap, get_keypoints_from_heatmap +from keypoint_detection.utils.heatmap import generate_channel_heatmap + + +def get_logging_label_from_channel_configuration(channel_configuration: List[List[str]], mode: str) -> str: + channel_name = channel_configuration + + if isinstance(channel_configuration, list): + if len(channel_configuration) == 1: + channel_name = channel_configuration[0] + else: + channel_name = f"{channel_configuration[0]}+{channel_configuration[1]}+..." + + channel_name_short = (channel_name[:40] + "...") if len(channel_name) > 40 else channel_name + if mode != "": + label = f"{channel_name_short}_{mode}" + else: + label = channel_name_short + return label def overlay_image_with_heatmap(images: torch.Tensor, heatmaps: torch.Tensor, alpha=0.5) -> torch.Tensor: @@ -19,7 +38,22 @@ def overlay_image_with_heatmap(images: torch.Tensor, heatmaps: torch.Tensor, alp return overlayed_images -def overlay_image_with_keypoints(images: torch.Tensor, keypoints: List[torch.Tensor], sigma: float) -> torch.Tensor: +def visualize_predicted_heatmaps( + imgs: torch.Tensor, + predicted_heatmaps: torch.Tensor, + gt_heatmaps: torch.Tensor, +): + num_images = min(predicted_heatmaps.shape[0], 6) + + predicted_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], predicted_heatmaps[:num_images]) + gt_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], gt_heatmaps[:num_images]) + + images = torch.cat([predicted_heatmap_overlays, gt_heatmap_overlays]) + grid = torchvision.utils.make_grid(images, nrow=num_images) + return grid + + +def overlay_images_with_keypoints(images: torch.Tensor, keypoints: List[torch.Tensor], sigma: float) -> torch.Tensor: """ images N x 3 x H x W keypoints list of size N with Tensors C x 2 @@ -49,27 +83,58 @@ def overlay_image_with_keypoints(images: torch.Tensor, keypoints: List[torch.Ten return overlayed_images -def visualize_predictions( - imgs: torch.Tensor, - predicted_heatmaps: torch.Tensor, - gt_heatmaps: torch.Tensor, - minimal_keypoint_pixel_distance: int, -): - num_images = min(predicted_heatmaps.shape[0], 6) - keypoint_sigma = max(1, imgs.shape[2] / 64) - - predicted_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], predicted_heatmaps[:num_images]) - gt_heatmap_overlays = overlay_image_with_heatmap(imgs[:num_images], gt_heatmaps[:num_images]) - predicted_keypoints = [ - torch.tensor(get_keypoints_from_heatmap(predicted_heatmaps[i].cpu(), minimal_keypoint_pixel_distance)) - for i in range(predicted_heatmaps.shape[0]) +def draw_keypoints_on_image( + image: Image, image_keypoints: List[List[Tuple[int, int]]], channel_configuration: List[List[str]] +) -> Image: + """adds all keypoints to the PIL image, with different colors for each channel.""" + color_pool = [ + "#FF00FF", # Neon Purple + "#00FF00", # Electric Green + "#FFFF00", # Cyber Yellow + "#0000FF", # Laser Blue + "#FF0000", # Radioactive Red + "#00FFFF", # Galactic Teal + "#FF00AA", # Quantum Pink + "#C0C0C0", # Holographic Silver + "#000000", # Abyssal Black + "#FFA500", # Cosmic Orange ] - predicted_keypoints_overlays = overlay_image_with_keypoints( - imgs[:num_images], predicted_keypoints[:num_images], keypoint_sigma - ) + image_size = image.size + min_size = min(image_size) + scale = 1 + (min_size // 256) - images = torch.cat([predicted_heatmap_overlays, predicted_keypoints_overlays, gt_heatmap_overlays]) - grid = torchvision.utils.make_grid(images, nrow=num_images) + draw = ImageDraw.Draw(image) + for channel_idx, channel_keypoints in enumerate(image_keypoints): + for keypoint_idx, keypoint in enumerate(channel_keypoints): + u, v = keypoint + draw.ellipse((u - scale, v - scale, u + scale, v + scale), fill=color_pool[channel_idx]) + + draw.text( + (10, channel_idx * 10 * scale), + get_logging_label_from_channel_configuration(channel_configuration[channel_idx], ""), + fill=color_pool[channel_idx], + font=ImageFont.truetype("FreeMono.ttf", size=10 * scale), + ) + + return image + + +def visualize_predicted_keypoints( + images: torch.Tensor, keypoints: List[List[List[List[int]]]], channel_configuration: List[List[str]] +): + drawn_images = [] + num_images = min(images.shape[0], 6) + for i in range(num_images): + # PIL expects uint8 images + image = images[i].permute(1, 2, 0).numpy() * 255 + image = image.astype(np.uint8) + image = Image.fromarray(image) + image = draw_keypoints_on_image(image, keypoints[i], channel_configuration) + drawn_images.append(image) + + drawn_images = torch.stack([torch.from_numpy(np.array(image)).permute(2, 0, 1) / 255 for image in drawn_images]) + + grid = torchvision.utils.make_grid(drawn_images, nrow=num_images) return grid @@ -79,7 +144,7 @@ def visualize_predictions( from torch.utils.data import DataLoader from keypoint_detection.data.coco_dataset import COCOKeypointsDataset - from keypoint_detection.train.train import parse_channel_configuration + from keypoint_detection.tasks.train import parse_channel_configuration from keypoint_detection.utils.heatmap import create_heatmap_batch parser = ArgumentParser() @@ -98,7 +163,7 @@ def visualize_predictions( shape = images.shape[2:] heatmaps = create_heatmap_batch(shape, keypoint_channels[0], sigma=6.0, device="cpu") - grid = visualize_predictions(images, heatmaps, heatmaps, 6) + grid = visualize_predicted_heatmaps(images, heatmaps, heatmaps, 6) image_numpy = grid.permute(1, 2, 0).numpy() plt.imshow(image_numpy) diff --git a/labeling/Readme.md b/labeling/Readme.md deleted file mode 100644 index f17ca67..0000000 --- a/labeling/Readme.md +++ /dev/null @@ -1,76 +0,0 @@ -# CVAT to COCO Keypoints - -This readme defines a workflow to label semantic keypoints on images using [CVAT](https://www.cvat.ai/) and to convert them to the [COCO keypoints format](https://cocodataset.org/#format-data). - This package contains parsers for the different dataset formats and code to convert from the CVAT Image 1.1 format to COCO format. - - - -## Labeling use case analysis -- **we want to label semantic keypoints on images**. -- There can be **multiple categories / classes of objects in the images**. Each category can have 0 - N instances in each image. (think about categories/classes as objects that you could draw a bounding box or segmentation mask for). -- Each category has a number of **semantic types** of keypoints that are of interest. E.g. arms, shoulders, head,.. for the person category. -- Each semantic type can contain multiple keypoints (e.g. a human has 2 shoulders, a cardboard box has 4 corners). Although you could label these separately (and for humans this is very natural as humans have a front/back side, unlike boxes for which there is no semantic difference between the corners), this creates a burden as you have to do this in a geometrically consistent way by e.g. always labeling most topleft corner of a box as 'corner1'. This is easily done afterwards using e.g. the quadrant of each corner and asking the labeler to do so only leads to more work and possible inaccuracies. Therefore each semantic type can have 0 - K keypoints. - -**So each image has N_i instances of each of the K categories and each instance has the M semantic types of that category, where each each type has S_i keypoints.** - -We hence need to be able to -- group the keypoints of a single instance of a category together -- and to label multiple keypoints under one semantic type and later separate them for the COCO format (which does not allow for multiple keypoints for a single type). - - -## CVAT configuration -In CVAT we create a **label** for each **semantic type** of each **category** using the naming convention **{category}.{semantic_type}**. You then label all keypoints of a single type in the image. If there are multiple instances of a single category, you group them together using the Grouping Tool in CVAT, if there is only one instance of a category, there is no need to manually group them together. - -After labeling, the annotations XML can be downloaded. - -## Converting CVAT to COCO -- clone this repo and pip install the requirements of this package. -- [set up CVAT](docs/cvat_setup.md) -- create a task and define your labels. -- label your images. -- export the annotations XML in the CVAT images format. -- create a Category configuration that specifies for each category: - - its name - - its desired ID - - its supercategory - - its semantic types and how much keypoints each type has - - This configuration is then used to create the COCO object categories and define all the keypoints for each categorie. - -- then, from this readme's folder, run `python convert_cvat_to_coco.py --cvat_xml_file example/annotations.xml --coco_categories_config_path example/coco_category_configuration.json` to create a `coco.json` annotation file. You should now have a COCO dataset annotation file, that you can use for example with - -## Example -There is an example included for 4 images containing a number of tshirts. -The desired categories configuration is as follows (see `examples/coco_category_configuration.json`) -``` -{ - "categories": [ - { - "name": "tshirt", - "supercategory": "cloth", - "id": 23, - "semantic_types": [ - { - "name": "neck", - "n_keypoints": 1 - }, - { - "name": "shoulder", - "n_keypoints": 2 - } - ] - } - ] -} -``` - -which implies we have 1 object class that has 2 semantic types: 'neck' with 1 keypoint and 'shoulder' with 2 keypoints (left/right, which is an artificial example as a tshirt has a front and back side and hence there is no ambiguity) - -To label this configuration, we create 2 labels in cvat: -![alt text](docs/cvat_example_setup.png). - - -One image contains 2 instances of the tshirt, which should hence be grouped using CVATs group_id. One image is not labeled to show this can be dealt with. Another image is partially labeled to simulate partial visibility. -You can find the resulting CVAT annotations in `example/annotations.xml`. - -You can now finally convert the CVAT annotations to COCO annotation format, which results in the `example/coco.json` file. diff --git a/labeling/convert_cvat_to_coco.py b/labeling/convert_cvat_to_coco.py deleted file mode 100644 index 79da86e..0000000 --- a/labeling/convert_cvat_to_coco.py +++ /dev/null @@ -1,232 +0,0 @@ -from __future__ import annotations - -import json -from typing import List - -import tqdm - -from keypoint_detection.data.coco_parser import CocoImage, CocoKeypointAnnotation, CocoKeypointCategory, CocoKeypoints -from labeling.file_loading import get_dict_from_json, get_dict_from_xml -from labeling.parsers.coco_categories_parser import COCOCategoriesConfig, COCOCategoryConfig -from labeling.parsers.cvat_keypoints_parser import CVATKeypointsParser, ImageItem, Point - - -def cvat_image_to_coco( - cvat_xml_path: str, coco_category_configuration_path: str, image_folder: str = "images" -) -> dict: - """Function that converts an annotation XML in the CVAT 1.1 Image format to the COCO keypoints format. - - This function supports: - - multiple categories (box, tshirt); - - multiple semantic types for each category ("corners", "flap_corners") - - multiple keypoints for a single semantic type (a box has 4 corners) to facilitate fast labeling (no need to label each corner with a separate label, which requires geometric consistency) - - occluded or invisible keypoints for each type - - It requires the CVAT dataset to be created by using labels formatted as ., using the group_id to group multiple instances together. - if only a single instance is present, the group id is set to 1 by default so you don't have to do this yourself. - - To map from the CVAT labels to the COCO categories, you need to specify a configuration. - See the readme for more details and an example. - - This function is rather complex unfortunately, but at a high level it performs the following: - # for all categories in the config: - # for all images: - # create COCO Image - # find number of category instances in that images - # for each instance in the image: - # for all semantic types in the category: - # find all keypoints of that type for that instance in the current image - # create a COCO Annotation for the current instance of the category - - Args: - cvat_xml_path (str): _description_ - coco_category_configuration_path (str): _description_ - - Returns: - (dict): a COCO dict that can be dumped to a JSON. - """ - cvat_dict = get_dict_from_xml(cvat_xml_path) - cvat_parsed = CVATKeypointsParser(**cvat_dict) - - category_dict = get_dict_from_json(coco_category_configuration_path) - parsed_category_config = COCOCategoriesConfig(**category_dict) - - # create a COCO Dataset Model - coco_model = CocoKeypoints(images=[], annotations=[], categories=[]) - - annotation_id_counter = 1 # counter for the annotation ID - - print("starting CVAT Image -> COCO conversion") - for category in parsed_category_config.categories: - print(f"converting category {category.name}") - category_name = category.name - category_keypoint_names = get_coco_keypoint_names_from_category_config(category) - coco_model.categories.append( - CocoKeypointCategory( - id=category.id, - name=category.name, - supercategory=category.supercategory, - keypoints=category_keypoint_names, - ) - ) - - for cvat_image in tqdm.tqdm(cvat_parsed.annotations.image): - coco_image = CocoImage( - file_name=f"{image_folder}/{cvat_image.name}", - height=int(cvat_image.height), - width=int(cvat_image.width), - id=int(cvat_image.id) + 1, - ) - coco_model.images.append(coco_image) - n_image_category_instances = get_n_category_instances_in_image(cvat_image, category_name) - for instance_id in range(1, n_image_category_instances + 1): # IDs start with 1 - instance_category_keypoints = [] - for semantic_type in category.semantic_types: - keypoints = get_semantic_type_keypoints_from_instance_in_cvat_image( - cvat_image, semantic_type.name, instance_id - ) - - # pad for invisible keypoints for the given instance of the semantic type. - keypoints.extend([0.0] * (3 * semantic_type.n_keypoints - len(keypoints))) - instance_category_keypoints.extend(keypoints) - - coco_model.annotations.append( - CocoKeypointAnnotation( - category_id=category.id, - id=annotation_id_counter, - image_id=coco_image.id, - keypoints=instance_category_keypoints, - ) - ) - annotation_id_counter += 1 - return coco_model.dict(exclude_none=True) - - -### helper functions - - -def get_n_category_instances_in_image(cvat_image: ImageItem, category_name: str) -> int: - """returns the number of instances for the specified category in the CVAT ImageItem. - - This is done by finding the maximum group_id for all annotations of the image. - - Edge cases include: no Points in the image or only 1 Point in the image. - """ - if cvat_image.points is None: - return 0 - if not isinstance(cvat_image.points, list): - if get_category_from_cvat_label(cvat_image.points.label) == category_name: - return int(cvat_image.points.group_id) - else: - return 0 - max_group_id = 1 - for cvat_point in cvat_image.points: - if get_category_from_cvat_label(cvat_point.label) == category_name: - max_group_id = max(max_group_id, int(cvat_point.group_id)) - return max_group_id - - -def get_category_from_cvat_label(label: str) -> str: - """cvat labels are formatted as . - this function returns the category - """ - split = label.split(".") - assert len(split) == 2, " label was not formatted as category.semantic_type" - return label.split(".")[0] - - -def get_semantic_type_from_cvat_label(label: str) -> str: - """cvat labels are formatted as . - this function returns the semantic type - """ - split = label.split(".") - assert len(split) == 2, " label was not formatted as category.semantic_type" - return label.split(".")[1] - - -def get_coco_keypoint_names_from_category_config(config: COCOCategoryConfig) -> List[str]: - """Helper function that converts a CategoryConfiguration to a list of coco keypoints. - This function duplicates keypoints for types with n_keypoints > 1 by appending an index: - e.g. "corner", n_keypoints = 2 -> ["corner1" ,"corner2"]. - - Args: - config (dict): _description_ - - Returns: - _type_: _description_ - """ - keypoint_names = [] - for semantic_type in config.semantic_types: - if semantic_type.n_keypoints == 1: - keypoint_names.append(semantic_type.name) - else: - for i in range(semantic_type.n_keypoints): - keypoint_names.append(f"{semantic_type.name}{i+1}") - return keypoint_names - - -def get_semantic_type_keypoints_from_instance_in_cvat_image( - cvat_image: ImageItem, semantic_type: str, instance_id: int -) -> List[float]: - """Gather all keypoints of the given semantic type for this in the image. - - Args: - cvat_image (ImageItem): _description_ - semantic_type (str): _description_ - instance_id (int): _description_ - - Returns: - List: _description_ - """ - instance_id = str(instance_id) - if cvat_image.points is None: - return [0.0, 0.0, 0] - if not isinstance(cvat_image.points, list): - if ( - semantic_type == get_semantic_type_from_cvat_label(cvat_image.points.label) - and instance_id == cvat_image.points.group_id - ): - return extract_coco_keypoint_from_cvat_point(cvat_image.points, cvat_image) - else: - return [0.0, 0.0, 0] - keypoints = [] - for cvat_point in cvat_image.points: - if semantic_type == get_semantic_type_from_cvat_label(cvat_point.label) and instance_id == cvat_point.group_id: - keypoints.extend(extract_coco_keypoint_from_cvat_point(cvat_point, cvat_image)) - return keypoints - - -def extract_coco_keypoint_from_cvat_point(cvat_point: Point, cvat_image: ImageItem) -> List: - """extract keypoint in coco format (u,v,f) from cvat annotation point. - Args: - cvat_point (Point): _description_ - cvat_image (ImageItem): _description_ - - Returns: - List: [u,v,f] where u,v are the coords scaled to the image resolution and f is the coco visibility flag. - see the coco dataset format for more details. - """ - u = float(cvat_point.points.split(",")[0]) - v = float(cvat_point.points.split(",")[1]) - f = ( - 1 if cvat_point.occluded == "1" else 2 - ) # occluded = 1 means not visible, which is 1 in COCO; visible in COCO is 2 - return [u, v, f] - - -if __name__ == "__main__": - """ - example usage: - - python convert_cvat_to_coco.py --cvat_xml_file example/annotations.xml --coco_categories_config_path example/coco_category_configuration.json - """ - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--cvat_xml_file", type=str, required=True) - parser.add_argument("--coco_categories_config_path", type=str, required=True) - - args = parser.parse_args() - coco = cvat_image_to_coco(args.cvat_xml_file, args.coco_categories_config_path) - with open("coco.json", "w") as file: - json.dump(coco, file) diff --git a/labeling/docs/cvat_example_setup.png b/labeling/docs/cvat_example_setup.png deleted file mode 100644 index e838893..0000000 Binary files a/labeling/docs/cvat_example_setup.png and /dev/null differ diff --git a/labeling/docs/cvat_setup.md b/labeling/docs/cvat_setup.md deleted file mode 100644 index 3c680d7..0000000 --- a/labeling/docs/cvat_setup.md +++ /dev/null @@ -1,4 +0,0 @@ -CVAT provides a docker compose file to set up everyting on a private machine, see instructions [here](https://cvat-ai.github.io/cvat/docs/administration/basics/installation/#ubuntu-1804-x86_64amd64). At the time of writing however, you need to change the container tag to `dev`, cf [this issue](https://github.com/opencv/cvat/issues/4816). - - - You can also do this setup on a remote machine, in which case you can either make the client reachable over the web or forward the tcp conncetion to your local machine using ssh: `ssh -L :: @` diff --git a/labeling/example/annotations.json b/labeling/example/annotations.json deleted file mode 100644 index 1c7d36f..0000000 --- a/labeling/example/annotations.json +++ /dev/null @@ -1,109 +0,0 @@ -{ - "info": null, - "licenses": null, - "images": [ - { - "license": null, - "file_name": "1.jpeg", - "height": 256, - "width": 256, - "id": 1 - }, - { - "license": null, - "file_name": "2.jpeg", - "height": 256, - "width": 256, - "id": 2 - }, - { - "license": null, - "file_name": "3.jpeg", - "height": 256, - "width": 256, - "id": 3 - }, - { - "license": null, - "file_name": "4.jpeg", - "height": 256, - "width": 256, - "id": 4 - } - ], - "categories": [ - { - "supercategory": "onion", - "id": 0, - "name": "onion", - "keypoints": [ - "head", - "tail" - ], - "skeleton": [ - [ - 1, - 0 - ] - ] - } - ], - "annotations": [ - { - "category_id": 0, - "id": 1, - "image_id": 1, - "num_keypoints": null, - "keypoints": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ] - }, - { - "category_id": 0, - "id": 2, - "image_id": 2, - "num_keypoints": null, - "keypoints": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ] - }, - { - "category_id": 0, - "id": 3, - "image_id": 3, - "num_keypoints": null, - "keypoints": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ] - }, - { - "category_id": 0, - "id": 4, - "image_id": 4, - "num_keypoints": null, - "keypoints": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ] - } - ] -} diff --git a/labeling/example/annotations.xml b/labeling/example/annotations.xml deleted file mode 100644 index 82e17dd..0000000 --- a/labeling/example/annotations.xml +++ /dev/null @@ -1,76 +0,0 @@ - - 1.1 - - - 3 - example-keypoints-task - 4 - annotation - 0 - - 2022-08-24 11:46:30.553731+00:00 - 2022-08-24 12:51:45.689497+00:00 - Train - 0 - 3 - - - - 3 - 0 - 3 - http://localhost:8080/?id=3 - - - - tlips - thomas.lips@ugent.be - - - - - - - - 2022-08-24 12:51:55.830700+00:00 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/labeling/example/coco.json b/labeling/example/coco.json deleted file mode 100644 index a045df1..0000000 --- a/labeling/example/coco.json +++ /dev/null @@ -1,106 +0,0 @@ -{ - "images": [ - { - "file_name": "images/1.jpeg", - "height": 256, - "width": 256, - "id": 1 - }, - { - "file_name": "images/2.jpeg", - "height": 256, - "width": 256, - "id": 2 - }, - { - "file_name": "images/3.jpeg", - "height": 256, - "width": 256, - "id": 3 - }, - { - "file_name": "images/4.jpeg", - "height": 256, - "width": 256, - "id": 4 - } - ], - "categories": [ - { - "supercategory": "cloth", - "id": 23, - "name": "tshirt", - "keypoints": [ - "neck", - "shoulder1", - "shoulder2" - ] - } - ], - "annotations": [ - { - "category_id": 23, - "id": 1, - "image_id": 1, - "keypoints": [ - 126.0, - 26.6, - 2.0, - 64.1, - 31.8, - 2.0, - 181.8, - 28.9, - 2.0 - ] - }, - { - "category_id": 23, - "id": 2, - "image_id": 2, - "keypoints": [ - 127.68, - 61.3, - 2.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ] - }, - { - "category_id": 23, - "id": 3, - "image_id": 3, - "keypoints": [ - 71.96, - 41.64, - 2.0, - 38.52, - 41.31, - 2.0, - 102.44, - 40.0, - 2.0 - ] - }, - { - "category_id": 23, - "id": 4, - "image_id": 3, - "keypoints": [ - 187.34, - 40.33, - 1.0, - 152.27, - 45.9, - 2.0, - 221.76, - 44.59, - 2.0 - ] - } - ] -} diff --git a/labeling/example/coco_category_configuration.json b/labeling/example/coco_category_configuration.json deleted file mode 100644 index e448053..0000000 --- a/labeling/example/coco_category_configuration.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "categories": [ - { - "name": "tshirt", - "supercategory": "cloth", - "id": 23, - "semantic_types": [ - { - "name": "neck", - "n_keypoints": 1 - }, - { - "name": "shoulder", - "n_keypoints": 2 - } - ] - } - ] -} diff --git a/labeling/example/images/1.jpeg b/labeling/example/images/1.jpeg deleted file mode 100644 index f3f2e14..0000000 Binary files a/labeling/example/images/1.jpeg and /dev/null differ diff --git a/labeling/example/images/2.jpeg b/labeling/example/images/2.jpeg deleted file mode 100644 index ea58071..0000000 Binary files a/labeling/example/images/2.jpeg and /dev/null differ diff --git a/labeling/example/images/3.jpeg b/labeling/example/images/3.jpeg deleted file mode 100644 index 4e9e7d5..0000000 Binary files a/labeling/example/images/3.jpeg and /dev/null differ diff --git a/labeling/example/images/4.jpeg b/labeling/example/images/4.jpeg deleted file mode 100644 index cba5f15..0000000 Binary files a/labeling/example/images/4.jpeg and /dev/null differ diff --git a/labeling/file_loading.py b/labeling/file_loading.py deleted file mode 100644 index 823084e..0000000 --- a/labeling/file_loading.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/python - -import json - -import xmltodict - - -def get_dict_from_xml(xml_path): - with open(xml_path, "r") as file: - # prefixes @/_ lead to issues with pydantic parsing! so simply use no prefix - xml_dict = xmltodict.parse(file.read(), attr_prefix="") - return xml_dict - - -def get_dict_from_json(path): - with open(path, "r") as file: - return json.load(file) - - -if __name__ == "__main__": - """convert XML to JSON""" - import sys - - assert len(sys.argv) == 2 - xml_path = sys.argv[1] - json_path = xml_path[:-3] + "json" - - print(f"converting {xml_path} to {json_path}") - - xml_dict = get_dict_from_xml(xml_path) - with open(json_path, "w") as outfile: - json.dump(xml_dict, outfile) diff --git a/labeling/parsers/coco_categories_parser.py b/labeling/parsers/coco_categories_parser.py deleted file mode 100644 index c496e77..0000000 --- a/labeling/parsers/coco_categories_parser.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Parser for configuration of COCO categories to convert CVAT XML to COCO Keypoints -""" -from typing import List - -from pydantic import BaseModel - - -class COCOSemanticTypeConfig(BaseModel): - name: str - n_keypoints: int - - -class COCOCategoryConfig(BaseModel): - supercategory: str - id: int - name: str - semantic_types: List[COCOSemanticTypeConfig] - - -class COCOCategoriesConfig(BaseModel): - categories: List[COCOCategoryConfig] diff --git a/labeling/parsers/cvat_keypoints_parser.py b/labeling/parsers/cvat_keypoints_parser.py deleted file mode 100644 index e48a8c1..0000000 --- a/labeling/parsers/cvat_keypoints_parser.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Parser for CVAT Images 1.1 keypoint annotations - -see CVAT Documentation for format information - -This file was created by labeling the example images using the flow that is described in the repo. -Then this xml was converted to JSON using xmltodict -Then an initial version of the parser was created using https://pydantic-docs.helpmanual.io/datamodel_code_generator/: -`datamodel-codegen --input annotations.json --input-file-type json --output cvat_keypoints_parser.py --class-name CVATKeypointsParser` - -And the parser was then further finetuned. - -Note that the attributes cannot have _ as prefix for Pydantic. -""" - -# generated by datamodel-codegen: -# filename: annotations.json -# timestamp: 2022-08-24T12:15:52+00:00 -# then further changed by @tlips - -from __future__ import annotations - -from typing import Any, List, Optional, Union - -from pydantic import BaseModel - - -class Segment(BaseModel): - id: str - start: str - stop: str - url: str - - -class Segments(BaseModel): - segment: Segment - - -class Owner(BaseModel): - username: str - email: str - - -class LabelItem(BaseModel): - name: str - color: str - type: str - attributes: Any - - -class Labels(BaseModel): - label: Union[List[LabelItem], LabelItem] - - -class Task(BaseModel): - id: str - name: str - size: str - mode: str - overlap: str - bugtracker: Any - created: str - updated: str - subset: str - start_frame: str - stop_frame: str - frame_filter: Any - segments: Segments - owner: Owner - assignee: Any - labels: Labels - - -class Meta(BaseModel): - task: Task - dumped: str - - -class Point(BaseModel): - label: str - occluded: str - source: str - points: str - z_order: str - group_id: Optional[str] = "1" # set default group id to 1. - - -class ImageItem(BaseModel): - id: str - name: str - width: str - height: str - points: Optional[Union[List[Point], Point]] = None - - -class Annotations(BaseModel): - version: str - meta: Meta - image: List[ImageItem] - - -class CVATKeypointsParser(BaseModel): - annotations: Annotations diff --git a/labeling/requirements.txt b/labeling/requirements.txt deleted file mode 100644 index efed842..0000000 --- a/labeling/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -xmltodict -pydantic -tqdm diff --git a/labeling/scripts/crop_coco_dataset.py b/labeling/scripts/crop_coco_dataset.py deleted file mode 100644 index 71dc4b9..0000000 --- a/labeling/scripts/crop_coco_dataset.py +++ /dev/null @@ -1,113 +0,0 @@ -import json -import os -from argparse import ArgumentParser -from collections import defaultdict - -import albumentations as A -import cv2 - -from keypoint_detection.data.coco_dataset import COCOKeypointsDataset -from keypoint_detection.data.coco_parser import CocoKeypoints -from keypoint_detection.data.imageloader import ImageLoader, IOSafeImageLoaderDecorator - - -def save_cropped_image_and_edit_annotations( - i, image_info, image_annotations, height_new, width_new, image_loader, input_dataset_path, output_dataset_path -): - input_image_path = os.path.join(input_dataset_path, image_info.file_name) - image = image_loader.get_image(input_image_path, i) - - min_size = min(image.shape[0], image.shape[1]) - transform = A.Compose( - [ - A.CenterCrop(min_size, min_size), - A.Resize(height_new, width_new), - ], - keypoint_params=A.KeypointParams(format="xy", remove_invisible=False), - ) - - # Extract keypoints to the format albumentations wants. - image_keypoints = [] - for annotation in image_annotations: - annotation_keypoints = COCOKeypointsDataset.split_list_in_keypoints(annotation.keypoints) - for keypoint in annotation_keypoints: - image_keypoints.append(keypoint[:2]) - keypoints_xy = [keypoint[:2] for keypoint in image_keypoints] - - # Transform image and keypoints - transformed = transform(image=image, keypoints=keypoints_xy) - transformed_image = transformed["image"] - transformed_keypoints = transformed["keypoints"] - - # Edit the original keypoints. - index = 0 - for annotation in image_annotations: - for i in range(len(annotation.keypoints) // 3): - annotation.keypoints[3 * i : 3 * i + 2] = transformed_keypoints[index] - index += 1 - - # Save transformed image to disk - output_image_path = os.path.join(output_dataset_path, image_info.file_name) - image_directory = os.path.dirname(output_image_path) - os.makedirs(image_directory, exist_ok=True) - image_bgr = cv2.cvtColor(transformed_image, cv2.COLOR_RGB2BGR) - cv2.imwrite(output_image_path, image_bgr) - - -def create_cropped_dataset(input_json_dataset_path, height_new, width_new): - input_dataset_path = os.path.dirname(input_json_dataset_path) - output_dataset_path = input_dataset_path + f"_{height_new}x{width_new}" - - if os.path.exists(output_dataset_path): - print(f"{output_dataset_path} exists, quiting.") - return - - with open(input_json_dataset_path, "r") as file: - data = json.load(file) - parsed_coco = CocoKeypoints(**data) - - image_loader = IOSafeImageLoaderDecorator(ImageLoader()) - annotations = parsed_coco.annotations - - images_annotations = defaultdict(list) - for annotation in annotations: - print(type(annotation)) - images_annotations[annotation.image_id].append(annotation) - - for i, image_info in enumerate(parsed_coco.images): - image_annotations = images_annotations[image_info.id] - save_cropped_image_and_edit_annotations( - i, - image_info, - image_annotations, - height_new, - width_new, - image_loader, - input_dataset_path, - output_dataset_path, - ) - - annotations_json = os.path.join(output_dataset_path, os.path.basename(input_json_dataset_path)) - with open(annotations_json, "w") as file: - json.dump(parsed_coco.dict(exclude_none=True), file) - - return output_dataset_path - - -if __name__ == "__main__": - """ - example usage: - - python crop_coco_dataset.py datasets/towel_testset_0 256 256 - - This will create a new dataset called towel_testset_0_256x256 in the same directory as the old one. - The old dataset will be unaltered. - Currently only square outputs are supported. - """ - - parser = ArgumentParser() - parser.add_argument("input_json_dataset_path") - parser.add_argument("height_new", type=int) - parser.add_argument("width_new", type=int) - args = parser.parse_args() - create_cropped_dataset(**vars(args)) diff --git a/scripts/benchmark.py b/scripts/benchmark.py index e0e5643..cc8b859 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -39,10 +39,10 @@ def benchmark(f, name=None, iters=500, warmup=20, display=True, profile=False): device = "cuda:0" backbone = "ConvNeXtUnet" - input_size = 256 + input_size = 512 backbone = BackboneFactory.create_backbone(backbone) - model = KeypointDetector(1, "2 4", 3, 3e-4, backbone, [["test"]], 1, 1, 0.0, 20) + model = KeypointDetector(1, "2 4", 3, 3e-4, backbone, [["test1"], ["test2,test3"]], 1, 1, 0.0, 20) # do not forget to set model to eval mode! # this will e.g. use the running statistics for batch norm layers instead of the batch statistics. # this is important as inference batches are typically a lot smaller which would create too much noise. @@ -52,14 +52,14 @@ def benchmark(f, name=None, iters=500, warmup=20, display=True, profile=False): sample_model_input = torch.rand(1, 3, input_size, input_size, device=device, dtype=torch.float32) sample_inference_input = np.random.randint(0, 255, (input_size, input_size, 3), dtype=np.uint8) - benchmark(lambda: model(sample_model_input), "plain model forward pass", profile=True) + benchmark(lambda: model(sample_model_input), "plain model forward pass", profile=False) benchmark( - lambda: local_inference(model, sample_inference_input, device=device), "plain model inference", profile=True + lambda: local_inference(model, sample_inference_input, device=device), "plain model inference", profile=False ) torchscript_model = model.to_torchscript() # JIT compiling with torchscript should improve performance (slightly) - benchmark(lambda: torchscript_model(sample_model_input), "torchscript model forward pass", profile=True) + benchmark(lambda: torchscript_model(sample_model_input), "torchscript model forward pass", profile=False) torch.backends.cudnn.benchmark = True model.half() @@ -67,7 +67,7 @@ def benchmark(f, name=None, iters=500, warmup=20, display=True, profile=False): half_torchscript_model = model.to_torchscript(method="trace", example_inputs=half_input) benchmark( - lambda: half_torchscript_model(half_input), "torchscript model forward pass with half precision", profile=True + lambda: half_torchscript_model(half_input), "torchscript model forward pass with half precision", profile=False ) # note: from the traces it can be seen that a lot of time is spent in 'overhead', i.e. the GPU is idle... diff --git a/scripts/benchmark_heatmap_extraction.py b/scripts/benchmark_heatmap_extraction.py new file mode 100644 index 0000000..c430b8f --- /dev/null +++ b/scripts/benchmark_heatmap_extraction.py @@ -0,0 +1,52 @@ +"""quick and dirty benchmark of the heatmap extraction methods.""" + +import time + +import torch + +from keypoint_detection.utils.heatmap import ( + generate_channel_heatmap, + get_keypoints_from_heatmap_batch_maxpool, + get_keypoints_from_heatmap_scipy, +) + + +def test_method(nb_iters, heatmaps, method, name): + n_keypoints = 20 + torch.cuda.synchronize() + t0 = time.time() + if method == get_keypoints_from_heatmap_scipy: + for i in range(nb_iters): + heatmap = heatmaps[i] + for batch in range(len(heatmap)): + for channel in range(len(heatmap[batch])): + method(heatmap[batch][channel], n_keypoints) + else: + for i in range(nb_iters): + method(heatmaps[i], n_keypoints) + torch.cuda.synchronize() + t1 = time.time() + duration = (t1 - t0) / nb_iters * 1000.0 + print(f"{duration:.3f} ms per iter for {name} method with heatmap size {heatmap_size} ") + + +if __name__ == "__main__": + nb_iters = 20 + n_channels = 2 + batch_size = 1 + n_keypoints_per_channel = 10 + print( + f"benchmarking with batch_size: {batch_size}, {n_channels} channels and {n_keypoints_per_channel} keypoints per channel" + ) + for heatmap_size in [(256, 256), (512, 256), (512, 512), (1920, 1080)]: + heatmaps = [ + generate_channel_heatmap(heatmap_size, torch.randint(0, 255, (6, 2)), 6, "cpu") + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, n_channels, 1, 1) + .cuda() + for _ in range(nb_iters) + ] + + test_method(nb_iters, heatmaps, get_keypoints_from_heatmap_scipy, "scipy") + test_method(nb_iters, heatmaps, get_keypoints_from_heatmap_batch_maxpool, "torch") diff --git a/scripts/checkpoint_inference.py b/scripts/checkpoint_inference.py index 426ccf8..bc10abb 100644 --- a/scripts/checkpoint_inference.py +++ b/scripts/checkpoint_inference.py @@ -4,7 +4,7 @@ import torch from torchvision.transforms.functional import to_tensor -from keypoint_detection.utils.heatmap import get_keypoints_from_heatmap +from keypoint_detection.utils.heatmap import get_keypoints_from_heatmap_batch_maxpool from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint @@ -27,9 +27,7 @@ def local_inference(model, image: np.ndarray, device="cuda"): heatmaps = model(image).squeeze(0) # extract keypoints from heatmaps - predicted_keypoints = [ - torch.tensor(get_keypoints_from_heatmap(heatmaps[i].cpu(), 2)) for i in range(heatmaps.shape[0]) - ] + predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmaps.unsqueeze(0))[0] return predicted_keypoints diff --git a/scripts/fiftyone_viewer.py b/scripts/fiftyone_viewer.py new file mode 100644 index 0000000..0830c47 --- /dev/null +++ b/scripts/fiftyone_viewer.py @@ -0,0 +1,257 @@ +"""use fiftyone to visualize the predictions of trained keypoint detectors on a dataset. Very useful for debugging and understanding the models predictions.""" +import os +from collections import defaultdict +from typing import List, Optional, Tuple + +import fiftyone as fo +import numpy as np +import torch +import tqdm + +from keypoint_detection.data.coco_dataset import COCOKeypointsDataset +from keypoint_detection.models.detector import KeypointDetector +from keypoint_detection.models.metrics import DetectedKeypoint, Keypoint, KeypointAPMetrics +from keypoint_detection.tasks.train_utils import parse_channel_configuration +from keypoint_detection.utils.heatmap import compute_keypoint_probability, get_keypoints_from_heatmap_batch_maxpool +from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint + +# TODO: can get channel config from the models! no need to specify manually +# TODO: mAP / image != mAP, maybe it is also not even the best metric to use for ordering samples .Should also log the loss / image. + + +class DetectorFiftyoneViewer: + def __init__( + self, + dataset_path: str, + models: dict[str, KeypointDetector], + channel_config: str, + detect_only_visible_keypoints: bool = False, + n_samples: Optional[int] = None, + ap_threshold_distances: Optional[List[int]] = None, + ): + self.dataset_path = dataset_path + self.models = models + self.channel_config = channel_config + self.detect_only_visible_keypoints = detect_only_visible_keypoints + self.n_samples = n_samples + self.parsed_channel_config = parse_channel_configuration(channel_config) + self.ap_threshold_distances = ap_threshold_distances + if self.ap_threshold_distances is None: + self.ap_threshold_distances = [ + 2, + ] + + self.coco_dataset = COCOKeypointsDataset( + dataset_path, self.parsed_channel_config, detect_only_visible_keypoints=detect_only_visible_keypoints + ) + + # create the AP metrics + self.ap_metrics = { + name: [KeypointAPMetrics(self.ap_threshold_distances) for _ in self.parsed_channel_config] + for name in models.keys() + } + + # set all models to eval mode to be sure. + for model in self.models.values(): + model.eval() + + self.predicted_keypoints = {model_name: [] for model_name in models.keys()} + self.gt_keypoints = [] + # {model: {sample_idx: {channel_idx: [ap_score]}} + self.ap_scores = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + + # create the fiftyone dataset + self.fo_dataset = fo.Dataset.from_dir( + dataset_type=fo.types.COCODetectionDataset, + data_path=os.path.dirname(self.dataset_path), + label_types=[], # do not load the coco annotations + labels_path=self.dataset_path, + ) + self.fo_dataset.add_dynamic_sample_fields() + self.fo_dataset = self.fo_dataset.limit(self.n_samples) + + # order of coco dataset does not necessarily match the order of the fiftyone dataset + # so we create a mapping of image paths to dataset indices + # to match fiftyone samples to coco dataset samples to obtain the GT keypoints. + self.image_path_to_dataset_idx = {} + for idx, entry in enumerate(self.coco_dataset.dataset): + image_path, _ = entry + image_path = str(self.coco_dataset.dataset_dir_path / image_path) + self.image_path_to_dataset_idx[image_path] = idx + + def predict_and_compute_metrics(self): + with torch.no_grad(): + fo_sample_idx = 0 + for fo_sample in tqdm.tqdm(self.fo_dataset): + image_path = fo_sample.filepath + image_idx = self.image_path_to_dataset_idx[image_path] + image, keypoints = self.coco_dataset[image_idx] + image = image.unsqueeze(0) + gt_keypoints = [] + for channel in keypoints: + gt_keypoints.append([[kp[0], kp[1]] for kp in channel]) + self.gt_keypoints.append(gt_keypoints) + + for model_name, model in self.models.items(): + heatmaps = model(image)[0] + # extract keypoints from heatmaps for each channel + predicted_keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmaps.unsqueeze(0))[0] + predicted_keypoint_probabilities = [ + compute_keypoint_probability(heatmaps[i], predicted_keypoints[i]) for i in range(len(heatmaps)) + ] + self.predicted_keypoints[model_name].append( + [predicted_keypoints, predicted_keypoint_probabilities] + ) + + #### METRIC COMPUTATION #### + for metric in self.ap_metrics[model_name]: + metric.reset() + + for channel_idx in range(len(self.parsed_channel_config)): + metric_detected_keypoints = predicted_keypoints[channel_idx] + probabilities = predicted_keypoint_probabilities[channel_idx] + metric_detected_keypoints = [ + DetectedKeypoint(kp[0], kp[1], p) + for kp, p in zip(metric_detected_keypoints, probabilities) + ] + metric_gt_formatted_keypoints = [Keypoint(kp[0], kp[1]) for kp in gt_keypoints[channel_idx]] + self.ap_metrics[model_name][channel_idx].update( + metric_detected_keypoints, metric_gt_formatted_keypoints + ) + + for channel_idx in range(len(self.parsed_channel_config)): + self.ap_scores[model_name][fo_sample_idx].update( + {channel_idx: list(self.ap_metrics[model_name][channel_idx].compute().values())} + ) + + fo_sample_idx += 1 + + def visualize_predictions( + self, + ): + """visualize keypoint detectors on a coco dataset. Requires the coco json, thechannel config and a dict of wandb checkpoints.""" + + # add the ground truth to the dataset + for sample_idx, sample in enumerate(self.fo_dataset): + self._add_instance_keypoints_to_fo_sample( + sample, "ground_truth_keypoints", self.gt_keypoints[sample_idx], None, self.parsed_channel_config + ) + + # add the predictions to the dataset + for model_name, model in self.models.items(): + for sample_idx, sample in enumerate(self.fo_dataset): + keypoints, probabilities = self.predicted_keypoints[model_name][sample_idx] + self._add_instance_keypoints_to_fo_sample( + sample, f"{model_name}_keypoints", keypoints, probabilities, self.parsed_channel_config + ) + model_ap_scores = self.ap_scores[model_name][sample_idx] + + # log map + ap_values = np.zeros((len(self.parsed_channel_config), len(self.ap_threshold_distances))) + for channel_idx in range(len(self.parsed_channel_config)): + for max_dist_idx in range(len(self.ap_threshold_distances)): + ap_values[channel_idx, max_dist_idx] = model_ap_scores[channel_idx][max_dist_idx] + sample[f"{model_name}_keypoints_mAP"] = ap_values.mean() + sample.save() + # could do only one loop instead of two for the predictions usually, but we have to compute the GT keypoints, so we need to loop over the dataset anyway + # https://docs.voxel51.com/user_guide/dataset_creation/index.html#model-predictions + + print(self.fo_dataset) + + session = fo.launch_app(dataset=self.fo_dataset, port=5252) + session = self._configure_session_colors(session) + session.wait() + + def _configure_session_colors(self, session: fo.Session) -> fo.Session: + """ + set colors such that each model has a different color and the mAP labels have the same color as the keypoints. + """ + + # chatgpt color pool + color_pool = [ + "#FF00FF", # Neon Purple + "#00FF00", # Electric Green + "#FFFF00", # Cyber Yellow + "#0000FF", # Laser Blue + "#FF0000", # Radioactive Red + "#00FFFF", # Galactic Teal + "#FF00AA", # Quantum Pink + "#C0C0C0", # Holographic Silver + "#000000", # Abyssal Black + "#FFA500", # Cosmic Orange + ] + color_fields = [] + color_fields.append({"path": "ground_truth_keypoints", "fieldColor": color_pool[-1]}) + for model_idx, model_name in enumerate(self.models.keys()): + color_fields.append({"path": f"{model_name}_keypoints", "fieldColor": color_pool[model_idx]}) + color_fields.append({"path": f"{model_name}_keypoints_mAP", "fieldColor": color_pool[model_idx]}) + session.color_scheme = fo.ColorScheme(color_pool=color_pool, fields=color_fields) + return session + + def _add_instance_keypoints_to_fo_sample( + self, + sample, + predictions_name, + instance_keypoints: List[List[Tuple]], + keypoint_probabilities: List[List[float]], + parsed_channels: List[List[str]], + ) -> fo.Sample: + """adds the detected keypoints to the sample in the fiftyone format""" + assert len(instance_keypoints) == len(parsed_channels) + # assert instance_keypoints[0][0][0] > 1.0 # check if the keypoints are not normalized yet + fo_keypoints = [] + for channel_idx in range(len(instance_keypoints)): + channel_keypoints = instance_keypoints[channel_idx] + # normalize the keypoints to the image size + width = sample["metadata"]["width"] + height = sample["metadata"]["height"] + channel_keypoints = [[kp[0] / width, kp[1] / height] for kp in channel_keypoints] + if keypoint_probabilities is not None: + channel_keypoint_probabilities = keypoint_probabilities[channel_idx] + else: + channel_keypoint_probabilities = None + fo_keypoints.append( + fo.Keypoint( + label="=".join(parsed_channels[channel_idx]), + points=channel_keypoints, + confidence=channel_keypoint_probabilities, + ) + ) + + sample[predictions_name] = fo.Keypoints(keypoints=fo_keypoints) + sample.save() + return sample + + +import cv2 + +cv2.INTER_LINEAR +if __name__ == "__main__": + # TODO: make CLI for this -> hydra config? + checkpoint_dict = { + # "maxvit-256-flat": "tlips/synthetic-cloth-keypoints-quest-for-precision/model-5ogj44k0:v0", + # "maxvit-512-flat": "tlips/synthetic-cloth-keypoints-quest-for-precision/model-1of5e6qs:v0", + # "maxvit-pyflex-20k": "tlips/synthetic-cloth-keypoints/model-qiellxgb:v0" + # "maxvit-pyflex-512x256": "tlips/synthetic-cloth-keypoints/model-8m3z0wyo:v0", + # "maxvit-RTF-512x256" : "tlips/synthetic-cloth-keypoints/model-pzbwimqa:v0", + # "maxvit-sim-longer": "tlips/synthetic-cloth-keypoints/model-nvs1pktv:v0", + # "rtf-cv2":"tlips/synthetic-cloth-keypoints/model-xvkowjqr:v0", + # "rtf-pil":"tlips/synthetic-cloth-keypoints/model-0goi5hc7:v0", + # "sim-new-data":"tlips/synthetic-cloth-keypoints/model-axrqhql1:v0", + # "sim-40k":"tlips/synthetic-cloth-keypoints/model-yillsdva:v0" + # "purple-towel-on-white": "tlips/synthetic-cloth-keypoints-single-towel/model-pw2tsued:v0", + "purple-towel-on-white-separate": "tlips/synthetic-cloth-keypoints-single-towel/model-gl39yjtf:v0" + } + + dataset_path = "/storage/users/tlips/aRTFClothes/towels-test_resized_512x256/towels-test.json" + dataset_path = "/home/tlips/Documents/synthetic-cloth-data/synthetic-cloth-data/data/datasets/TOWEL/05-512x256-40k/annotations_val.json" + dataset_path = "/home/tlips/Documents/synthetic-cloth-data/synthetic-cloth-data/data/datasets/TOWEL/07-purple-towel-on-white/annotations_val.json" + channel_config = "corner0;corner1;corner2;corner3" + detect_only_visible_keypoints = True + n_samples = 200 + models = {key: get_model_from_wandb_checkpoint(value) for key, value in checkpoint_dict.items()} + visualizer = DetectorFiftyoneViewer( + dataset_path, models, channel_config, detect_only_visible_keypoints, n_samples, ap_threshold_distances=[4] + ) + visualizer.predict_and_compute_metrics() + visualizer.visualize_predictions() diff --git a/scripts/generate_dataset.ipynb b/scripts/generate_dataset.ipynb new file mode 100644 index 0000000..9a7b768 --- /dev/null +++ b/scripts/generate_dataset.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate a COCO keypoints dataset of black images with circles on it for integration testing of the keypoint detector. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: distinctipy in /fast_storage_2/symlinked_homes/tlips/conda/.conda/envs/keypoint-detection/lib/python3.9/site-packages (1.2.2)\n", + "Requirement already satisfied: numpy in /home/tlips/.local/lib/python3.9/site-packages (from distinctipy) (1.25.2)\n" + ] + } + ], + "source": [ + "import cv2\n", + "import numpy as np \n", + "from airo_dataset_tools.data_parsers.coco import CocoKeypointAnnotation, CocoImage, CocoKeypointCategory, CocoKeypointsDataset\n", + "import pathlib\n", + "!pip install distinctipy\n", + "import distinctipy" + ] + }, + { + "cell_type": "code", + "execution_count": 193, + "metadata": {}, + "outputs": [], + "source": [ + "n_images = 500\n", + "n_categories = 2\n", + "max_category_instances_per_image = 2\n", + "\n", + "image_resolution = (128, 128)\n", + "circle_radius = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 194, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "DATA_DIR = pathlib.Path(\"./dummy_dataset\")\n", + "DATA_DIR.mkdir(exist_ok=True)\n", + "IMAGE_DIR = DATA_DIR / \"images\"\n", + "IMAGE_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 195, + "metadata": {}, + "outputs": [], + "source": [ + "categories = []\n", + "for category_idx in range(n_categories):\n", + " coco_category = CocoKeypointCategory(\n", + " id=category_idx,\n", + " name=f\"dummy{category_idx}\",\n", + " supercategory=f\"dummy{category_idx}\",\n", + " keypoints=[f\"dummy{category_idx}\"]\n", + " )\n", + " categories.append(coco_category)" + ] + }, + { + "cell_type": "code", + "execution_count": 196, + "metadata": {}, + "outputs": [], + "source": [ + "category_colors = distinctipy.get_colors(n_categories)\n", + "category_colors = [tuple([int(c * 255) for c in color]) for color in category_colors]" + ] + }, + { + "cell_type": "code", + "execution_count": 197, + "metadata": {}, + "outputs": [], + "source": [ + "coco_images = []\n", + "cococ_annotations = []\n", + "\n", + "coco_instances_coutner = 0\n", + "for image_idx in range(n_images):\n", + " img = np.zeros((image_resolution[1],image_resolution[0],3), dtype=np.uint8)\n", + " coco_images.append(CocoImage(id=image_idx, file_name=f\"images/img_{image_idx}.png\", height=image_resolution[1], width=image_resolution[0]))\n", + " for category_idx in range(n_categories):\n", + " n_instances = np.random.randint(0, max_category_instances_per_image+1)\n", + " for instance_idx in range(n_instances):\n", + " x = np.random.randint(2, image_resolution[0])\n", + " y = np.random.randint(2, image_resolution[1])\n", + " img = cv2.circle(img, (x, y), circle_radius, category_colors[category_idx], -1)\n", + " cococ_annotations.append(CocoKeypointAnnotation(\n", + " id=coco_instances_coutner,\n", + " image_id=image_idx,\n", + " category_id=category_idx,\n", + " # as in coco datasets: zero-index, INT keypoints.\n", + " # but add some random noise (simulating dataset with the exact pixel location instead of the zero-index int location)\n", + " # to test if the detector can deal with this\n", + " keypoints=[x + np.random.rand(1).item(), y + np.random.rand(1).item(), 1],\n", + " num_keypoints=1,\n", + " ))\n", + " coco_instances_coutner += 1\n", + "\n", + " cv2.imwrite(str(DATA_DIR / \"images\"/f\"img_{image_idx}.png\"), img)\n", + "\n", + "coco_dataset = CocoKeypointsDataset(\n", + " images=coco_images,\n", + " annotations=cococ_annotations,\n", + " categories=categories,\n", + ")\n", + "\n", + "with open(DATA_DIR / \"dummy_dataset.json\", \"w\") as f:\n", + " f.write(coco_dataset.json())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "keypoint-detection", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index d58fcab..d6917ba 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup +from setuptools import find_packages, setup setup( name="keypoint_detection", @@ -7,11 +7,11 @@ version="1.0", description="Pytorch Models, Modules etc for keypoint detection", url="https://github.com/tlpss/keypoint-detection", - packages=["keypoint_detection", "labeling"], + packages=find_packages(exclude=("test",)), install_requires=[ "torch>=0.10", "torchvision>=0.11", - "pytorch-lightning>=1.5.10", + "pytorch-lightning>=1.5.10,<=1.9.4", # PL 2.0 has breaking changes that need to be incorporated "torchmetrics>=0.7", "wandb>=0.13.7", # artifact bug https://github.com/wandb/wandb/issues/4500 "timm>=0.6.11", # requires smallsized convnext models @@ -24,5 +24,7 @@ # for labeling package, should be moved in time to separate setup.py "xmltodict", "pydantic", + "fiftyone", ], + entry_points={"console_scripts": ["keypoint-detection = keypoint_detection.tasks.cli:main"]}, ) diff --git a/test/configuration.py b/test/configuration.py index c4f5570..72f267e 100644 --- a/test/configuration.py +++ b/test/configuration.py @@ -11,7 +11,7 @@ DEFAULT_HPARAMS = { "keypoint_channel_configuration": [["box_corner0", "box_corner1", "box_corner2", "box_corner3"], ["flap_corner0"]], - "detect_non_visible_keypoints": True, + "detect_only_visible_keypoints": False, "seed": 102, "wandb_project": "test_project", "wandb_entity": "box-manipulation", diff --git a/test/integration_test.sh b/test/integration_test.sh index 2d520f5..537c9d5 100644 --- a/test/integration_test.sh +++ b/test/integration_test.sh @@ -3,7 +3,7 @@ # Run from the repo's root folder using bash test/integration_test.sh # make sure to remove all trailing spaces from the command, as this would result in an error when using bash. -python keypoint_detection/train/train.py \ ---keypoint_channel_configuration "box_corner0= box_corner1 = box_corner2= box_corner3; flap_corner0 ; flap_corner2" \ ---json_dataset_path "test/test_dataset/coco_dataset.json" --batch_size 2 --wandb_project "keypoint-detector-integration-test" \ ---max_epochs 50 --early_stopping_relative_threshold -1.0 --log_every_n_steps 1 --accelerator="gpu" --devices 1 --precision 16 +python keypoint_detection/tasks/train.py \ +--keypoint_channel_configuration "box_corner0= box_corner1 = box_corner2= box_corner3: flap_corner0:flap_corner2" \ +--json_dataset_path "test/test_dataset/coco_dataset.json" --json_validation_dataset_path "test/test_dataset/coco_dataset.json" --batch_size 2 --wandb_project "keypoint-detector-integration-test" \ +--max_epochs 50 --early_stopping_relative_threshold -1.0 --log_every_n_steps 1 --accelerator="gpu" --devices 1 --precision 16 --augment_train diff --git a/test/test_crop_coco_dataset.py b/test/test_crop_coco_dataset.py deleted file mode 100644 index f1af738..0000000 --- a/test/test_crop_coco_dataset.py +++ /dev/null @@ -1,36 +0,0 @@ -import shutil -import unittest -from pathlib import Path - -import numpy as np - -from keypoint_detection.data.coco_dataset import COCOKeypointsDataset -from labeling.scripts.crop_coco_dataset import create_cropped_dataset - -from .configuration import DEFAULT_HPARAMS - - -class TestCropCocoDataset(unittest.TestCase): - def test_crop_coco_dataset(self): - annotations_filename = "coco_dataset.json" - input_json_dataset_path = Path(__file__).parents[0] / "test_dataset" / annotations_filename - output_dataset_path = create_cropped_dataset(input_json_dataset_path, 32, 32) - print(output_dataset_path) - - output_json_dataset_path = Path(output_dataset_path) / annotations_filename - - # Check whether the new coords are half of the old, because image resolution was halved. - channel_config = DEFAULT_HPARAMS["keypoint_channel_configuration"] - dataset_old = COCOKeypointsDataset(input_json_dataset_path, channel_config) - dataset_new = COCOKeypointsDataset(output_json_dataset_path, channel_config) - - for item_old, item_new in zip(dataset_old, dataset_new): - _, keypoint_channels_old = item_old - _, keypoint_channels_new = item_new - - for channel_old, channel_new in zip(keypoint_channels_old, keypoint_channels_new): - for keypoint_old, keypoint_new in zip(channel_old, channel_new): - print(keypoint_old, keypoint_new) - assert np.allclose(np.array(keypoint_old) / 2.0, np.array(keypoint_new)) - - shutil.rmtree(output_dataset_path) diff --git a/test/test_datamodule.py b/test/test_datamodule.py index c692bf4..8cdcde6 100644 --- a/test/test_datamodule.py +++ b/test/test_datamodule.py @@ -4,6 +4,7 @@ from test.configuration import DEFAULT_HPARAMS, TEST_PARAMS import torch +import torch.utils.data from keypoint_detection.data.datamodule import KeypointsDataModule @@ -51,29 +52,37 @@ def test_batch_format(self): self.assertIsInstance(ch2[0], torch.Tensor) def test_augmentations_result_in_different_image(self): + # get the dataset through the datamodule + # cannot use dataloader directly bc it shuffles the dataset. random.seed(2022) torch.manual_seed(2022) hparams = copy.deepcopy(DEFAULT_HPARAMS) + hparams["augment_train"] = False + module = KeypointsDataModule(**hparams) - train_dataloader = module.train_dataloader() + no_aug_train_dataloader = module.train_dataloader() + no_aug_dataset = no_aug_train_dataloader.dataset - batch = next(iter(train_dataloader)) - img, _ = batch + img, _ = no_aug_dataset[0] - hparams = copy.deepcopy(DEFAULT_HPARAMS) + # reset seeds to obtain the same dataset order + # and get the dataset again but now with augmentations + random.seed(2022) + torch.manual_seed(2022) + hparams = copy.deepcopy(hparams) hparams["augment_train"] = True - module = KeypointsDataModule(**hparams) - train_dataloader = module.train_dataloader() + aug_module = KeypointsDataModule(**hparams) + aug_train_dataloader = aug_module.train_dataloader() + aug_dataset = aug_train_dataloader.dataset - dissimilar_batches = 0 - # iterate over a few batches - # bc none of the augmentations is applied with 100% probability + dissimilar_images = 0 + # iterate a few times over the dataset to check that the augmentations are applied + # bc none of the augmentations is applied with 100% probability so some batches could be equal # and finding a seed that triggers them could change if you change the augmentations for _ in range(5): - batch = next(iter(train_dataloader)) - transformed_img, _ = batch + transformed_img, _ = aug_dataset[0] # check both images are not equal. - dissimilar_batches += 1 * (torch.linalg.norm(img - transformed_img) != 0.0) + dissimilar_images += 1 * (torch.linalg.norm(img - transformed_img) != 0.0) - self.assertTrue(dissimilar_batches > 0) + self.assertTrue(dissimilar_images > 0) diff --git a/test/test_dataset.py b/test/test_dataset.py index 4c641c5..8538f0e 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -41,7 +41,8 @@ def test_dataset(self): self.assertEqual(len(ch2), len(DEFAULT_HPARAMS["keypoint_channel_configuration"][1])) def test_non_visible_dataset(self): - self.hparams.update({"detect_non_visible_keypoints": False}) + self.hparams["json_dataset_path"] = Path(__file__).parent / "test_dataset" / "duplicate_coco_dataset.json" + self.hparams.update({"detect_only_visible_keypoints": True}) dataset = COCOKeypointsDataset(**self.hparams) # has duplicates but they are not visible (flag=1) diff --git a/test/test_detector.py b/test/test_detector.py index 5bde578..8512b0d 100644 --- a/test/test_detector.py +++ b/test/test_detector.py @@ -1,6 +1,7 @@ import os import unittest +import pytest import torch from pytorch_lightning.loggers import WandbLogger from torch import nn @@ -10,7 +11,7 @@ from keypoint_detection.models.backbones.unet import Unet from keypoint_detection.models.detector import KeypointDetector from keypoint_detection.models.metrics import KeypointAPMetric -from keypoint_detection.train.utils import create_pl_trainer +from keypoint_detection.tasks.train_utils import create_pl_trainer from keypoint_detection.utils.heatmap import create_heatmap_batch, generate_channel_heatmap from keypoint_detection.utils.load_checkpoints import load_from_checkpoint from keypoint_detection.utils.path import get_wandb_log_dir_path @@ -100,6 +101,10 @@ def test_model_init_heatmaps(self): self.assertTrue(torch.mean(heatmap).item() < 0.1) self.assertTrue(torch.var(heatmap).item() < 0.1) + # TODO: chcek if we can run it on gh actions as well. + IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions atm") def test_checkpoint_loading(self): wandb_logger = WandbLogger(dir=get_wandb_log_dir_path(), mode="offline") diff --git a/test/test_heatmap.py b/test/test_heatmap.py index 585f5ae..903df27 100644 --- a/test/test_heatmap.py +++ b/test/test_heatmap.py @@ -3,7 +3,12 @@ import numpy as np import torch -from keypoint_detection.utils.heatmap import create_heatmap_batch, generate_channel_heatmap, get_keypoints_from_heatmap +from keypoint_detection.utils.heatmap import ( + create_heatmap_batch, + generate_channel_heatmap, + get_keypoints_from_heatmap_batch_maxpool, + get_keypoints_from_heatmap_scipy, +) class TestHeatmapUtils(unittest.TestCase): @@ -16,24 +21,35 @@ def setUp(self): def test_keypoint_generation_and_extraction(self): # test if extract(generate(keypoints)) == keypoints heatmap = generate_channel_heatmap((self.image_height, self.image_width), self.keypoints, self.sigma, "cpu") - extracted_keypoints = get_keypoints_from_heatmap(heatmap, 1) + extracted_keypoints = get_keypoints_from_heatmap_scipy(heatmap, 1) for keypoint in extracted_keypoints: self.assertTrue(keypoint in self.keypoints.tolist()) self.assertEqual((self.image_height, self.image_width), heatmap.shape) self.assertGreater(heatmap[4, 10], 0.5) - def test_extract_all_keypoints_from_heatmap(self): + def test_extract_all_keypoints_from_heatmap_scipy(self): def _test_extract_keypoints_from_heatmap(keypoints, num_keypoints): heatmap = generate_channel_heatmap((self.image_height, self.image_width), keypoints, self.sigma, "cpu") - extracted_keypoints = get_keypoints_from_heatmap(heatmap, 1, max_keypoints=num_keypoints) + extracted_keypoints = get_keypoints_from_heatmap_scipy(heatmap, 1, max_keypoints=num_keypoints) for keypoint in extracted_keypoints: self.assertTrue(keypoint in keypoints.tolist()) - keypoints = torch.randint(0, 15, (500, 2)) - _test_extract_keypoints_from_heatmap(keypoints, num_keypoints=500) + keypoints = torch.randint(0, 15, (5, 2)) + _test_extract_keypoints_from_heatmap(keypoints, num_keypoints=10) _test_extract_keypoints_from_heatmap(keypoints, num_keypoints=-1) _test_extract_keypoints_from_heatmap(keypoints, num_keypoints=np.inf) + def test_extract_keypoints_from_heatmap_maxpool(self): + def _test_extract_keypoints_from_heatmap(keypoints, num_keypoints): + heatmap = generate_channel_heatmap((self.image_height, self.image_width), keypoints, self.sigma, "cpu") + heatmap = heatmap.unsqueeze(0).unsqueeze(0) + extracted_keypoints = get_keypoints_from_heatmap_batch_maxpool(heatmap, max_keypoints=num_keypoints)[0][0] + for keypoint in extracted_keypoints: + self.assertTrue(keypoint in keypoints.tolist()) + + keypoints = torch.randint(0, 15, (5, 2)) + _test_extract_keypoints_from_heatmap(keypoints, num_keypoints=10) + def test_empty_heatmap(self): # test if heatmap for channel w/o keypoints is created correctly heatmap = generate_channel_heatmap(