In [None]:
from pinnacledb.mongodb.client import SuperDuperClient
from IPython.display import display

c = SuperDuperClient()
docs = c.ecommerce.documents

In [None]:
import torch


class Target(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def preprocess(self, value):
        return torch.tensor(float(value))
    
    def forward(self, x):
        return x
    
    
class PersonDetector(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1024, 1)
        
    def preprocess(self, value):
        return torch.tensor(value).type(torch.float)
    
    def forward(self, x):
        return self.linear(x)[:, 0]
    
    def postprocess(self, x):
        return x.item() > 0.5


docs.create_model(
    'target',
    Target(),
    serializer='dill',
)

docs.create_model(
    'person_detector',
    PersonDetector(),
    serializer='dill',
)

In [None]:
pos_weight = docs.count_documents({'person': True}) / docs.count_documents({'person': False})

docs.create_objective('bce', torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)))

In [None]:
docs.create_metric('accuracy', lambda x, y: x == y, serializer='dill')

In [None]:
job_ids = docs.create_imputation(
    'person_imputation',
    'person_detector',
    '_base',
    'target',
    'person',
    objective='bce',
    metrics=['accuracy'],
    filter_={'person': {'$exists': 1}},
    trainer_kwargs={'features': {'_base': 'clip'}, 'log_interval': 50},
)

In [None]:
job_ids

In [None]:
docs.watch_job(job_ids[0])

In [None]:
docs.list_models()

In [None]:
_id = next(docs.aggregate([
    {'$match': {}},
    {'$sample': {'size': 1}},
    {'$project': {'_id': 1}}
]))['_id']

r = docs.find_one({'_id': _id})

display(r['img'])
docs.apply_model('person_detector', r['_outputs']['_base']['clip'])

In [None]:
docs.delete_imputation('person_imputation', force=True)
docs.delete_objective('bce', force=True)
docs.delete_metric('accuracy', force=True)
docs.delete_model('person_detector', force=True)
docs.delete_model('target', force=True)