In [1]:
from pinnacledb.mongodb.client import SuperDuperClient
from IPython.display import display, Image as I

c = SuperDuperClient()
docs = c.yondo.documents

In [None]:
I('../docs/img/architecture_detailed.png')

In [2]:
docs.list_models()

['clip']

In [None]:
import io
import numpy
import PIL.Image
import PIL.JpegImagePlugin, PIL.PngImagePlugin
import torch


class PILImage:
    types = (PIL.JpegImagePlugin.JpegImageFile, PIL.PngImagePlugin.PngImageFile)
    
    @staticmethod
    def encode(x):
        buffer = io.BytesIO()
        x.save(buffer, format='png')
        return buffer.getvalue()

    @staticmethod
    def decode(bytes_):
        return PIL.Image.open(io.BytesIO(bytes_))


class FloatTensor:
    types = (torch.FloatTensor, torch.Tensor)

    @staticmethod
    def encode(x):
        x = x.numpy()
        assert x.dtype == numpy.float32
        return memoryview(x).tobytes()

    @staticmethod
    def decode(bytes_):
        array = numpy.frombuffer(bytes_, dtype=numpy.float32)
        return torch.from_numpy(array).type(torch.float)
    
docs.create_type('float_tensor', FloatTensor, serializer='dill')
docs.create_type('image', PILImage(), serializer='dill')

In [6]:
import torch
from clip import load, tokenize


class Image(torch.nn.Module):
    def __init__(self, model, preprocess):
        super().__init__()
        self.model = model
        self.preprocess = preprocess

    def forward(self, x):
        return self.model.encode_image(x)


class Text(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model.encode_text(x)

    def preprocess(self, x):
        return tokenize(x)[0]
    


class CLIP(torch.nn.Module):
    def __init__(self, name):
        super().__init__()
        model, preprocess = load(name)
        self.image = Image(model, preprocess)
        self.text = Text(model)

    def preprocess(self, r):
        out = {}
        if "brand" in r or "title" in r:
            out["text"] = self.text.preprocess(f'{r.get("brand", "")} {r.get("title", "")}')
        if "img" in r:
            out["image"] = self.image.preprocess(r['img'])
        assert out
        return out

    def forward(self, r):
        assert r
        key = next(iter(r.keys()))
        bs = r[key].shape[0]
        out = torch.zeros(bs, 1024).to(r[key].device)
        n = 0
        if 'image' in r:
            tmp = self.image.forward(r['image'])
            tmp = tmp.div(tmp.pow(2).sum(axis=1).sqrt()[:, None])
            out += tmp
            n += 1
        if 'text' in r:
            tmp = self.text.forward(r['text'])
            tmp = tmp.div(tmp.pow(2).sum(axis=1).sqrt()[:, None])
            out += tmp
            n += 1
        return out / n

    
docs.create_model('clip', CLIP('RN50'), serializer='dill', type='float_tensor')

In [None]:
def dot(x, y):
    return x.matmul(y.T)


def css(x, y):
    x = x.div(x.norm(dim=1)[:, None])
    y = y.div(y.norm(dim=1)[:, None])
    return dot(x, y)

docs.create_measure('css', css, serializer='dill')

In [8]:
docs.delete_semantic_index('clip', force=True)

unsetting output field _outputs._base.clip


In [9]:
job_ids = docs.create_semantic_index(
    'clip', ['clip'], ['_base'], 'css', loader_kwargs={'batch_size': 10, 'num_workers': 0},
    verbose=True,
)

In [10]:
docs.list_jobs()

[{'identifier': 'fe52a521-ae91-4148-8e67-c8d15a34ab6b',
  'time': datetime.datetime(2023, 4, 14, 15, 30, 0, 246000),
  'status': 'failed',
  'method': '_process_documents_with_watcher'},
 {'identifier': '5c2afda4-421b-46e4-ae8c-7422f19709f7',
  'time': datetime.datetime(2023, 4, 14, 21, 25, 13, 123000),
  'status': 'running',
  'method': '_process_documents_with_watcher'}]

In [None]:
docs.watch_job(job_ids[0])

computing chunk (1/7)
finding documents under filter
done.
processing with clip
 47%|####6     | 233/500 [05:54<06:45,  1.52s/it]

In [None]:
docs.cancel_job('fe52a521-ae91-4148-8e67-c8d15a34ab6b')

In [None]:
r = docs.find_one()

print('anchor image:')
display(r['img'])

for r in docs.find(like={'_id': r['_id']}, semantic_index='clip'):
    print(r['_score'])
    display(r['img'])

In [None]:
for r in docs.find(like={'title': 'leopard print t-shirt'}, semantic_index='clip'):
    print(r['_score'])
    display(r['img'])

In [None]:
from IPython.display import Image as I

url = 'https://thumblr.uniid.it/product/238107/09ef5396fac2.jpg'
docs.remote = True

display(I(url=url, width=200))

cur = docs.find(like={'img': {'_content': {'url': url, 'type': 'image'}}}, semantic_index='clip', download=True)
for r in cur:
    display(r['img'])

In [None]:
import torch
from clip import load, tokenize


class ClassifierSimple(torch.nn.Module):
    def __init__(self, categories, name):
        super().__init__()
        self.categories = categories
        model, _ = load(name)
        category_vectors = \
            model.encode_text(torch.cat([tokenize(x) for x in categories], 0))
        category_vectors = category_vectors / category_vectors.norm(dim=1, keepdim=True)
        logit_scale = model.logit_scale.exp()
        self.register_buffer('category_vectors', category_vectors)
        self.register_buffer('logit_scale', logit_scale)
        
    @property
    def device(self):
        return self.category_vectors.device

    def preprocess(self, x):
        if isinstance(x, dict):
            x = x['_outputs']['_base']['clip']
        else:
            assert isinstance(x, torch.Tensor)
        return x

    def forward(self, x):
        x = x / x.norm(dim=1, keepdim=True)
        logits_per_image = self.logit_scale * x @ self.category_vectors.t()
        out = logits_per_image.softmax(dim=-1)
        return out

    def postprocess(self, x):
        pos = x.topk(1)[1].item()
        return self.categories[pos]

In [None]:
docs.create_model(
    'silhouettes', 
    ClassifierSimple(
        name='RN50',
        categories=[
            'accessory',
            'blouse',
            'coat',
            'dress',
            'hat',
            'hoodie',
            'jacket',
            'pullover',
            'shoes',
            'skirt',
            't-shirt',
            'trousers',
        ]
    ),
    serializer='dill',
)

In [None]:
r = docs.find_one()
print(r['title'])
display(r['img'])
docs.apply_model('silhouettes', r)

In [None]:
docs.list_watchers()

In [None]:
docs.create_watcher?

In [None]:
job_id = docs.create_watcher(
    'silhouettes',
    features={'_base': 'clip'},
    loader_kwargs={'batch_size': 10, 'num_workers': 0},
    verbose=True,
)

In [None]:
docs.watch_job(job_id)

In [None]:
list(docs.find().limit(3))

In [None]:
update = list(docs.find({}, {'_id': 0, '_outputs': 0, '_fold': 0, 'img._content.bytes': 0}, raw=True).limit(50))
for r in update:
    r['update'] = True

In [None]:
_, job_ids = docs.insert_many(update)
job_ids

In [None]:
docs.watch_job(job_ids['_download_content'][0])

In [None]:
docs.watch_job(job_ids['watcher', 'clip/_base'][0])

In [None]:
docs.watch_job(job_ids['watcher', 'silhouettes/_base'][0])

In [None]:
docs.find_one({'update': True})

In [None]:
docs.delete_many({'update': True})

docs.delete_measure('css', force=True)
docs.delete_semantic_index('clip', force=True)

docs.delete_watcher('silhouettes/_base', force=True)
docs.delete_watcher('clip/_base', force=True)

docs.delete_model('silhouettes', force=True)
docs.delete_model('clip', force=True)

docs.delete_type('float_tensor', force=True)
docs.delete_type('image', force=True)

In [None]:
import bson

h = list(docs.find({}, {'_outputs._base.clip': 1}, raw=True))
h = {f'{i}': hh for i, hh in enumerate(h)}

with open('clip_bak.bson', 'wb') as f:
    f.write(bson.BSON.encode(h))

In [None]:
import bson
with open('clip_bak.bson', 'rb') as f:
    h = bson.BSON.decode(f.read())
    
h = [h[f'{i}'] for i in range(len(h))]

docs.bulk_write([
    UpdateOne({'_id': h[i]['_id']}, {'$set': {f'_outputs._base.clip': h[i]}})
    for i in range(len(h))
])

In [None]:
docs.bulk_write([
    UpdateOne({'_id': h[i]['_id']}, {'$set': {f'_outputs._base.clip': h[i]}})
    for i in range(len(h))
])