Skip to content

Commit

Permalink
option to start training from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Oct 20, 2023
1 parent 2822dc5 commit 66e9d5e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,22 @@ TODO: add integration example.
To train a keypoint detector, run the `keypoint-detection train` CLI with the appropriate arguments.
To create your own configuration: run `keypoint-detection train -h` to see all parameter options and their documentation.

A starting point could be the bash script `bash test/integration_test.sh` to test on the provided test dataset, which contains 4 images. You should see the loss going down consistently until the detector has completely overfit the train set and the loss is around the entropy of the ground truth heatmaps (if you selected the default BCE loss).
A good starting point could be the bash script `bash test/integration_test.sh` to test on the provided test dataset, which contains 4 images. You should see the loss going down consistently until the detector has completely overfit the train set and the loss is around the entropy of the ground truth heatmaps (if you selected the default BCE loss).

### Wandb sweeps
Alternatively, you can create a sweep on [wandb](https://wandb.ai) and to then start a (number of) wandb agent(s). This is very useful for running multiple configurations (hparam search, testing on multiple datasets,..)

### Loading pretrained weights
If you want to load pretrained keypoint detector weights, you can specify the wandb artifact of the checkpoint in the training parameters: `keypoint-detection train ..... -wandb_checkpoint_artifact <artifact-path>`. This can be used for example to finetune on real data after pretraining on synthetic data.

## Dataset

This package used the [COCO format](https://cocodataset.org/#format-data) for keypoint annotation and expects a dataset with the following structure:
```
dataset/
images/
...
<name>.json : a COCO-formatted keypoint annotation file.
<name>.json : a COCO-formatted keypoint annotation file with filepaths relative to its parent directory.
```
For an example, see the `test_dataset` at `test/test_dataset`.

Expand All @@ -66,7 +70,7 @@ TODO
TODO
`scripts/fiftyone_viewer`

## Using a trained model (Inference)
## Using a trained model for Inference
During training Pytorch Lightning will have saved checkpoints. See `scripts/checkpoint_inference.py` for a simple example to run inference with a checkpoint.
For benchmarking the inference (or training), see `scripts/benchmark.py`.

Expand All @@ -86,7 +90,7 @@ In general a lower threshold will result in a lower metric. The size of this gap
#TODO: add a figure to illustrate this.


We do not use OKS as in COCO for 2 reasons:
We do not use OKS as in COCO for the following reasons:
1. it requires bbox annotations, which are not always required for keypoint detection itself and represent additional label effort.
2. More importantly, in robotics the size of an object does not always correlate with the required precision. If a large and a small mug stand on a table, they require the same precise localisation of keypoints for a robot to grasp them even though their apparent size is different.
3. (you need to estimate label variance, though you could simply set k=1 and skip this part)
Expand Down
20 changes: 18 additions & 2 deletions keypoint_detection/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keypoint_detection.models.backbones.backbone_factory import BackboneFactory
from keypoint_detection.models.detector import KeypointDetector
from keypoint_detection.tasks.train_utils import create_pl_trainer, parse_channel_configuration
from keypoint_detection.utils.load_checkpoints import get_model_from_wandb_checkpoint
from keypoint_detection.utils.path import get_wandb_log_dir_path


Expand Down Expand Up @@ -49,6 +50,12 @@ def add_system_args(parent_parser: ArgumentParser) -> ArgumentParser:
help="do not use deterministic algorithms for pytorch. This can speed up training, but will make it non-reproducible.",
)

parser.add_argument(
"--wandb_checkpoint_artifact",
type=str,
help="A checkpoint to resume/start training from. keep in mind that you currently cannot specify hyperparameters other than the LR.",
required=False,
)
parser.set_defaults(deterministic=True)
return parent_parser

Expand All @@ -63,9 +70,18 @@ def train(hparams: dict) -> Tuple[KeypointDetector, pl.Trainer]:

# use deterministic algorithms for torch to ensure exact reproducibility
# we have to set it in the trainer! (see create_pl_trainer)
if "wandb_checkpoint_artifact" in hparams.keys():
print("Loading checkpoint from wandb")
# This will create a KeypointDetector model with the associated hyperparameters.
# Model weights will be loaded.
# Optimizer and LR scheduler will be initiated from scratch" (if you want to really resume training, you have to pass the ckeckpoint to the trainer)
# cf. https://lightning.ai/docs/pytorch/latest/common/checkpointing_basic.html#lightningmodule-from-checkpoint
model = get_model_from_wandb_checkpoint(hparams["wandb_checkpoint_artifact"])
# TODO: how can specific hparams be overwritten here? e.g. LR reduction for finetuning or something?
else:
backbone = BackboneFactory.create_backbone(**hparams)
model = KeypointDetector(backbone=backbone, **hparams)

backbone = BackboneFactory.create_backbone(**hparams)
model = KeypointDetector(backbone=backbone, **hparams)
data_module = KeypointsDataModule(**hparams)
wandb_logger = WandbLogger(
project=hparams["wandb_project"],
Expand Down
12 changes: 10 additions & 2 deletions keypoint_detection/utils/load_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ def get_model_from_wandb_checkpoint(checkpoint_reference: str):
import wandb

# download checkpoint locally (if not already cached)
run = wandb.init(project="inference")
if wandb.run is None:
run = wandb.init(project="inference")
else:
run = wandb.run
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()
checkpoint_path = Path(artifact_dir) / "model.ckpt"
return load_from_checkpoint(checkpoint_path)


def load_from_checkpoint(checkpoint_path: str):
def load_from_checkpoint(checkpoint_path: str, hparams_to_override: dict = None):
"""
function to load a Keypoint Detector model from a local pytorch lightning checkpoint.
Expand All @@ -43,3 +46,8 @@ def load_from_checkpoint(checkpoint_path: str):
backbone = BackboneFactory.create_backbone(**checkpoint["hyper_parameters"])
model = KeypointDetector.load_from_checkpoint(checkpoint_path, backbone=backbone)
return model


if __name__ == "__main__":
model = get_model_from_wandb_checkpoint("tlips/synthetic-cloth-keypoints-tshirts/model-4um302zo:v0")
print(model.hparams)

0 comments on commit 66e9d5e

Please sign in to comment.