In [56]:
import networkx as nx
from networkx.algorithms.operators.binary import difference, symmetric_difference
import os
import json
import matplotlib.pyplot as plt
from zipfile import ZipFile
from tqdm.notebook import tqdm

In [59]:
def plot_graph(graph):
        nx.draw(graph, with_labels=True,
                node_color='skyblue', node_size=2200,
                width=3, edge_cmap=plt.cm.OrRd,
                arrowstyle='->',arrowsize=20,
                font_size=10, font_weight="bold",
                pos=nx.random_layout(init, seed=13))
        plt.show()

In [61]:
ctr = 0
predicted_transformations = []
actual_transformations = []
factorials = {}
num_lines = sum(1 for line in open('C:/Users/Admin/Projects/Transformation Driven Visual Learning/data/data.jsonl', 'r'))
with open('C:/Users/Admin/Projects/Transformation Driven Visual Learning/data/data.jsonl', 'r') as f:
    for line in tqdm(f, total=num_lines):
        # print(line)
        entry = json.loads(line)
        data = {
            'init_state': entry['states'][0]['objects'],
            'init_image_file': entry['states'][0]['images']['Camera_Center'],
            'final_state': entry['states'][1]['objects'],
            'final_image_file': entry['states'][1]['images']['Camera_Center'],
            'transformation': entry['transformations']
        }

        init_state = data['init_state']
        final_state = data['final_state']
        n = len(init_state)
        objects = ['obj'+str(i) for i in range(n)]
        init_edges = []
        final_edges = []

        init = nx.DiGraph()
        final = nx.DiGraph()
        init = nx.DiGraph()
        final = nx.DiGraph()
        attributes = {'colors' : [],
        'materials' : [],
        'shapes' : [],
        'sizes' : [],
        'positions': [],
        }

        # next cell
        nodes = []
        for i in range(n):
            nodes.append(init_state[i]['color'])
            nodes.append(init_state[i]['material'])
            nodes.append(tuple(init_state[i]['position']))
            nodes.append(init_state[i]['shape'])
            nodes.append(init_state[i]['size'])
            nodes.append(final_state[i]['color'])
            nodes.append(final_state[i]['material'])
            nodes.append(tuple(final_state[i]['position']))
            nodes.append(final_state[i]['shape'])
            nodes.append(final_state[i]['size'])
        nodes = nodes + objects

        init.add_nodes_from(nodes)
        final.add_nodes_from(nodes)


        # init.add_nodes_from(objects)
        # final.add_nodes_from(objects)

        # next cell
        import itertools
        if n in factorials.keys():
            objectSet = factorials[n]
        else:
            objectSet = [list(i) for i in itertools.permutations(objects)]
            ni = len(objectSet)
            factorials[n] = objectSet
        # next cell
        def edge_creator(state, id=0):
            objects = objectSet[id]
            edges = []
            for i in range(n):
                edges.append((objects[i], state[i]['shape']))
                edges.append((objects[i], state[i]['color']))
                edges.append((objects[i], state[i]['size']))
                edges.append((objects[i], tuple(state[i]['position'])))
                edges.append((objects[i], state[i]['material']))
            return edges


        # next cell
        init_edges = edge_creator(init_state)
        i = 0
        final_edges = edge_creator(final_state)
        init.add_edges_from(init_edges)
        final.add_edges_from(final_edges)
        diff = symmetric_difference(init, final)
        minChanges = len(diff.edges())
        final.remove_edges_from(final_edges)

        bestDiffGraph = diff

        for i in range(1, ni):
            final_edges = edge_creator(final_state, i)
            diff = symmetric_difference(init, final)
            final.remove_edges_from(final_edges)
            changes = len(diff.edges())
            if minChanges > changes:
                minChanges = changes
                bestDiffGraph = diff
            if minChanges <= 8:
                break


        # next cell
        # nx.draw(bestDiffGraph, with_labels=True,
        #         node_color='skyblue', node_size=2200,
        #         width=3, edge_cmap=plt.cm.OrRd,
        #         arrowstyle='->',arrowsize=20,
        #         font_size=10, font_weight="bold",
        #         pos=nx.random_layout(init, seed=13))


        # next cell
        subGraphs = []
        for object in objects:
            subGraph = nx.DiGraph()
            subGraph.add_node(object)
            nodes = bestDiffGraph.neighbors(object)
            for node in nodes:
                subGraph.add_node(node)
                subGraph.add_edge(object, node)
            subGraphs.append(subGraph)
        changes = []
        for i in range(len(subGraphs)):
            nodes = list(subGraphs[i].nodes())
            print(nodes)
            if len(nodes) == 3:
                changes.append([(nodes[0], nodes[1]), (nodes[0], nodes[2])])  
        predicted_transformations.append(changes)
#         init.remove_edges_from(init_edges)
#         init.remove_nodes_from(nodes)
#         final.remove_nodes_from(nodes)
#         print(predicted_transformations)
        break

  0%|          | 0/530000 [00:00<?, ?it/s]

['obj0']
['obj1']
['obj2', 'yellow', (0, -14), 'cyan', (20, -14)]
['obj3']
['obj4', 'blue', 'brown']
['obj5']
['obj6']
['obj7']
['obj8']
['obj9']


In [62]:
print(changes)
print(init.has_edge(*changes[0]))
init.edges()

[[('obj4', 'blue'), ('obj4', 'brown')]]
False


OutEdgeView([('obj0', 'cylinder'), ('obj0', 'gray'), ('obj0', 'small'), ('obj0', (-19, 37)), ('obj0', 'metal'), ('obj1', 'cylinder'), ('obj1', 'purple'), ('obj1', 'large'), ('obj1', (40, 40)), ('obj1', 'metal'), ('obj2', 'cube'), ('obj2', 'yellow'), ('obj2', 'small'), ('obj2', (0, -14)), ('obj2', 'glass'), ('obj3', 'cylinder'), ('obj3', 'blue'), ('obj3', 'large'), ('obj3', (-18, -2)), ('obj3', 'metal'), ('obj4', 'cube'), ('obj4', 'blue'), ('obj4', 'large'), ('obj4', (-1, 17)), ('obj4', 'glass'), ('obj5', 'cylinder'), ('obj5', 'blue'), ('obj5', 'medium'), ('obj5', (20, 24)), ('obj5', 'metal'), ('obj6', 'sphere'), ('obj6', 'blue'), ('obj6', 'large'), ('obj6', (-38, 25)), ('obj6', 'rubber'), ('obj7', 'cube'), ('obj7', 'gray'), ('obj7', 'large'), ('obj7', (24, 2)), ('obj7', 'glass'), ('obj8', 'sphere'), ('obj8', 'yellow'), ('obj8', 'medium'), ('obj8', (37, -40)), ('obj8', 'metal'), ('obj9', 'cube'), ('obj9', 'brown'), ('obj9', 'large'), ('obj9', (31, -28)), ('obj9', 'glass')])

In [53]:
nodes

['obj9']

In [28]:
with open('C:/Users/Admin/Projects/Transformation Driven Visual Learning/predicted_transformations.json', 'w') as fp:
    json.dump(predicted_transformations, fp)