# MNIST inference in model cards

In [None]:
import typing
from datetime import datetime
import PIL
import os

import weave
from weave import panels

import wandb
import wandb.apis.reports as wb
wandb.login()

In [None]:
import pprint

import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets, transforms

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_transform():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))])
    return transform


def build_dataset(batch_size=100, train=True):
    dataset = datasets.MNIST(".", train=train, download=False,
        transform=get_transform())
    if batch_size is None:
        batch_size = dataset.data.shape[0]
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    return loader


def build_network(fc_layer_size, dropout):
    network = nn.Sequential(  # fully-connected, single hidden layer
        nn.Flatten(),
        nn.Linear(784, fc_layer_size), nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(fc_layer_size, 10),
        nn.LogSoftmax(dim=1))
    return network.to(DEVICE)
        

def build_optimizer(network, optimizer, learning_rate):
    optimizer = optim.Adam(network.parameters(),
        lr=learning_rate)
    return optimizer


def train_epoch(network, loader, optimizer):
    cumu_loss = 0
    for _, (data, target) in enumerate(loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        loss = F.nll_loss(network(data), target)
        cumu_loss += loss.item()
        loss.backward()
        optimizer.step()
        wandb.log({"batch loss": loss.item()})
    return cumu_loss / len(loader)


def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        loader = build_dataset(config.batch_size)
        network = build_network(config.fc_layer_size, config.dropout)
        optimizer = build_optimizer(network, config.optimizer, config.learning_rate)
        for epoch in range(config.epochs):
            avg_loss = train_epoch(network, loader, optimizer)
            wandb.log({"loss": avg_loss, "epoch": epoch})
    return network

In [None]:
# workaround to fetch MNIST data
if not os.path.exists('./MNIST'):
    !wget www.di.ens.fr/~lelarge/MNIST.tar.gz
    !tar -zxvf MNIST.tar.gz

In [None]:
config = {
    'fc_layer_size': 256,
    'dropout': 0.5,
    'epochs': 50,
    'learning_rate': 0.005,
    'batch_size': 128,
    'optimizer': 'adam'
}

In [None]:
network = train(config)

In [None]:
test_loader = build_dataset(batch_size=None, train=False)
test_images, test_labels = next(iter(test_loader))
inferred_labels = network(test_images)

_, predicted = torch.max(inferred_labels.data, 1)
correct = (predicted == test_labels).sum().item()
accuracy = correct / test_labels.size(0)

print(f'Network accuracy is {100 * accuracy}%')

In [None]:
path = "./fashion_mnist_results.png"
image = PIL.Image.open(path)

def get_model_card_args(entity, project, run_id, artifact_index):
    api = wandb.Api()
    run = api.run(f'{entity}/{project}/{run_id}')
    collection_name = 'fashion_style'
    run_name = run.name
    artifact = run.logged_artifacts()[0]
    updated = artifact.updated_at.strip()
    updated_dt = datetime.strptime(updated, '%Y-%m-%dT%H:%M:%S')
    path = "./fashion_mnist_results.png"
    image = PIL.Image.open(path)
    card_args = {
        'model_name': artifact.name,
        'created_by': User(name=run.entity),
        'updated_at': f'{updated_dt:%d/%m/%Y at %H:%M:%S}',
        'model_type': 'CNN',
        'application': 'Classifying clothing images by type',
        'primary_metric': TargetMetric(name='acc', direction='up'),
        'example': weave.save(image),
        'limitations': 'Not useful for realistic examples.'
    }
    full_args = {
        'entity': entity,
        'project': project,
        'run_id': run_id,
        'artifact_index': artifact_index,
        'entity_name': run.entity,
        'pil_image': image
    }
    full_args.update(card_args)
    return card_args, full_args

In [None]:
# This Type should be built-in to Weave, declared in weave.types
@weave.type()
class User:
    name: str

# This Type will be a built-in to Weave, declared in weave.types
@weave.type()
class TargetMetric:
    name: str
    direction: str # typing.Union['up','down']  # (TODO: enum)

class MarkdownString(weave.types.Type):
    pass

@weave.type()
class ExampleImage():
    instance_class = PIL.PngImagePlugin.PngImageFile
    instance_classes = PIL.PngImagePlugin.PngImageFile
    pass
        
@weave.type()
class ModelCard:
    model_name: str
    created_by: User
    updated_at: str  # TODO: timestamp
    model_type: str  # TODO: enum
    primary_metric: TargetMetric
    application: str
    
    # TODO: This is not general enough. It should depend on the type of the model
    example: ExampleImage
    limitations: str
        
# This should be an op, so we can call it from the UI, but I need to fix something
# to make that work
#@weave.op()
def model_card_panel(model_card: ModelCard) -> panels.Card:
    return panels.Card(
        title=model_card.model_name,
        subtitle=model_card.created_by.name,
        content=[
            panels.CardTab(
                name='Overview',
                content=panels.Group(
                    items=[
                        panels.Group(
                            prefer_horizontal=True,
                            items=[
                                panels.LabeledItem(item=model_card.updated_at, label='Last updated'),
                                panels.LabeledItem(item=model_card.model_type, label='Model type'),
                                panels.LabeledItem(item=model_card.primary_metric.name, label='Metric'),
                            ]
                        ),
                        panels.LabeledItem(item=model_card.application, label='Application'),
                        panels.LabeledItem(item=model_card.example, label='Example'),
                    ]
                )
            ),
            panels.CardTab(
                name='Limitations & Use',
                content=panels.LabeledItem(item=model_card.limitations, label='Limitations')
            ),

        ]
    )

In [None]:
# Programatically generate a model card (will additionally set defaults for user to replace)
entity = 'stacey'
project = 'digio'
run_id = '2pw9wdv6'
artifact_index = 0
card_args, full_args = get_model_card_args(entity, project, run_id, artifact_index)
card_args

In [None]:
# Define the model card
model_card = ModelCard(**card_args)

# Render it using the model_card_panel!
model_card_panel(model_card)