In [None]:
import random
import typing
import itertools
from PIL import Image
import weave
from weave.legacy.weave import storage
from weave.legacy.weave import weave_internal
#weave.use_frontend_devmode()

In [None]:
# Not really model predictions but it'll work for now
import string
import hashlib

def simple_hash(n, b):
    return int.from_bytes(hashlib.sha256(str(n).encode()).digest(), "little") % b

def create_data(n_rows, n_extra_cols, images=False):
    inner_count = int(n_rows / 25)
    base_im = Image.linear_gradient('L').resize((32, 32))
    x_choices = string.ascii_lowercase
    extra_cols = [chr(ord('a') + i) for i in range(n_extra_cols)]
    ims = []
    for i, (rotate, shear, _) in enumerate(
        itertools.product(range(5), range(5), range(inner_count))
    ):
        im ={
            'rotate': rotate,
            'shear': shear,
            'y': x_choices[simple_hash(i**13, 5)],
            'x': x_choices[simple_hash(i**13, 11)]
        }
        if images:
            im['image'] = (base_im
                .rotate(rotate * 4)
                .transform((32, 32), Image.AFFINE, (1, shear / 10, 0, 0, 1, 0), Image.BICUBIC))
        for j, col in enumerate(extra_cols):
            im[col] = x_choices[simple_hash(i*13**j, 11)]
        ims.append(im)
    return ims

ims = create_data(100, 1, True)
#ims = storage.to_arrow(ims)
#ims = weave.save(ims)

In [None]:
weave.show(ims)

In [None]:
plot = weave.legacy.weave.panels.Plot(ims)
plot.set_x(lambda row: row['rotate'])
plot.set_y(lambda row: row['x'])
plot.set_tooltip(lambda row: row['image'])
weave.show(plot)

## Facet is fun!


In [None]:
# This is one way to build a multi-confusion matrix.

facet = weave.legacy.weave.panels.Facet(
    input_node=ims,
    x=lambda im: im['rotate'],
    y=lambda im: im['shear'],
    select=lambda cell: weave.legacy.weave.panels.Plot(
        input_node=cell.groupby(lambda row: row['y']),
        x=lambda group: group.count(),
        y=lambda group: group.key(),
        label=lambda group: group.key(),
        tooltip=lambda group: group.map(lambda r: r['image']),
        mark='bar',
        no_axes=True,
        no_legend=True
    )
)

weave.show(facet)

In [None]:
# @weave.op()
# def confusion_matrix(inp: typing.Any, guess_col: str, truth_col: str, compare_col: str) -> weave.legacy.weave.panels.Facet:
#     return weave.legacy.weave.panels.Facet(
#         input_node=inp,
#         x=lambda i: i[guess_col],
#         y=lambda i: i[truth_col],
#         select=lambda cell: weave.legacy.weave.panels.Plot(
#             input_node=cell.groupby(
#                 weave.define_fn({'row': cell.type.object_type}, lambda row: row[compare_col])),
#             x=lambda group: group.count(),
#             y=lambda group: group.key(),
#             label=lambda group: group.key(),
#             mark='bar',
#             no_axes=True,
#             no_legend=True
#         )
#     )

# An example of a Panel returning op. This (sort of) works but there are lots of 

@weave.op()
def confusion_matrix(inp: typing.Any, guess_col: str, truth_col: str, compare_col: str) -> weave.legacy.weave.panels.Facet:
    return weave.legacy.weave.panels.Facet(
        input_node=inp,
        x=lambda i: i[guess_col],
        y=lambda i: i[truth_col],
        select=lambda cell: cell.count()
    )

In [None]:
demos.confusion_matrix(ims, 'rotate', 'x', 'y')

In [None]:
# Other explorations


# weave.show(facet)

# # Not working yet, but playing with removing lambdas
# facet = weave.legacy.weave.panels.Facet(
#     input_node=ims,
#     x=ims['rotate'],
#     y=ims['shear'],
#     select=lambda cell: weave.legacy.weave.panels.Plot(
#         input_node=cell.groupby(cell.row['y']),
#         x=lambda group: group.count(),
#         y=lambda group: group.key(),
#         label=lambda group: group.key(),
#         mark='bar',
#         no_axes=True,
#         no_legend=True
#     )
# )

# # Maybe another cool thing
# facet = weave.legacy.weave.panels.Facet(
#     input_node=ims,
#     x=lambda im: im['rotate'],
#     y=lambda im: im['shear'],
#     select=lambda cell_ims: cell_ims.groupby(cell.row['y'])
#         .Plot(
#             x=lambda group: group.count(),
#             y=lambda group: group.key(),
#             label=lambda group: group.key(),
#             mark='bar',
#             no_axes=True,
#             no_legend=True
#     )
# )

# # OK actually its more like we just always want to drop the first argument
# # This is finally minimal, but we can't get autocomplete help.
# facet = ims.Facet(
#     x=select('rotate'),
#     y=select('shear'),
#     select=groupby('y')
#         .Plot(
#             x=count(),
#             y=group_key(),
#             label=group_key(),
#             mark=select('bar'),
#             no_axes=True,
#             no_legend=True
#     )
# )

# # Hmm.
# facet = ims
#     .groupby(
#         facet_x=select('rotate'),
#         facet_y=select('shear'),
#         y=select('run'))
#     .count()
#     .Plot()

# facet = Plot(
#     ims,
#     facet_x=select('rotate'),
#     facet_y=select('shear'),
#     y=select('run')
#     x=count()
)

# TODO:
#   - get rid of lambdas (ims.row could be a variable?)
#     ... or we could just treat ims as the row variable when its assigned to a plot?
#   - This is why react is a little nicer, component control flow isn't usually hidden
#     away inside other components (framework style). Instead, you decide how you want to lay stuff
#     out...
#   - But we can achieve that here... We just need to make some lower level components.
#     (Try IT!) Instead of PanelFacet, use PanelLayout or PanelGrid or something
#   - see if we can get rid of groupby requirement in PanelPlot. You don't need it in Vega
#       - also maybe look at Altair API?
#
# In react you'd do:
# const ConfusionMatrix = (ims) => (
#   <Facet ...>
#     <Facet.Cell ...>
#   </Facet>
# )