In [1]:
from superduperdb.mongodb.client import SuperDuperClient

c = SuperDuperClient()
docs = c.yondo.documents

In [None]:
docs

In [2]:
docs.database['_meta'].find_one()

{'_id': ObjectId('63bc0aa9456b1bb8fc0a78dc'),
 'key': 'html_template',
 'value': '\n<div><b>{{ r[\'brand\'] }} - {{ r[\'title\'] }}</b></div>\n<img src="{{ r[\'img\'][\'_content\'][\'url\'] }}" />\n'}

In [None]:
import io
import numpy
from PIL import Image
import torch


class PILImage:
    @staticmethod
    def encode(x):
        buffer = io.BytesIO()
        x.save(buffer, format='png')
        return buffer.getvalue()

    @staticmethod
    def decode(bytes_):
        return Image.open(io.BytesIO(bytes_))


class FloatTensor:
    types = (torch.FloatTensor, torch.Tensor)

    @staticmethod
    def encode(x):
        x = x.numpy()
        assert x.dtype == numpy.float32
        return memoryview(x).tobytes()

    @staticmethod
    def decode(bytes_):
        array = numpy.frombuffer(bytes_, dtype=numpy.float32)
        return torch.from_numpy(array).type(torch.float)

In [None]:
docs.create_type('float_tensor', FloatTensor, serializer='dill')
docs.create_type('image', PILImage(), serializer='dill')

In [4]:
import torch
from clip import load, tokenize


class Image(torch.nn.Module):
    def __init__(self, model, preprocess):
        super().__init__()
        self.model = model
        self.preprocess = preprocess

    def forward(self, x):
        return self.model.encode_image(x)


class Text(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model.encode_text(x)

    def preprocess(self, x):
        return tokenize(x)[0]
    


class CLIP(torch.nn.Module):
    def __init__(self, name):
        super().__init__()
        model, preprocess = load(name)
        self.image = Image(model, preprocess)
        self.text = Text(model)

    def preprocess(self, r):
        out = {}
        if "brand" in r or "title" in r:
            out["text"] = self.text.preprocess(f'{r.get("brand", "")} {r.get("title", "")}')
        if "img" in r:
            out["image"] = self.image.preprocess(r['img'])
        assert out
        return out

    def forward(self, r):
        assert r
        key = next(iter(r.keys()))
        bs = r[key].shape[0]
        out = torch.zeros(bs, 1024).to(r[key].device)
        n = 0
        if 'image' in r:
            tmp = self.image.forward(r['image'])
            tmp = tmp.div(tmp.pow(2).sum(axis=1).sqrt()[:, None])
            out += tmp
            n += 1
        if 'text' in r:
            tmp = self.text.forward(r['text'])
            tmp = tmp.div(tmp.pow(2).sum(axis=1).sqrt()[:, None])
            out += tmp
            n += 1
        return out / n

In [None]:
docs.create_model('clip', CLIP('RN50'), serializer='dill')

In [None]:
docs.database.get_object_info('clip', 'model')

In [None]:
docs['_objects'].find_one()

In [None]:
def dot(x, y):
    return x.matmul(y.T)


def css(x, y):
    x = x.div(x.norm(dim=1)[:, None])
    y = y.div(y.norm(dim=1)[:, None])
    return dot(x, y)

docs.create_measure('css', css, serializer='dill')

In [None]:
docs.create_semantic_index(
    'clip', ['clip'], ['_base'], 'css', loader_kwargs={'batch_size': 10, 'num_workers': 0},
    verbose=True,
)

# lots of output - takes a while...

In [None]:
docs.list_models()

In [None]:
docs.list_watchers()

In [None]:
docs.find_one()

In [None]:
import torch
from clip import load, tokenize


class ClassifierSimple(torch.nn.Module):
    def __init__(self, categories, name):
        super().__init__()
        self.categories = categories
        model, _ = load(name)
        category_vectors = \
            model.encode_text(torch.cat([tokenize(x) for x in categories], 0))
        category_vectors = category_vectors / category_vectors.norm(dim=1, keepdim=True)
        logit_scale = model.logit_scale.exp()
        self.register_buffer('category_vectors', category_vectors)
        self.register_buffer('logit_scale', logit_scale)
        
    @property
    def device(self):
        return self.category_vectors.device

    def preprocess(self, x):
        if isinstance(x, dict):
            x = x['_outputs']['_base']['clip']
        else:
            assert isinstance(x, torch.Tensor)
        return x

    def forward(self, x):
        x = x / x.norm(dim=1, keepdim=True)
        logits_per_image = self.logit_scale * x @ self.category_vectors.t()
        out = logits_per_image.softmax(dim=-1)
        return out

    def postprocess(self, x):
        pos = x.topk(1)[1].item()
        return self.categories[pos]

In [None]:
docs.create_model(
    'silhouettes', 
    ClassifierSimple(
        name='RN50',
        categories=[
            'accessory',
            'blouse',
            'coat',
            'dress',
            'hat',
            'hoodie',
            'jacket',
            'pullover',
            'shoes',
            'skirt',
            't-shirt',
            'trousers',
        ]
    ),
    serializer='dill',
)

In [None]:
docs.delete_watcher('silhouettes')

In [None]:
docs.create_watcher(
    'silhouettes',
    'silhouettes',
    features={'_base': 'clip'},
    loader_kwargs={'batch_size': 10, 'num_workers': 0},
    verbose=True,
)

In [None]:
docs.find_one(like={'title': 'leopard print t-shirt'}, semantic_index='clip')['img']

In [None]:
docs.remote = False
docs.find_one(like={'img': {'_content': {'url': 'https://thumblr.uniid.it/product/238107/09ef5396fac2.jpg', 'type': 'image'}}},
              semantic_index='clip', download=True)['img']

In [None]:
import torch
c.yondo.types['float_tensor'].encode(torch.randn(32))