In [None]:
import pymongo
import torch
import torchvision

from pinnacledb import pinnacle

In [None]:
pymongo.MongoClient().drop_database('documents')

In [None]:
class LeNet5(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            torch.nn.BatchNorm2d(6),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = torch.nn.Linear(400, 120)
        self.relu = torch.nn.ReLU()
        self.fc1 = torch.nn.Linear(120, 84)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(84, num_classes)

    def preprocess(self, x):
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))]
        )(x)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        return out

    def postprocess(self, x):
        return int(x.topk(1)[1].item())
    
model = LeNet5(10)

In [None]:
db = pymongo.MongoClient().documents

In [None]:
db = pinnacle(db)

In [None]:
from pinnacledb.types.pillow.image import pil_image as i
from pinnacledb.core.documents import Document as D
from pinnacledb.queries.mongodb.queries import Collection

import random

mnist_data = list(torchvision.datasets.MNIST(root='./data', download=True))
data = [D({'img': i(x[0]), 'class': x[1]}) for x in mnist_data]
random.shuffle(data)

db.execute(
    Collection(name='mnist').insert_many(data[:-1000], encoders=[i])
)

In [None]:
from pinnacledb.queries.mongodb.queries import Collection

r = db.execute(
    Collection(name='mnist').find_one()
)

In [None]:
r

In [None]:
model = pinnacle(model)

In [None]:
model.predict([r['img'] for r in data[:10]])

In [None]:
# model.predict(
#     X='img',
#     db=db, 
#     select=Collection(name='mnist').find(),
#     remote=False
# )

In [None]:
from torch.optim import Adam
from torch.nn.functional import cross_entropy

from pinnacledb.core.metric import Metric
from pinnacledb.metrics.classification import compute_classification_metrics
from pinnacledb.core.dataset import Dataset
from pinnacledb.models.torch.wrapper import TorchTrainerConfiguration


job = model.fit(
    X='img',
    y='class',
    db=db,
    select=Collection(name='mnist').find(),
    configuration=TorchTrainerConfiguration(
        optimizer_cls=Adam,
        identifier='my_configuration',
        objective=cross_entropy,
        loader_kwargs={'batch_size': 10},
        max_iterations=100,
        validation_interval=10,
        compute_metrics=compute_classification_metrics,
    ),
    metrics=[Metric(identifier='acc', object=lambda x, y: x == y)],
    validation_sets=[
        Dataset(
            identifier='my_valid',
            select=Collection(name='mnist').find({'_fold': 'valid'})
        )
    ],
    remote=False
)

In [None]:
len(list(db.execute(Collection(name='mnist').find({'_fold': 'train'}))))

In [None]:
jobs = db.add(
    Watcher(
        model=model,
        key='img',
        depends_on=model.fit(
            X='img',
            y='class',
            database=db,
            select=Collection('mnist').find(),
        )
    )
)

In [None]:
jobs[0].watch()

In [None]:
from matplotlib import pyplot as plt

model = db.load('model', model.identifier)

plt.plot(model.metric_values['acc'])
plt.show()

In [None]:
jobs[1].watch()

In [None]:
for r in data[-1000:]:
    r['update'] = True

db.execute(Collection('docs').insert_many(data[-1000:]))

In [None]:
db.execute(Collection('docs').find_one({'update': True}))