In [None]:
import numpy as np
from source import *
from functools import reduce

In [None]:
def same_obj(obj1: ARC_Object, obj2: ARC_Object):
    if obj1 is None or obj2 is None:
        return False
    return np.array_equal(obj1.grid, obj2.grid)

class ListProperties:
    def __init__(self, objs: List[ARC_Object]):
        self.objs = sorted(objs, key=lambda o: o.top_left)
        self.num_objs = len(self.objs)
        self.most_common = most_common(self.objs)
        self.majority = majority(self.objs)
        self.and_all = reduce(and_obj, self.objs)
        self.or_all = reduce(or_obj, self.objs)
        self.xor_all = reduce(xor_obj, self.objs)

class CompareObjects:
    def __init__(self, obj1: ARC_Object, obj2: ARC_Object):
        self.obj1 = obj1
        self.obj2 = obj2
        self.same_grid = same_obj(obj1, obj2)
        self.same_pos = obj1.top_left == obj2.top_left
        self.same_color = dominant_color(obj1) == dominant_color(obj2)
        self.same_size = (obj1.height == obj2.height) and (obj1.width == obj2.width)

class CompareLists:
    def __init__(self, lst1: ListProperties, lst2: ListProperties):
        self.lst1 = lst1
        self.lst2 = lst2
        self.compare_objs = [CompareObjects(o1, o2) for o1, o2 in zip(lst1.objs, lst2.objs)]
        self.same_len = lst1.num_objs == lst2.num_objs
        self.same_most_common = same_obj(lst1.most_common, lst2.most_common)
        self.same_majority = same_obj(lst1.majority, lst2.majority)
        self.same_and = same_obj(lst1.and_all, lst2.and_all)
        self.same_or = same_obj(lst1.or_all, lst2.or_all)
        self.same_xor = same_obj(lst1.xor_all, lst2.xor_all)

In [None]:
train_objs, test_obj = quick_load('1caeab9d', 'training')
in_children = extract_objects(train_objs[0]['input'], method='color')
out_children = extract_objects(train_objs[0]['output'], method='color')
train_objs[0]['input'].plot_grid()
train_objs[0]['output'].plot_grid()

in1_p = ListProperties(in_children)
out1_p = ListProperties(out_children)

analyze1 = CompareLists(in1_p, out1_p)

for r in analyze1.compare_objs:
    r.obj1.plot_grid()
    r.obj2.plot_grid()
    print(r.obj1.top_left, r.obj2.top_left)
    print(r.same_grid, r.same_pos, r.same_color, r.same_size)

print(analyze1.same_len, analyze1.same_most_common, analyze1.same_majority, analyze1.same_and, analyze1.same_or, analyze1.same_xor)

# for c in in_children:
#     c.plot_grid()
# print('Out:')
# for c in out_children:
#     c.plot_grid()