Hi, everyone！
This kernel is based on the work [WBF over TTA](https://www.kaggle.com/shonenkov/wbf-over-tta-single-model-efficientdet) of [Alex shonenkov](https://www.kaggle.com/shonenkov).Thanks for his sharing!

I just add another transform and it gives me a slight improvement.

Hope you enjoy it！

In [None]:
import albumentations as A

class BaseWheatTTA:
    """ author: @shonenkov """
    image_size = 512
    
    def augment(self,image):
        raise NotImplementedError
        
    def batch_augment(self,images):
        raise NotImplementedError
        
    def deaugment_boxes(self,boxes):
        raise NotImplementedError
        
class TTAVerticalFlip(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self,image):
        return image.flip(1)
    
    def batch_augment(self,images):
        return images.flip(2)
    
    def deaugment_boxes(self,boxes):
        boxes[:,[1,3]] = self.image_size - boxes[:,[3,1]]
        return boxes
    
class TTAHorizontalFlip(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return image.flip(2)
    
    def batch_augment(self, images):
        return images.flip(3)
    
    def deaugment_boxes(self, boxes):
        boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]]
        return boxes
    
class TTARotate90(BaseWheatTTA):
    """ author: @shonenkov """
    
    def augment(self, image):
        return torch.rot90(image, 1, (1, 2))

    def batch_augment(self, images):
        return torch.rot90(images, 1, (2, 3))
    
    def deaugment_boxes(self, boxes):
        res_boxes = boxes.copy()
        res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]]
        res_boxes[:, [1,3]] = boxes[:, [2,0]]
        return res_boxes
    
class TTAHSV_or_RBC(BaseWheatTTA):
    """HSV or RBC"""
    transform = A.OneOf([
                A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit= 0.2, 
                                     val_shift_limit=0.2,p=0.5),
                A.RandomBrightnessContrast(brightness_limit=0.2, 
                                           contrast_limit=0.2, p=0.5),
            ],p=1)
    
    def augment(self,image):
        image = image.permute(1,2,0).cpu().numpy()
        image = self.transform(image=image)['image']
        return ToTensorV2()(image=image)['image']

    def batch_augment(self, images):
        batch_size = len(images)
        for i in range(batch_size):
            images[i] = self.augment(images[i])
        return images
        
    def deaugment_boxes(self,boxes):
        return boxes
    
class TTACompose(BaseWheatTTA):
    """ author: @shonenkov """
    def __init__(self, transforms):
        self.transforms = transforms
        
    def augment(self, image):
        for transform in self.transforms:
            image = transform.augment(image)
        return image
    
    def batch_augment(self, images):
        for transform in self.transforms:
            images = transform.batch_augment(images)
        return images
    
    def prepare_boxes(self, boxes):
        result_boxes = boxes.copy()
        result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1)
        result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1)
        result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1)
        result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1)
        return result_boxes
    
    def deaugment_boxes(self, boxes):
        for transform in self.transforms[::-1]:
            boxes = transform.deaugment_boxes(boxes)
        return self.prepare_boxes(boxes)