# About this notebook

In this notebook I'm presenting a simplified version of this competition: how to match an image with a subcrop of itself.

You can find explanations about my motivations here : https://www.kaggle.com/competitions/image-matching-challenge-2022/discussion/323403. My final goal is obviously much more complex but I first need to solve the basics.

# QuadTreeAttention

I'm showing here how to use a pretrained version of QuadTreeAttention : https://arxiv.org/abs/2201.02767

# Updated version (version 4)

Training is now working,

# Weird behaviors

- Pretrained or fine-tuned models are good at detecting crops, but they stop working as soon as a 90° rotation is applied.

In [None]:
!pip -q install -U kornia
!pip -q install kornia-moons
!pip3 -q install torch==1.8.2+cu102 torchvision==0.9.2+cu102 torchaudio==0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
!pip -q install ninja
!pip -q install loguru
!pip -q install einops
!pip -q install timm

In [None]:
!cp -r ../input/quadtreeattention/ ../working/ # input folder is read only
! cd ../working/quadtreeattention/QuadTreeAttention-master/QuadTreeAttention/ && pip install .

In [None]:
import cv2
import kornia as K
import kornia.feature as KF
from kornia.feature.loftr import LoFTR
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import glob
import random

import torchvision
import kornia_moons.feature as KMF
from PIL import Image

import sys
sys.path.append('../working/quadtreeattention/QuadTreeAttention-master/')
sys.path.append('../working/quadtreeattention/QuadTreeAttention-master/FeatureMatching/')
sys.path.append('../working/quadtreeattention/QuadTreeAttention-master/QuadTreeAttention/')

# Utilities functions

In [None]:
from FeatureMatching.src.utils.plotting import make_matching_figure
from pathlib import Path
import matplotlib.cm as cm

def get_images_path(path):
    path_to_imgs = [str(p) for p in Path(path).rglob("**/images/*.jpg")]
    # remove macros
    return path_to_imgs

def match_and_draw_dataset(matcher, dataset, conf_thresh=0, max_img=20, device="cuda", rotate=False, plot_no_match=True):
    """
    Match and draw from all elements in a dataset
    for a specific model

    Parameters
    ----------
    matcher (torch nn module): a matcher model (LOFTR, QUADTREE)
    dataset (torch dataset): a dataset with matching pairs
    conf_thresh (float): between 0 and 1, confidence of shown matches
    max_img (int) : max images to plot
    device (str) : device to make inference
    """
    matcher.eval()
    matcher.to(device)

    for idx in range(min(len(dataset), max_img)):
        batch = dataset[idx]
        batch["image0"] = batch['image0'].unsqueeze(0).to(device)
        batch["image1"] = batch['image1'].unsqueeze(0).to(device)

        img0_raw = K.tensor_to_image(batch["raw_image0"]) #np.tile(batch["image0"].squeeze(0).cpu().numpy().transpose(1, 2, 0), 3)
        img1_raw = K.tensor_to_image(batch["raw_image1"]) #np.tile(batch["image1"].squeeze(0).cpu().numpy().transpose(1, 2, 0), 3)

        with torch.no_grad():
            matcher.eval()
            matcher.to(device)
            matcher(batch)
            mconf = batch['mconf'].cpu().numpy()
            mask_conf = mconf > conf_thresh
            mconf = mconf[mask_conf]
            mkpts0 = batch['mkpts0_f'].cpu().numpy()[mask_conf]
            mkpts1 = batch['mkpts1_f'].cpu().numpy()[mask_conf]
            
            color = cm.jet(mconf)
        if len(mkpts0)<=3 and not plot_no_match:
            print("Not enough matches")
            continue
        text = [
            'LoFTR',
            'Matches: {}'.format(len(mkpts0)),
        ]
        fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)
        plt.show()
        plt.close()

        if rotate:
            # Look at transposition
            batch["image0"] = batch['image0']
            batch["image1"] = torch.transpose(batch['image1'], 2, 3)
            img0_raw = K.tensor_to_image(batch["raw_image0"]) #np.tile(batch["image0"].squeeze(0).cpu().numpy().transpose(1, 2, 0), 3)
            img1_raw = K.tensor_to_image(batch["raw_image1"]).transpose(1, 0, 2) #np.tile(batch["image1"].squeeze(0).cpu().numpy().transpose(1, 2, 0), 3)

            with torch.no_grad():
                matcher(batch)
                mconf = batch['mconf'].cpu().numpy()
                mask_conf = mconf > conf_thresh
                mconf = mconf[mask_conf]
                mkpts0 = batch['mkpts0_f'].cpu().numpy()[mask_conf]
                mkpts1 = batch['mkpts1_f'].cpu().numpy()[mask_conf]
                
                color = cm.jet(mconf)

            text = [
                'LoFTR transpose',
                'Init Matches: {}'.format(len(mask_conf)),
                'Thresh Matches: {}'.format(len(mkpts0)),
            ]
            fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)
            plt.show()
            plt.close()

    return

# Synthetic Dataset

The goal here is simply to pick an image as input, sample a smaller crop inside and try to match the crop with the image.




In [None]:
from torch.utils.data import Dataset

class RotateSyntheticDataset(Dataset):
    # TODO : Is rotation correct ?
    def __init__(self,
                 path_to_dermoscopies,
                 img_size=(320, 320), # (64*6, 48*6) (640, 480)
                 crop_range=[2],
                 rotation_prob=0.5,
                 augment_fn=None,
                 light_mode=False):
        """
        Creates artificial dataset.
        
        Args:
            - path_to_dermoscopies (list): iterable of path to dermoscopic images
            - img_resize (int, int): Final size of image shown to model (should be divisible by 64?)
            - augment_fn (callable, optional): augments images with pre-defined visual effects.
            - crop_range (list of int): ratios between original image and image crop
        """
        super().__init__()
        self.path_to_dermoscopies = path_to_dermoscopies
        
        self.img_size = img_size
        self.crop_range = crop_range
        self.rotation_prob = rotation_prob
        self.augment_fn = augment_fn
        self.light_mode = light_mode

        # for training LoFTR
        # self.coarse_scale = 1/8 #getattr(kwargs, 'coarse_scale', 0.125) # 0.125=1/8

    def __len__(self):
        return len(self.path_to_dermoscopies)
    
    def __getitem__(self, idx):
        img_path = self.path_to_dermoscopies[idx]
        img_name = img_path.split('/')[-1]
        
        image0 = self.load_torch_image(img_path) #(h, w)

        rand_aug = np.random.rand()

        if rand_aug > 0.5:
            image0 = torch.rot90(image0, k=1, dims=[1, 2])
    
        h_init, w_init = image0.shape[1:]
        
        image1, crop_pos = self.basic_crop(image0)
        hc, wc = image1.shape[1:]
        # resize imgs
        image0 = torchvision.transforms.Resize(self.img_size)(image0)
        image1 = torchvision.transforms.Resize(self.img_size)(image1)
        
        if self.augment_fn is not None:
            image0 = image0.numpy().transpose(1, 2, 0).astype(np.uint8)
            image1 = image1.numpy().transpose(1, 2, 0).astype(np.uint8)
            image0 = self.augment_fn()(image=image0)["image"]
            image1 = self.augment_fn()(image=image1)["image"]

            image0 = torch.tensor(image0.transpose(2, 0, 1))
            image1 = torch.tensor(image1.transpose(2, 0, 1))

        scale0 = image0.shape[1:] #(h, w)
        h0, w0 = scale0
        
        depth0 = torch.ones(scale0) # everything to 1?
        # intrisinct matrix
        K_0 = torch.tensor([[1, 0, h0/2], # h0/h_init
                            [0, 1, w0/2], # w0/w_init
                            [0, 0, 1]
                           ]) # 1 to 1 pixel, centered position
        # rotation matrix
        R0 = np.diag([1, 1, 1]) # no rotation
        # translation vector
        Tv0 = np.array([[w0/2, h0/2, 0]])

        T0 = np.concatenate((R0, Tv0.T), axis=1)
        T0 = np.concatenate((T0, np.asarray([[0, 0, 0, 1]])), axis=0)
        
        random_prob = np.random.rand()
        rotate = random_prob < self.rotation_prob # 0.5
        
        # get a random crop
        if rotate:
            image1 = torch.transpose(image1, 1, 2) #torch.rot90(image1, k=1, dims=[1, 2]) #torch.transpose(image1, 1, 2)
                
        depth1 = torch.ones(scale0) # everything to 1?

        K_1 = torch.tensor([[h_init/hc, 0, h0/2], # h0/hc  , h0/2
                            [0,  w_init/wc, w0/2], # w0/wc , w0/2
                            [0, 0, 1]
                           ]) # 1 to 1 pixel, centered position
        if rotate:
            # image need to be switched the opposite
            R1 = np.array([[0, 1, 0],
                           [1, 0, 0],
                           [0, 0, 1]])
#             R1 = np.array([[cos(th), -sin(th), 0],
#                           [sin(th), cos(th), 0],
#                           [0, 0, 1]])
        else:
            R1 = np.diag([1, 1, 1]) # no rotation

        # AXES ARE MATH BASED NOT TENSOR BASED
        Tv1 = np.array([[w0-np.mean(crop_pos[1])*w0/w_init,
                           h0-np.mean(crop_pos[0])*h0/h_init,
                           0]])
        T1 = np.concatenate((R1, Tv1.T), axis=1)
        T1 = np.concatenate((T1, np.asarray([[0, 0, 0, 1]])), axis=0)
        
        T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4]  # (4, 4)
        T_1to0 = T_0to1.inverse()
            
        # SCALE_0 = 1/torch.tensor([w_init/w0, h_init/h0], dtype=torch.float)
        # SCALE_1 = 1/torch.tensor([wc/w0, hc/h0], dtype=torch.float)
        # MAYBE try to check what happens if inverse image0 and image1

        data = {
            'image0': image0.float().mean(axis=0, keepdim=True) / 255,  # (1, h, w)
            'depth0': depth0,  # (h, w)
            'image1': image1.float().mean(axis=0, keepdim=True) / 255,
            'depth1': depth1,
            'T_0to1': T_0to1,  # (4, 4)
            'T_1to0': T_1to0,
            'K0': K_0,  # (3, 3)
            'K1': K_1,
            # 'scale0': SCALE_0,  # [scale_w, scale_h]
            # 'scale1': SCALE_1,
            'dataset_name': 'scannet',
            'scene_id': idx,
            'pair_id': idx,
            'pair_names': (img_name, "macro_"+img_name),
        }
       
        if self.light_mode:
            return data
        else:
            data["raw_image0"] = image0
            data["raw_image1"] = image1
            return data

    def load_torch_image(self, fname):
        img = K.image_to_tensor(cv2.imread(fname), False)#.float()
        img = K.color.bgr_to_rgb(img).squeeze()
        c, h, w = img.shape

        return img
    
    def basic_crop(self, img):
        c, h, w = img.shape
        crop_ratio = np.random.choice(self.crop_range)
        crop_size = int(h / crop_ratio), int(w / crop_ratio)
        
        max_x = max(1, h - crop_size[0])
        max_y = max(1, w - crop_size[1])

        rand_x = np.random.randint(max_x)
        rand_y = np.random.randint(max_y)

        end_x = rand_x + crop_size[0]
        end_y = rand_y + crop_size[1]
        return img[:, rand_x:end_x, rand_y:end_y], ((rand_x,end_x), (rand_y,end_y))

In [None]:
# data module

import pytorch_lightning as pl
from torch.utils.data import DataLoader

class BasicDataModule(pl.LightningDataModule):
    def __init__(self, path_to_dermoscopies, transforms=None, img_size=(640,640),
                 batch_size: int = 4, crop_range=[2, 3, 4], num_workers = 4, rotation_prob=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.path_to_dermoscopies = path_to_dermoscopies
        self.crop_range = crop_range
        self.transforms = transforms
        self.num_workers = num_workers
        self.rotation_prob = rotation_prob
        self.img_size = img_size
    
    def setup(self, stage):
        return

    def train_dataloader(self):
#         return DataLoader(SyntheticDataset(path_to_dermoscopies, augment_fn=non_geom_transforms),
#                           batch_size=self.batch_size, num_workers=4)
        return DataLoader(RotateSyntheticDataset(self.path_to_dermoscopies,
                                                 augment_fn=self.transforms,
                                                 crop_range=self.crop_range,
                                                 img_size=self.img_size,
                                                 rotation_prob=self.rotation_prob,
                                                 light_mode=True
                                                ),
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(RotateSyntheticDataset(self.path_to_dermoscopies,
                                           augment_fn=None,
                                           crop_range=self.crop_range,
                                           img_size=self.img_size,
                                           rotation_prob=self.rotation_prob,
                                           light_mode=True
                                           ),
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(RotateSyntheticDataset(self.path_to_dermoscopies,
                                                crop_range=self.crop_range,
                                                img_size=self.img_size,
                                                rotation_prob=self.rotation_prob,
                                                augment_fn=None,
                                                light_mode=True
                                                ),
                          batch_size=self.batch_size, num_workers=self.num_workers)

    def predict_dataloader(self):
        return DataLoader(RotateSyntheticDataset(self.path_to_dermoscopies,
                                                crop_range=self.crop_range,
                                                img_size=self.img_size,
                                                rotation_prob=self.rotation_prob,
                                                augment_fn=None,
                                                light_mode=True
                                                ),
                          batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
# transformations for augmentation
import albumentations as albu
def non_geom_transforms(p=0.2):
    return albu.Compose([
        albu.ColorJitter(p=0.5),
        albu.RandomRain(p=0.05),  # random occlusion
        albu.RandomSunFlare(src_radius=50, p=0.1),
        albu.ImageCompression(p=0.25),
        albu.ISONoise(p=0.25),
        
        albu.OneOf(
            [
                albu.MotionBlur(blur_limit=(3, 5), always_apply=True),
                albu.GaussianBlur(blur_limit=(3, 5), always_apply=True),
            ],
            p=p,
        ),
        albu.OneOf(
            [
                albu.GaussNoise(var_limit=(1.0, 2.0), always_apply=True),
                 albu.RandomFog(fog_coef_lower=0.01, fog_coef_upper=0.25, always_apply=True),
                 albu.MultiplicativeNoise(multiplier=(0.75, 1.25), elementwise=True, always_apply=True),  # noqa
            ],
            p=p,
        ),
         albu.CoarseDropout(
                    max_holes=8,
                    max_height=16,
                    max_width=32,
                    min_holes=2,
                    min_height=2,
                    min_width=2,
                    fill_value=0,
                    always_apply=True,
                )
    ])

# Config for pretrained Quadtree

In [None]:
from FeatureMatching.src.config.default import get_cfg_defaults
config = get_cfg_defaults()
CROP_RANGE = [2, 3, 4] # 2
TRANSFORMS = non_geom_transforms # None
NB_EPOCHS = 10


# INDOOT lofrt_ds_quadtree config
config.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
config.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
config.LOFTR.RESNETFPN.INITIAL_DIM = 128
config.LOFTR.RESNETFPN.BLOCK_DIMS=[128, 196, 256]
config.LOFTR.COARSE.D_MODEL = 256
config.LOFTR.COARSE.BLOCK_TYPE = 'quadtree'
config.LOFTR.COARSE.ATTN_TYPE = 'B'
config.LOFTR.COARSE.TOPKS=[32, 16, 16]
config.LOFTR.FINE.D_MODEL = 128
config.TRAINER.WORLD_SIZE = 1 # 8
config.TRAINER.CANONICAL_BS = 32
config.TRAINER.TRUE_BATCH_SIZE = 1
_scaling = 1
config.TRAINER.ENABLE_PLOTTING = False
config.TRAINER.SCALING = _scaling
config.TRAINER.TRUE_LR = 1e-4 # 1e-4 config.TRAINER.CANONICAL_LR * _scaling
config.TRAINER.WARMUP_STEP = 0 #math.floor(config.TRAINER.WARMUP_STEP / _scaling)

In [None]:
# arguments 

import argparse
def parse_args():
    # init a costum parser which will be added into pl.Trainer parser
    # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'data_cfg_path', type=str, help='data config path')
    parser.add_argument(
        'main_cfg_path', type=str, help='main config path')
    parser.add_argument(
        '--exp_name', type=str, default='default_exp_name')
    parser.add_argument(
        '--batch_size', type=int, default=4, help='batch_size per gpu')
    parser.add_argument(
        '--num_workers', type=int, default=4)
    parser.add_argument(
        '--pin_memory', type=lambda x: bool(strtobool(x)),
        nargs='?', default=True, help='whether loading data to pinned memory or not')
    parser.add_argument(
        '--ckpt_path', type=str, default="../input/kornia-loftr/outdoor_ds.ckpt",
        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
    parser.add_argument(
        '--disable_ckpt', action='store_true',
        help='disable checkpoint saving (useful for debugging).')
    parser.add_argument(
        '--profiler_name', type=str, default=None,
        help='options: [inference, pytorch], or leave it unset')
    parser.add_argument(
        '--parallel_load_data', action='store_true',
        help='load datasets in with multiple processes.')

    parser = pl.Trainer.add_argparse_args(parser)
    nb_epochs = NB_EPOCHS # 20
    return parser.parse_args(f'../input/loftrutils/LoFTR-master/LoFTR-master/configs/data/megadepth_trainval_640.py ../input/loftrutils/LoFTR-master/LoFTR-master/configs/loftr/outdoor/loftr_ds_dense.py --exp_name test --gpus 0 --num_nodes 0 --accelerator gpu --batch_size 1 --check_val_every_n_epoch 1 --log_every_n_steps 1 --flush_logs_every_n_steps 1 --limit_val_batches 1 --num_sanity_val_steps 10 --benchmark True --max_epochs {nb_epochs}'.split())

from pytorch_lightning.utilities import rank_zero_only
import pytorch_lightning as pl
import pprint
args = parse_args()
rank_zero_only(pprint.pprint)(vars(args))

In [None]:
train_images = get_images_path('../input/image-matching-challenge-2022/train/brandenburg_gate/')

# Define Trainer

In [None]:
from FeatureMatching.src.utils.profiler import build_profiler
from FeatureMatching.src.lightning.lightning_loftr import PL_LoFTR
from loguru import logger as loguru_logger
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor


# lightning module
disable_ckpt = True
profiler_name = None # help='options: [inference, pytorch], or leave it unset
profiler = build_profiler(profiler_name)
model = PL_LoFTR(config,
                 pretrained_ckpt= "../input/quadtreecheckpoints/outdoor_quadtree.ckpt", # args.ckpt_path, from scratch atm
                 profiler=profiler)
loguru_logger.info(f"LoFTR LightningModule initialized!")

# lightning data
data_module = BasicDataModule(train_images, transforms=TRANSFORMS, crop_range=CROP_RANGE)
loguru_logger.info(f"LoFTR DataModule initialized!")

# TensorBoard Logger
logger = TensorBoardLogger(save_dir="../working/logs",
                           name="test_kaggle",
                           default_hp_metric=False)
ckpt_dir = Path(logger.log_dir) / 'checkpoints'

# Callbacks
# TODO: update ModelCheckpoint to monitor multiple metrics
ckpt_callback = ModelCheckpoint(monitor='auc@10', verbose=True, save_top_k=5, mode='max',
                                save_last=True,
                                dirpath=str(ckpt_dir),
                                filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks = [lr_monitor]
if not disable_ckpt:
    callbacks.append(ckpt_callback)

# Lightning Trainer
trainer = pl.Trainer.from_argparse_args(
                    args=args,
#                     plugins=DDPPlugin(find_unused_parameters=False,
#                                       num_nodes=num_nodes,
#                                       sync_batchnorm=False, #config.TRAINER.WORLD_SIZE > 0
#                                      ),
                    gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
                    callbacks=callbacks,
                    logger=logger,
#                     sync_batchnorm=False, #config.TRAINER.WORLD_SIZE > 0,
                    replace_sampler_ddp=False,  # use custom sampler
#                     reload_dataloaders_every_epoch=False,  # avoid repeated samples!
                    weights_summary='full',
                    profiler=profiler)


# BEFORE TRAINING (rely on pretraining weights)

As you can see, the pretrained model is pretty good with basic crops.

However, even a basic 90° rotation make things wrong.

In [None]:
CONF_THRESH = 0.
MAX_IMG = 5

In [None]:
# Have a look at training images
dataset = RotateSyntheticDataset(train_images, augment_fn=None, crop_range=CROP_RANGE)
match_and_draw_dataset(matcher=model.matcher,
                       dataset=dataset,
                       conf_thresh=CONF_THRESH,
                       max_img=2,
                       rotate=True
                      )

In [None]:
# Have a look at validaiton images

path_to_imgs = get_images_path("../input/image-matching-challenge-2022/train/notre_dame_front_facade/")
dataset = RotateSyntheticDataset(path_to_imgs, augment_fn=None, crop_range=CROP_RANGE)
match_and_draw_dataset(matcher=model.matcher,
                       dataset=dataset,
                       conf_thresh=CONF_THRESH,
                       max_img=MAX_IMG,
                       rotate=True
                      )

# Now let's train for only a few epochs

Question : Could adding rotation during training improve this behaviour ?

In [None]:
trainer.running_sanity_check = False
loguru_logger.info(f"Trainer initialized!")
loguru_logger.info(f"Start training!")
trainer.fit(model, datamodule=data_module)

In [None]:
# Have a look at training images
dataset = RotateSyntheticDataset(train_images, augment_fn=None, crop_range=CROP_RANGE)
match_and_draw_dataset(matcher=model.matcher,
                       dataset=dataset,
                       conf_thresh=CONF_THRESH,
                       max_img=10,
                       rotate=True
                      )