Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SimpleCopyPaste augmentation #5825

Merged
merged 28 commits into from
Jun 15, 2022

Conversation

lezwon
Copy link
Contributor

@lezwon lezwon commented Apr 18, 2022

This PR is related to #3817. It implements SimpleCopyPaste augmentation for segmenatation tasks.
https://arxiv.org/abs/2012.07177

@datumbox
Copy link
Contributor

@lezwon Thanks for kicking this off. Just a quick note, I would recommend moving this transform on detection instead of segmentation. MaskRCNN is actually run using the detection pipeline (I know it's confusing).

@lezwon
Copy link
Contributor Author

lezwon commented Apr 19, 2022

@lezwon Thanks for kicking this off. Just a quick note, I would recommend moving this transform on detection instead of segmentation. MaskRCNN is actually run using the detection pipeline (I know it's confusing).

Sure will do that 👍

@lezwon
Copy link
Contributor Author

lezwon commented Apr 22, 2022

hey @vadimkantorov @datumbox , I've moved the augmentation to the detection module. It is a basic functioning POC right now. It would be really nice to get some early feedback from ya'll :) I have attached some samples for your reference.

image_0_a
image_0_mask
image_1_a
image_1_mask

@lezwon lezwon force-pushed the transforms/simplecopypaste branch 2 times, most recently from 7c58731 to 2782e89 Compare May 4, 2022 16:21
@lezwon
Copy link
Contributor Author

lezwon commented May 4, 2022

@datumbox are tests/docs necessary for this PR?

@lezwon lezwon marked this pull request as ready for review May 5, 2022 05:27
@datumbox
Copy link
Contributor

datumbox commented May 9, 2022

@lezwon Apologies for the delayed response. I was OOO and trying to catch up. Let me see if I can find someone who could support you on the remaining bit.

No need for doc/tests etc at this point. Especially since it's on the references, the implementation won't appear on main TorchVision (for now). Let's focus on completing the implementation, verifying it works as expected, reviewing the API and training a model with it.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 12, 2022

@lezwon I was checking the paper, official TF implementation and unofficial pytorch one.

In our case, let's proceed in the following way. Originally, copy paste transform works on dataset level and mixes current image with a paste image taken randomly from the dataset. In our case in order to avoid a mess we could have if dealing with datasets, we can start by copy/pasting data on batch level. For detection recipee, images and targets are tuple of tensors and tuple of dict (key->Tensor). CopyPaste transform should work on that level instead of single image/target pair. Thus, we have to add it on DataLoader when collating data:

copypaste = CopyPaste()

def copypaste_collate_fn(batch):
    return copypaste(*utils_collate_fn(batch))

data_loader = torch.utils.data.DataLoader(
    dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=copypaste_collate_fn
)

Here is an unfinished implemention that we could update in order to make this PR merged:

from torch import nn

def copy_paste(image, target, paste_image, paste_target):
    # implement copy/paste logic
    return paste_image, paste_target

class CopyPaste(nn.Module):
    
    def __init__(self, inplace=True):
        super().__init__()
        self.inplace = inplace
    
    def forward(self, images, targets=None):
         assert targets is not None
        
        # assert images is a tuple of Tensors
        # assert targets is a tuple of dict of key -> Tensor
        
        if not self.inplace:
            # clone images and targets
            pass
                
        # images = [t1, t2, ..., tN]
        # Let's define paste_images as shifted list of input images
        # in TF they mix data on the dataset level
        # paste_images = [t2, t3, ..., tN, t1]
        images = list(images)
        targets = list(targets)
        shift = 1
        le = len(images)        
        
        out_images = []
        out_targets = []
        for i in range(le):
            image, target = images[i], targets[i]
            paste_image, paste_target = images[(i + shift) % le], targets[(i + shift) % le]
            image, target = copy_paste(image, target, paste_image, paste_target)
            out_images.append(image)
            out_targets.append(target)
        
        return tuple(out_images), tuple(out_targets)

Let me know what do you think

@datumbox
Copy link
Contributor

@lezwon Let us know if what @vfdev-5 works out for you or if you need additional help. Thanks! 😃

@lezwon
Copy link
Contributor Author

lezwon commented May 17, 2022

@datumbox @vfdev-5 Really sorry for the delay. Been a bit occupied lately. I'll get back to this PR by Thursday is that okay :)

@lezwon
Copy link
Contributor Author

lezwon commented May 20, 2022

Hi @vfdev-5, I have a few doubts and questions:

  1. My current implementation takes in a batch image tensor as batch and list of dicts as target. I did the implementation based on [RandomMixup](https://github.com/pytorch/vision/blob/289fce29b3e2392114aadbe7a419df0f2e3ac1be/references/classification/transforms.py#L9) in classification transforms as I felt it would be much faster during processing.
  2. Your implementation iterates through all the image-target pairs and passes them to copy_paste. The output is then saved to out_images and out_targets. Isn't running copy_paste() for each tensor separately slower? Cant we update the entire batch in one operation?
  3. In your implementation the output is returned by copy_paste(). I don't suppose the original images/target are being changed inplace here right? This is would nullify the purpose of self.inplace.
  4. If the changes are being done in place within copy_paste(), we would have to clone paste_image and paste_target as we don't want to pick up any image which has been modified previously by copy_paste().

My current implementation updates the targets individually first and then updates the entire image batch as I felt it would be more efficient. Please do let me know if this is incorrect, and needs improvement.

@lezwon
Copy link
Contributor Author

lezwon commented Jun 10, 2022

@lezwon if you would like to check this version of code, run on your examples and comment out if something is missing, it would be great. Thanks!

@vfdev-5 just ran it on some sample images. It's looking great so far. Will review it in detail and point out any issues if present. :)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 10, 2022

Will review it in detail and point out any issues if present. :)

Thanks @lezwon !

Meanhwile, I changed previous cropping strategy to resize of the data to paste, if sizes are different. This way input image will have the same size after pasting data into it. Previously, common minimal size was found and it could have few problems with cropped targets and also input size can change its size.

references/detection/transforms.py Show resolved Hide resolved
references/detection/transforms.py Outdated Show resolved Hide resolved
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 11, 2022

Right now, I encountered an issue in CopyPaste code with some batch during the training:

Original Traceback (most recent call last):                                                                                                                                       
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop                                                                      
    data = fetcher.fetch(index)                                                                                                                                                   
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch                                                                               
    return self.collate_fn(data)                                                                                                                                                  
  File "train.py", line 193, in copypaste_collate_fn                                                                                                                              
    return copypaste(*utils.collate_fn(batch))                                                                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                                              
    return forward_call(*input, **kwargs)                                                                                                                                         
  File "/vision/references/detection/transforms.py", line 551, in forward                                                                                                         
    output_image, output_data = _copy_paste(image, target, paste_image, paste_target, self.inplace)                                                                               
  File "/vision/references/detection/transforms.py", line 497, in _copy_paste                                                                                                     
    iscrowd = target["iscrowd"][non_all_zero_masks]                                                                                                                               
IndexError: The shape of the mask [14] at index 0 does not match the shape of the indexed tensor [15] at index 0 

Something is not aligned and rather complicated to reproduce due to randomness. Let's see if we could make it a bit more bulletproof

@lezwon
Copy link
Contributor Author

lezwon commented Jun 11, 2022

Right now, I encountered an issue in CopyPaste code with some batch during the training:

Original Traceback (most recent call last):                                                                                                                                       
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop                                                                      
    data = fetcher.fetch(index)                                                                                                                                                   
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch                                                                               
    return self.collate_fn(data)                                                                                                                                                  
  File "train.py", line 193, in copypaste_collate_fn                                                                                                                              
    return copypaste(*utils.collate_fn(batch))                                                                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                                              
    return forward_call(*input, **kwargs)                                                                                                                                         
  File "/vision/references/detection/transforms.py", line 551, in forward                                                                                                         
    output_image, output_data = _copy_paste(image, target, paste_image, paste_target, self.inplace)                                                                               
  File "/vision/references/detection/transforms.py", line 497, in _copy_paste                                                                                                     
    iscrowd = target["iscrowd"][non_all_zero_masks]                                                                                                                               
IndexError: The shape of the mask [14] at index 0 does not match the shape of the indexed tensor [15] at index 0 

Something is not aligned and rather complicated to reproduce due to randomness. Let's see if we could make it a bit more bulletproof

Ah yes, I faced this issue. The size of iscrowd and area does not match with masks, boxes for some images. Not sure about the reason for it in COCO dataset.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 looks good. Just minor comments, let me know your thoughts.

references/detection/transforms.py Outdated Show resolved Hide resolved
references/detection/transforms.py Outdated Show resolved Hide resolved
references/detection/transforms.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Last set of nits, promise.

references/detection/train.py Outdated Show resolved Hide resolved
references/detection/transforms.py Outdated Show resolved Hide resolved
references/detection/transforms.py Show resolved Hide resolved
references/detection/transforms.py Outdated Show resolved Hide resolved
references/detection/train.py Show resolved Hide resolved
references/detection/train.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezwon Thank you very very much for this piece of work. I'll merge and start immediately investigating if this strategy improves the accuracy of the FasterRCNN models. I'll get back to you as soon as I have results.

@vfdev-5 thanks a lot for the reviews and final touches on the PR.

@datumbox
Copy link
Contributor

BTW the failing tests are unrelated and they are tracked separately on a different issue. Merging.

@datumbox datumbox merged commit bbc1aac into pytorch:main Jun 15, 2022
NicolasHug added a commit to NicolasHug/vision that referenced this pull request Jun 16, 2022
Summary:
* added simple POC

* added jitter and crop options

* added references

* moved simplecopypaste to detection module

* working POC for simple copy paste in detection

* added comments

* remove transforms from class
updated the labels
added gaussian blur

* removed loop for mask calculation

* replaced Gaussian blur with functional api

* added inplace operations

* added changes to accept tuples instead of tensors

* - make copy paste functional
- make only one copy of batch and target

* add inplace support within copy paste functional

* Updated code for copy-paste transform

* Fixed code formatting

* [skip ci] removed manual thresholding

* Replaced cropping by resizing data to paste

* Removed inplace arg (as useless) and put a check on iscrowd target

* code-formatting

* Updated copypaste op to make it torch scriptable
Added fallbacks to support LSJ

* Fixed flake8

* Updates according to the review

Differential Revision: D37212651

fbshipit-source-id: 467b670164150dd5cc424f4d616d436295ce818d

Co-authored-by: vfdev-5 <vfdev.5@gmail.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
facebook-github-bot pushed a commit that referenced this pull request Jun 17, 2022
Summary:
* added simple POC

* added jitter and crop options

* added references

* moved simplecopypaste to detection module

* working POC for simple copy paste in detection

* added comments

* remove transforms from class
updated the labels
added gaussian blur

* removed loop for mask calculation

* replaced Gaussian blur with functional api

* added inplace operations

* added changes to accept tuples instead of tensors

* - make copy paste functional
- make only one copy of batch and target

* add inplace support within copy paste functional

* Updated code for copy-paste transform

* Fixed code formatting

* [skip ci] removed manual thresholding

* Replaced cropping by resizing data to paste

* Removed inplace arg (as useless) and put a check on iscrowd target

* code-formatting

* Updated copypaste op to make it torch scriptable
Added fallbacks to support LSJ

* Fixed flake8

* Updates according to the review

Reviewed By: datumbox

Differential Revision: D37212651

fbshipit-source-id: 8bb4eb613d44071d381da19c030f6c63278c3815

Co-authored-by: vfdev-5 <vfdev.5@gmail.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
@datumbox
Copy link
Contributor

datumbox commented Jun 21, 2022

Due to the way the technique combines images within the batches, the mini-batch size is expected to affect the performance. I was expecting that a naive config of setting mini-batch-size=2 would have very detrimental effects but actually it doesn't hurt much. I'm currently running for larger batch-sizes where the method makes more sense and I'll update the results above as I get them. I will also try to the recipe described on the paper.

Below I summarize the training results of using the specific augmentation.

  • mini-batch-size=2: standard recipe
    • FasterRCNN-v2: box mAP: 46.8 (internal job ids 16740 & 17380)
    • MaskRCNN-v2: box mAP: 47.1, mask mAP: 41.5 (internal job ids 16732 & 17040)
  • mini-batch-size=4: standard recipe
    • FasterRCNN-v2: box mAP: 46.6 (internal job id 19452)
    • MaskRCNN-v2: box mAP: 47.2, mask mAP: 41.7 (internal job id 19454)
  • paper recipe: --ngpus 8 --nodes 4 --dataset coco --epochs 600 --lr-steps 540 570 585 --lr 0.32 --batch-size 8 --weight-decay 0.00004 --sync-bn --data-augmentation lsj --use-copypaste
    • FasterRCNN-v2: box mAP: 46.5 (internal job id 22643)
    • MaskRCNN-v2: box mAP: 47.3, mask mAP: 41.7 (internal jobs 22644)

@datumbox
Copy link
Contributor

@lezwon @vfdev-5 I've completed the training of the *RCNN models with SimpleCopyPaste (see updated previous comment). There are some extremely small improvements which can be the result of running the training for much longer. Given this, I don't think it's worth releasing updated weights for them. Might be worth rechecking once this implementation mades it to the new Transforms API but for now I think we can conclude the work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants