diff --git a/references/detection/train.py b/references/detection/train.py index 4cad35ccb0a..f56ac66881c 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -35,6 +35,11 @@ from transforms import SimpleCopyPaste +def copypaste_collate_fn(batch): + copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) + return copypaste(*utils.collate_fn(batch)) + + def get_dataset(name, image_set, transform, data_path): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] @@ -194,11 +199,6 @@ def main(args): if args.data_augmentation != "lsj": raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies") - copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True) - - def copypaste_collate_fn(batch): - return copypaste(*utils.collate_fn(batch)) - train_collate_fn = copypaste_collate_fn data_loader = torch.utils.data.DataLoader(