# Data Preparation

In [None]:
! curl -L https://github.com/towhee-io/examples/releases/download/data/reverse_image_search.zip -O
! unzip -q -o reverse_image_search.zip

# Code Preparation

In [None]:
import sys
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision import models

import towhee


@towhee.register
class image_normalize:
    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self._mean = mean
        self._std = std
        
    def __call__(self, x):
        x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
        return TF.normalize(x, mean=self._mean, std=self._std).detach().numpy()
    
    def __vcall__(self, x):
#         print('=', 'image_normalize', file=sys.stderr)
        x = torch.tensor(x, dtype=torch.float32).permute(0, 3, 1, 2)
        return TF.normalize(x, mean=self._mean, std=self._std).detach().numpy()

@towhee.register
class image_embedding:
    def __init__(self):
        torch_model = models.resnet18(pretrained=True)
        torch_model = torch.nn.Sequential(*(list(torch_model.children())[:-1]))
        torch_model.to('cuda' if torch.cuda.is_available() else 'cpu')
        _ = torch_model.eval()
        
        self._model = torch_model

    def __call__(self, imgs):
        imgs = torch.tensor(imgs)
        imgs = torch.unsqueeze(imgs, 0)
        embedding = self._model(imgs).detach().cpu().numpy()
        return embedding.reshape([512])
    
    def __vcall__(self, imgs):
        imgs = torch.tensor(imgs)
#         print('=', 'image_embedding', file=sys.stderr)
        embedding = self._model(imgs).detach().cpu().numpy()
        return embedding.reshape([-1, 512])

# Benchmarks

## Convert to Arrow-based DataFrame

In [None]:
%%time

dc = towhee.read_csv('reverse_image_search.csv').unstream() \
    .runas_op['id', 'id'](func=lambda x: int(x)) \
    .image_decode['path', 'img']() \
    .image_resize['img', 'img'](dsize=[224, 224]) \
    .to_column()

## Row-based DataFrame

In [None]:
%%time
dc = towhee.read_csv('reverse_image_search.csv').unstream() \
    .head(100) \
    .runas_op['id', 'id'](func=lambda x: int(x)) \
    .image_decode['path', 'img']() \
    .image_resize['img', 'img'](dsize=[224, 224]) \
    .image_normalize['img', 'rimg']() \
    .image_embedding['rimg', 'embedding']()

## Col-based DataFrame

In [None]:
%%time
dc = towhee.read_csv('reverse_image_search.csv').unstream() \
    .head(100) \
    .runas_op['id', 'id'](func=lambda x: int(x)) \
    .image_decode['path', 'img']() \
    .image_resize['img', 'img'](dsize=[224, 224]) \
    .to_column() \
    .image_normalize['img', 'rimg']() \
    .image_embedding['rimg', 'embedding']()

## Chunked DataFrame

In [None]:
%%time
dc = towhee.read_csv('reverse_image_search.csv').unstream() \
    .head(100) \
    .runas_op['id', 'id'](func=lambda x: int(x)) \
    .image_decode['path', 'img']() \
    .image_resize['img', 'img'](dsize=[224, 224]) \
    .set_chunksize(5) \
    .image_normalize['img', 'rimg']() \
    .image_embedding['rimg', 'embedding']()