https://towardsdatascience.com/image-similarity-theory-and-code-2b7bcce96d0a

In [1]:
from fastai.vision.all import *
CUDA_LAUNCH_BLOCKING=1

GeForce RTX 3080 Ti with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the GeForce RTX 3080 Ti GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



In [2]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

In [3]:
def label_func(fname):
    return re.match(r'^(.+)_\d+.jpg$', fname.name).groups()[0]

In [4]:
labels = L(map(label_func, files)).unique()
len(labels)

37

In [5]:
class SiameseImage(fastuple):
    def show(self, ctx=None, **kwargs):
        if len(self) > 2:
            img1,img2,similarity = self
        else:
            img1,img2 = self
            similarity = 'Undetermined'
        if not isinstance(img1, Tensor):
            if img2.size != img1.size: img2 = img2.resize(img1.size)
            t1,t2 = tensor(img1),tensor(img2)
            t1,t2 = t1.permute(2,0,1),t2.permute(2,0,1)
        else: t1,t2 = img1,img2
        line = t1.new_zeros(t1.shape[0],t1.shape[1],10)
        return show_image(torch.cat([t1,line,t2], dim=2),title=similarity, ctx=ctx, **kwargs)

class SiameseTransform(Transform):
    def __init__(self, files, splits):
        self.splbl2files = [{l:[f for f in files[splits[i]] if label_func(f)==l] for l in labels}
                            for i in range(2)]
        self.valid = {f:self._draw(f,1) for f in files[splits[1]]}

    def encodes(self, f):
        f2,same = self.valid.get(f, self._draw(f))
        im1,im2 = PILImage.create(f),PILImage.create(f2)
        return SiameseImage(im1,im2,int(same))

    def _draw(self, f, splits=0):
        same = random.random() <.5
        cls = label_func(f)
        if not same: cls = random.choice([l for l in labels if l != cls])
        return random.choice([f for f in self.splbl2files[splits][cls]]),same

In [1]:
# splits = RandomSplitter(seed=23)(files)
# tfm = SiameseTransform(files, splits)
# tls = TfmdLists(files, tfm, splits=splits)
# sdls = tls.dataloaders(after_item=[Resize(224), ToTensor], 
#     after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])