# Graph Building & Matching

In [None]:
# %matplotlib widget
# above line is for interactive figures (nice for 3D plots). Install with:
# conda install -c conda-forge ipympl
import numpy as np
import matplotlib.pyplot as plt
import random
import quaternion as qt
#from scipy.spatial.transform import Rotation as R

import habitat_sim
from habitat_sim import AgentConfiguration, CameraSensorSpec

from tbp.monty.simulators.habitat import HabitatSim
from tbp.monty.frameworks.environments.habitat import PanTiltZoomCamera
from tbp.monty.frameworks.environment_utils.transforms import DepthTo3DLocations

## Habitat Experiment Setup

In [None]:
def get_one_obs_sim(obj="mug",obj_position=(0.01, 1.5, -0.08),cam_position=(0.0,0.0,0.0),cam_rotation=(1.0,0.0,0.0,0.0), obj_rotation=(0.0,0.0,0.0,0.0), world_coord=True):
    
    camera = PanTiltZoomCamera(semantic=True, resolution=(32, 32),position=cam_position, rotation=cam_rotation)
    sim = HabitatSim(agents={"agent_01": camera.get_spec()})
    obj = sim.add_object(name=obj, position=obj_position,rotation=obj_rotation)

    state = sim.get_agent("agent_01").get_state()
    states = sim.get_states()
    print(state)
    print(states)

    obs = sim.get_observations()
    # add 3D coordinates
    transform = DepthTo3DLocations(agent_id='agent_01', resolution=obs['agent_01']['depth'].shape, world_coord=world_coord)
    obs = transform(obs, state=states)

    sim.close()
    return obs, None, state


### Get Observations for Different Perspectives

In [None]:
WORLD_COORD = True
# Currently: get one observation and saccade on this
obs, _, state = get_one_obs_sim(world_coord=WORLD_COORD)
obs2, _, state2 = get_one_obs_sim(cam_position=(0.0,0.1,0.0),cam_rotation=(1.0,-0.4,0.0,0.0),world_coord=WORLD_COORD)
obs3, _, state3 = get_one_obs_sim(cam_position=(0.0,0.05,0.0),world_coord=WORLD_COORD)

In [None]:
plt.figure()
plt.subplot(2,3,1)
plt.imshow(obs['agent_01']['rgba'])
plt.title('RGBA')
plt.axis('off')
plt.subplot(2,3,2)
plt.imshow(obs['agent_01']['depth'])
plt.title('Depth')
plt.axis('off')
plt.subplot(2,3,3)
plt.imshow(obs['agent_01']['semantic'])
plt.title('Semantic')
plt.axis('off')

plt.subplot(2,3,4)
plt.imshow(obs2['agent_01']['rgba'])
#plt.title('RGBA')
plt.axis('off')
plt.subplot(2,3,5)
plt.imshow(obs2['agent_01']['depth'])
plt.title('View 2')
plt.axis('off')
plt.subplot(2,3,6)
plt.imshow(obs2['agent_01']['semantic'])
#plt.title('Semantic')
plt.axis('off')
plt.show()

### Saccade on Environment Observation (will be moved into Habitat)

In [None]:
def fake_saccade(full_obs, gaussian=True):
    xy_range = full_obs['agent_01']['rgba'].shape[:2]
    if gaussian:
        xy = np.round(np.random.normal(xy_range[0]//2, xy_range[0]//6, 2)).astype(int)#random.randint(0,xy_range[0]-1),random.randint(0,xy_range[1]-1)
        xy[xy>63]=63
        xy[xy<0]=0
        x,y=xy
    else:
        x,y = random.randint(0,xy_range[0]-1),random.randint(0,xy_range[1]-1)
    z = full_obs['agent_01']['depth'][y,x]
    obj = full_obs['agent_01']['semantic'][y,x]>0
    feat = full_obs['agent_01']['rgba'][y,x]
    return x, y, z, obj, feat

In [None]:
def collect_fake_saccades(full_obs, num_saccades, gaussian=True):
    all_x, all_y, all_z, all_obj, all_feat = [],[],[],[],[]
    for n in range(num_saccades):
        x, y, z, obj, feat = fake_saccade(full_obs, gaussian=gaussian)
        all_x.append(x), all_y.append(y), all_z.append(z), all_obj.append(obj), all_feat.append(feat)
    return all_x, all_y, all_z, all_obj, all_feat

In [None]:
xs, ys, zs, objs, feats = collect_fake_saccades(obs,20,gaussian=True)
# Important: When indexing array, x and y are switched!

plt.figure()
plt.imshow(obs['agent_01']['depth'])
plt.xlim([0,63])
plt.ylim([63,0])
plt.plot(xs,ys,marker='o',color='lightblue', markerfacecolor='green', markersize=0)
plt.scatter(xs,ys,marker='o',c=np.array(objs)*1,cmap='Reds',s=70)
plt.title("Gaussian")

plt.show()

## Collect Fake 3D Saccades
Its still kind of fake since we just get all the points at once without actually moving around in habitat.

In [None]:
def get_fake_3D_saccades(full_obs, num_saccades=0):
    """
    Just get all 3D points on the object in the camera view (depends on camera resolution).
    For real saccades we want to actually move around in habitat.
    This also doesn't return the feature at the location at the moment which we may also want to add.

    Arguments
            :param num_saccades: How many saccades should be returned. If 0, then all are returned,
                    otherwise they are randomly sampled.
    """
    obs_3d = full_obs['agent_01']['semantic_3d']
    xs, ys, zs, objs = obs_3d[:,0], obs_3d[:,1], obs_3d[:,2], obs_3d[:,3]
    if num_saccades > 0:
        idx = np.random.choice(len(xs), size=num_saccades, replace=False)
        xs, ys, zs, objs = xs[idx], ys[idx], zs[idx], objs[idx]
    feats = np.zeros(len(xs))
    # Once we have multiple objects in a scene we want to change this
    # TODO: maybe already update this for naming new graphs
    objs = np.ones(len(objs))
    # add negative sign to x, y, z to make plotting more intuitive. Since we use displacements for everything else this should be fine.
    return -xs, -ys, -zs, objs, feats

## Graph Building

In [None]:
import torch
# conda install pyg -c pyg -c conda-forge
import torch_geometric
from torch_geometric.data import Data
from mpl_toolkits.mplot3d import Axes3D
from sklearn.neighbors import NearestNeighbors

In [None]:
# data.x: Node feature matrix with shape [num_nodes, num_node_features]
#           [num_saccades, 4]. node_features = x, y, z, feat
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
#           bidirectional. Each new node links to the N closest nodes in 3D space
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
#           [N, 3]. edge_features = 3D displacement
# data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
#           ?
# data.pos: Node position matrix with shape [num_nodes, num_dimensions]
#           [num_saccades, 3]?

### Build Temporal Graph

In [None]:
def build_temporal_graph(xs, ys, zs, objs, feats):
    edge_starts = []
    edge_ends = []
    edges = []
    nodes = []
    positions = []
    n = 0
    for i in range(len(xs)):
        if objs[i]:
            # Add node only of saccade landed on the object
            # Leads to graph not being fully connected with current node adding
            nodes.append([xs[i], ys[i], zs[i]])#, feats[i]
            positions.append([xs[i], ys[i], zs[i]])
            if n>0:
                edge_starts.extend([n-1,n])
                edge_ends.extend([n,n-1])
                # Since the displacements depend of the direction of the node, this
                # makes the graph undirected.
                edges.extend([[xs[i]-xs[i-1],ys[i]-ys[i-1],zs[i]-zs[i-1]],
                [xs[i-1]-xs[i],ys[i-1]-ys[i],zs[i-1]-zs[i]]])
            n = n+1
    edge_index = torch.tensor([edge_starts,edge_ends], dtype=torch.long)
    x = torch.tensor(nodes, dtype=torch.float)
    edge_attr = torch.tensor(edges, dtype=torch.float)
    pos = torch.tensor(positions, dtype=torch.float)
    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
    return graph

In [None]:
def plot_graph(graph, show_nodes = True, show_edges=True, show_trisurf=False, rotation=-80, ax_lim=None):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")

    if show_nodes:
        ax.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])

    if show_edges:
        for i,e in enumerate(graph.edge_index[0]):
            e2 = graph.edge_index[1][i]
            ax.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
            [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")

    if show_trisurf:
        ax.plot_trisurf(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],alpha=0.7)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    if ax_lim != None:
        plt.xlim([0,ax_lim])
        plt.ylim([ax_lim,0])
    ax.view_init(rotation,180)
    fig.tight_layout()
    return fig


In [None]:
# xs, ys, zs, objs, feats = collect_fake_saccades(obs,5000)
# large_graph = build_temporal_graph(xs, ys, zs, objs, feats)
# 
# plot_graph(large_graph,show_nodes=True, show_edges=False, ax_lim=63)
# plt.title('nodes')
# plot_graph(large_graph,show_nodes=False, show_edges=True, ax_lim=63)
# plt.title('edges')
# plot_graph(large_graph,show_nodes=False, show_edges=False, show_trisurf=True,rotation=-50, ax_lim=63)
# plt.title('trisurf')
# plt.show()

In [None]:
xs, ys, zs, objs, feats = collect_fake_saccades(obs,100)
small_graph = build_temporal_graph(xs, ys, zs, objs, feats)
plot_graph(small_graph,show_trisurf=True,rotation=-40,ax_lim=63)
plt.show()

### Build Adjacency Graph

In [None]:
def build_adjacency_graph(xs, ys, zs, objs, feats, k_n):
    # k_n will need to be one larger than how many edges you like since the
    # first nearest edge of a node is itself.
    k_n = k_n+1
    X = []
    for i in range(len(xs)):
        if objs[i]:
            X.append([xs[i],ys[i],zs[i]])
    X = np.array(X)
    
    neigh = NearestNeighbors(n_neighbors=k_n)
    neigh.fit(X)

    A = neigh.kneighbors_graph(X)

    G = torch_geometric.utils.from_scipy_sparse_matrix(A)

    edge_index = G[0]
    # TODO: nodes don't include features yet
    nodes = np.unique(np.array(G[0][0]))
    x = torch.tensor(nodes, dtype=torch.float)

    displacements = []
    for e, edge_start in enumerate(edge_index[0]):
        edge_end = edge_index[1][e]
        displacements.append(X[edge_end]-X[edge_start])
    edge_attr = np.array(displacements, dtype=np.float32)
    
    pos = torch.tensor(X, dtype=torch.float)
    n_graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

    return n_graph

In [None]:
xs, ys, zs, objs, feats = get_fake_3D_saccades(obs,100)
t_graph = build_temporal_graph(xs, ys, zs, objs, feats)
n1_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 1)
n2_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 2)
n3_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 3)
n5_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 5)

In [None]:
fig = plt.figure(figsize=(25,5))
ax1 = fig.add_subplot(151, projection="3d")
ax2 = fig.add_subplot(152, projection="3d")
ax3 = fig.add_subplot(153, projection="3d")
ax4 = fig.add_subplot(154, projection="3d")
ax5 = fig.add_subplot(155, projection="3d")
axes = [ax1,ax2,ax3,ax4,ax5]
graphs_to_show = [t_graph,n1_graph,n2_graph,n3_graph,n5_graph]
titles = ['temporal', '1-neighbor', '2-neighbors', '3-neighbors', '5-neighbors']

for n,ax in enumerate(axes):
    graph = graphs_to_show[n]
    ax.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])
    for i,e in enumerate(graph.edge_index[0]):
                e2 = graph.edge_index[1][i]
                ax.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
                [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")
    ax.set_xlabel("x")
    ax.set_zlabel("z")
    ax.view_init(-40,180)
    ax.title.set_text(titles[n])
    
plt.show()
#plt.savefig('./figures/k-neighbors-comp-3dsac.png',bbox_inches='tight',dpi=200)

In [None]:
xs, ys, zs, objs, feats = get_fake_3D_saccades(obs)
large_n5_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 5)

In [None]:
fig = plt.figure(figsize=(25,7))
ax1 = fig.add_subplot(151, projection="3d")
ax2 = fig.add_subplot(152, projection="3d")
ax3 = fig.add_subplot(153, projection="3d")
axes = [ax1, ax2, ax3]
graph = large_n5_graph

ax1.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])

for i,e in enumerate(graph.edge_index[0]):
            e2 = graph.edge_index[1][i]
            ax2.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
            [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")

ax3.plot_trisurf(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],alpha=0.7)

for ax in axes:
    ax.set_xlabel("x")
    ax.set_zlabel("z")
    ax.view_init(-50,180)

ax1.title.set_text("nodes")
ax2.title.set_text("edges")
ax3.title.set_text("trisurface")

plt.show()
#plt.savefig('./figures/many-obs-graph-3dsac.png',bbox_inches='tight',dpi=200)

### Extend Graphs with new Knowledge Becoming Available from Different Perspectives or Changing Occlusion
#### Fake Saccades (approximates x and y) (SKIP, JUST KEPT FOR POSTERITY)

In [None]:
xs, ys, zs, objs, feats = collect_fake_saccades(obs,2000)
n5_graph_persp1 = build_adjacency_graph(xs, ys, zs, objs, feats, 3)

xs2, ys2, zs2, objs2, feats2 = collect_fake_saccades(obs2,200)
n5_graph_persp2 = build_adjacency_graph(xs2, ys2, zs2, objs2, feats2, 3)

xs3, ys3, zs3, objs3, feats3 = collect_fake_saccades(obs3,2000)
n5_graph_persp3 = build_adjacency_graph(xs3, ys3, zs3, objs3, feats3, 3)

In [None]:
def get_camera_displacement_fake(state1, state2):
    pose1 = [state1.sensor_states['rgba'].position, state1.sensor_states['rgba'].rotation]
    pose2 = [state2.sensor_states['rgba'].position, state2.sensor_states['rgba'].rotation]
    displacement_loc = pose2[0]-pose1[0]
    displacement_rot = pose2[1]-pose1[1]
    return [displacement_loc, displacement_rot]

def pixel_to_world(world_vec):
    # temporary rough fix to compensate for x and y being sampled in pixel space
    # really rough approximation...
    x_mapping = {np.float32(0.0):0.0,np.float32(0.01):1,np.float32(0.02):4,np.float32(0.03):8,np.float32(0.04):13,np.float32(0.05):23}
    y_mapping = {np.float32(0.0):0.0,np.float32(0.01):6,np.float32(0.02):12,np.float32(0.03):18,np.float32(0.04):20,np.float32(0.05):22}
    pixel_vec = [x_mapping[world_vec[0]],y_mapping[world_vec[1]]]
    return pixel_vec

def extend_graphs_fake(graph1, graph2, displacement, k_n):
    positions = np.array(graph1.pos)
    positions2 = np.array(graph2.pos)
    displacement[0] = np.array(displacement[0])
    displacement[0] = [displacement[0][0],displacement[0][1],displacement[0][2]]
    for p in positions2:
        d_pos = p - displacement[0] # TODO: include rotation
        positions = np.vstack((positions,d_pos))

    feats = [] # TODO: include features in graphs
    objs = np.ones(len(positions))

    extended_graph = build_adjacency_graph(positions[:,0], positions[:,1], positions[:,2], objs, feats, k_n)
    return extended_graph

In [None]:
d = get_camera_displacement_fake(state, state3)
print(d)
pixd = pixel_to_world([np.round(d[0][0],2),np.round(d[0][1],2)])
d[0][0] = pixd[0]
d[0][1] = pixd[1]
print(d)
extended_graph = extend_graphs_fake(n5_graph_persp1, n5_graph_persp3, d, 5)

In [None]:
plt.figure(figsize=(20,7))
plt.subplot(2,3,1)
plt.scatter(n5_graph_persp1.pos[:,0],n5_graph_persp1.pos[:,2])
plt.xlim([0,63])
plt.ylim([0,0.15])
plt.subplot(2,3,2)
plt.scatter(n5_graph_persp3.pos[:,0],n5_graph_persp3.pos[:,2])
plt.xlim([0,63])
plt.ylim([0,0.15])
plt.subplot(2,3,3)
plt.scatter(extended_graph.pos[:,0],extended_graph.pos[:,2])
plt.xlim([0,63])
plt.ylim([0,0.15])

plt.subplot(2,3,4)
plt.scatter(n5_graph_persp1.pos[:,1],n5_graph_persp1.pos[:,2])
plt.xlim([0,63])
plt.ylim([0,0.15])
plt.subplot(2,3,5)
plt.scatter(n5_graph_persp3.pos[:,1],n5_graph_persp3.pos[:,2])
plt.xlim([0,63])
plt.ylim([0,0.15])
plt.subplot(2,3,6)
plt.scatter(extended_graph.pos[:,1],extended_graph.pos[:,2])
#plt.xlim([0,63])
#plt.ylim([0,0.15])
plt.show()

In [None]:
fig = plt.figure(figsize=(25,5))
ax1 = fig.add_subplot(151)
ax2 = fig.add_subplot(152, projection="3d")
ax3 = fig.add_subplot(153)
ax4 = fig.add_subplot(154, projection="3d")
ax5 = fig.add_subplot(155, projection="3d")
axes = [ax1,ax2,ax3,ax4,ax5]
# empty = Data(x=[0,0], edge_index=[[0,0],[0,0]], edge_attr=[0,0], pos=np.array([[0,0,0],[0,0,0]]))

graphs_to_show = [None,n5_graph_persp1,None,n5_graph_persp3,extended_graph]
titles = ['perspective 1', 'perspective 1', 'perspective 3', 'perspective 3', 'combined']
for n,ax in enumerate(axes):
    if n==0:
        ax.imshow(obs['agent_01']['depth'])
        ax.scatter(xs,ys,marker='o',c=np.array(objs)*1,cmap='Reds',s=30, alpha=0.7)
    elif n==2:
        ax.imshow(obs3['agent_01']['depth'])
        ax.scatter(xs3,ys3,marker='o',c=np.array(objs3)*1,cmap='Reds',s=30, alpha=0.7)
    else:
        graph = graphs_to_show[n]
        ax.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])
        for i,e in enumerate(graph.edge_index[0]):
                    e2 = graph.edge_index[1][i]
                    ax.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
                    [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")
        ax.set_xlabel("x")
        #ax.set_ylabel("y")
        ax.set_zlabel("z")
        ax.set_xlim([0,63])
        ax.set_ylim([63,0])
        ax.set_zlim([0,0.15])
        ax.view_init(-40,180)
        ax.title.set_text(titles[n])
    
#plt.show()
plt.savefig('./figures/extend-graphs.png',bbox_inches='tight',dpi=200)

### Use 3D Observations

In [None]:
xs, ys, zs, objs, feats = get_fake_3D_saccades(obs, num_saccades=200)
n5_graph3d_persp1 = build_adjacency_graph(xs, ys, zs, objs, feats, 5)

xs3, ys3, zs3, objs3, feats3 = get_fake_3D_saccades(obs3)
n5_graph3d_persp3 = build_adjacency_graph(xs3, ys3, zs3, objs3, feats3, 5)

In [None]:
def get_camera_displacement(state1, state2):
    pose1 = [state1.sensor_states['rgba'].position, state1.sensor_states['rgba'].rotation]
    pose2 = [state2.sensor_states['rgba'].position, state2.sensor_states['rgba'].rotation]
    displacement_loc = pose2[0]-pose1[0]
    displacement_rot = pose2[1]-pose1[1]
    return [displacement_loc, displacement_rot]

def extend_graphs(graph1, graph2, k_n, displacement=None):
    # If we are getting x,y,z in world coordinates we don't need displacement (already done in the transform)
    positions = np.array(graph1.pos)
    positions2 = np.array(graph2.pos)
    # TODO: remove nodes with very similar positions
    if displacement != None:
        displacement[0] = np.array(displacement[0])
        displacement[0] = [displacement[0][0],displacement[0][1],displacement[0][2]]
        for p in positions2:
            d_pos = p + displacement[0] # TODO: include rotation
            positions = np.vstack((positions,d_pos))
    else:
        positions = np.vstack((positions, positions2))

    feats = [] # TODO: include features in graphs
    objs = np.ones(len(positions))

    extended_graph = build_adjacency_graph(positions[:,0], positions[:,1], positions[:,2], objs, feats, k_n)
    return extended_graph

In [None]:
if WORLD_COORD:
    d = None
else:
    # Q: do we want to have this option?
    # TODO: doesn't work with camera rotation yet.
    d = get_camera_displacement(state, state3)
    print(d)
extended_graph_3d = extend_graphs(n5_graph3d_persp1, n5_graph3d_persp3, 5, displacement=d)

In [None]:
fig = plt.figure(figsize=(25,5))
ax1 = fig.add_subplot(151)
ax2 = fig.add_subplot(152, projection="3d")
ax3 = fig.add_subplot(153)
ax4 = fig.add_subplot(154, projection="3d")
ax5 = fig.add_subplot(155, projection="3d")
axes = [ax1,ax2,ax3,ax4,ax5]
# empty = Data(x=[0,0], edge_index=[[0,0],[0,0]], edge_attr=[0,0], pos=np.array([[0,0,0],[0,0,0]]))

graphs_to_show = [None,n5_graph3d_persp1,None,n5_graph3d_persp3,extended_graph_3d]
titles = ['perspective 1', 'perspective 1', 'perspective 3', 'perspective 3', 'combined']

for n,ax in enumerate(axes):
    if n==0:
        ax.imshow(obs['agent_01']['depth'])
        #ax.scatter(xs,ys,marker='o',c=np.array(objs)*1,cmap='Reds',s=30, alpha=0.7)
    elif n==2:
        ax.imshow(obs3['agent_01']['depth'])
        #ax.scatter(xs3,ys3,marker='o',c=np.array(objs3)*1,cmap='Reds',s=30, alpha=0.7)
    else:
        graph = graphs_to_show[n]
        ax.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])
        # Comment in to draw connections between nodes (takes longer to plot)
        # for i,e in enumerate(graph.edge_index[0]):
        #             e2 = graph.edge_index[1][i]
        #             ax.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
        #             [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")
        ax.set_xlabel("x")
        ax.set_zlabel("z")
        ax.view_init(-40,180)
        ax.title.set_text(titles[n])
    
plt.show()
#plt.savefig('./figures/extend-graphs-3d.png',bbox_inches='tight',dpi=200)

#### With Camera Zoom and Rotation

In [None]:
xs2, ys2, zs2, objs2, feats2 = get_fake_3D_saccades(obs2)
n5_graph3d_persp2 = build_adjacency_graph(xs2, ys2, zs2, objs2, feats2, 5)

In [None]:
if WORLD_COORD:
    d2 = None
else:
    d2 = get_camera_displacement(state, state2)
    print(d2)

extended_graph_3d_rot = extend_graphs(n5_graph3d_persp1, n5_graph3d_persp2, 5, displacement=d2)


In [None]:
fig = plt.figure(figsize=(25,5))
ax1 = fig.add_subplot(151)
ax2 = fig.add_subplot(152, projection="3d")
ax3 = fig.add_subplot(153)
ax4 = fig.add_subplot(154, projection="3d")
ax5 = fig.add_subplot(155, projection="3d")
axes = [ax1,ax2,ax3,ax4,ax5]

graphs_to_show = [None,n5_graph3d_persp1,None,n5_graph3d_persp2,extended_graph_3d_rot]
titles = ['perspective 1', 'perspective 1', 'perspective 2', 'perspective 2', 'combined']
#plot_graph(t_graph,show_trisurf=False,rotation=-40)
for n,ax in enumerate(axes):
    if n==0:
        ax.imshow(obs['agent_01']['depth'])
        #ax.scatter(xs,ys,marker='o',c=np.array(objs)*1,cmap='Reds',s=30, alpha=0.7)
    elif n==2:
        ax.imshow(obs2['agent_01']['depth'])
        #ax.scatter(xs3,ys3,marker='o',c=np.array(objs3)*1,cmap='Reds',s=30, alpha=0.7)
    else:
        graph = graphs_to_show[n]
        ax.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])
        # Comment in to draw connections between nodes (takes longer to plot)
        for i,e in enumerate(graph.edge_index[0]):
                    e2 = graph.edge_index[1][i]
                    ax.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
                    [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")
        ax.set_xlabel("x")
        ax.set_zlabel("z")
        ax.view_init(-40,180)
        ax.title.set_text(titles[n])
    
plt.show()
#plt.savefig('./figures/extended-graphs-3d-rot-lessobs.png',bbox_inches='tight',dpi=200)

### Body-Centric to Object/Node-Centric Graph

In [None]:
xs, ys, zs, objs, feats = get_fake_3D_saccades(obs,200)
body_ref_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 3)

In [None]:
def to_node_centric_positions(graph, center_node_id):
    node_ref_pos = np.zeros(graph.pos.shape)
    center_node_pos = graph.pos[center_node_id]
    for node_id in graph.x:
        pos_rel_node = np.array(center_node_pos - graph.pos[int(node_id)])
        node_ref_pos[int(node_id)] = pos_rel_node
    node_ref_graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr,
                         pos=torch.tensor(node_ref_pos, dtype=torch.float))
    return node_ref_graph

In [None]:
node_ref_graph = to_node_centric_positions(body_ref_graph, center_node_id=0)

In [None]:
fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122, projection="3d")
axes = [ax1, ax2]
graphs_to_show = [body_ref_graph,node_ref_graph]
titles = ['body-ref-graph', 'node0-ref-graph']

for n,ax in enumerate(axes):
    graph = graphs_to_show[n]
    ax.scatter(graph.pos[:,1],graph.pos[:,0],graph.pos[:,2],c=graph.pos[:,2])
    ax.scatter(graph.pos[0,1],graph.pos[0,0],graph.pos[0,2],c='red',s=50)
    for i,e in enumerate(graph.edge_index[0]):
                e2 = graph.edge_index[1][i]
                ax.plot([graph.pos[e,1],graph.pos[e2,1]],[graph.pos[e,0],graph.pos[e2,0]],
                [graph.pos[e,2],graph.pos[e2,2]], color="tab:gray")
    ax.set_xlabel("x")
    ax.set_zlabel("z")
    ax.view_init(-40,180)
    ax.title.set_text(titles[n])

plt.show()
#plt.savefig('./figures/bode-to-node-ref.png',bbox_inches='tight',dpi=200)

### Build Graph Database

In [None]:
obs_cube, _, _ = get_one_obs_sim(obj='cube', obj_position=(0.01, 1.3, -0.4))

In [None]:
xs, ys, zs, objs, feats = get_fake_3D_saccades(obs_cube)
cube_graph = build_adjacency_graph(xs, ys, zs, objs, feats, 5)

fig = plt.figure(figsize=(15,7))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132, projection="3d")

ax1.imshow(obs_cube['agent_01']['rgba'])
ax2.scatter(cube_graph.pos[:,1], cube_graph.pos[:,0], cube_graph.pos[:,2], c=cube_graph.pos[:,2])
ax2.view_init(-40,180)
plt.show()

In [None]:
xs2, ys2, zs2, objs2, feats2 = get_fake_3D_saccades(obs2)
n5_graph3d_persp2 = build_adjacency_graph(xs2, ys2, zs2, objs2, feats2, 5)
cup_graph_3p = extend_graphs(extended_graph_3d, n5_graph3d_persp2, 5)

plot_graph(cup_graph_3p)
plt.show()

In [None]:
class GraphMemory:
    """
    Graph memory.
    """
    def __init__(self):
        """
        Arguments
            :param 
        """

        self.graph_memory = {}
        self.possible_matches = self.graph_memory
        self.possible_nodes = {}

    def _add_graph_to_memory(self, graph, graph_id):
        """add graph
        Arguments
            :param graph: new graph to be added to memory
            :param graph_id: id of graph that should be added
            """
        if graph_id in self.graph_memory:
            # TODO: consolidate add and update into one method?
            print("Graph with ID "+graph_id+" is already in memory. Use update_graph instead.")
            pass

        self.graph_memory[graph_id] = graph

    def update_graph(self, graph, graph_id, k_n, displacement=None):
        """update graph
        Arguments
            :param graph: graph with new nodes that should be integrated
            :param graph_id: id of graph where new nodes should be integrated
            :param displacement: displacement of new nodes to existing nodes
                            TODO: figure out in which reference frame graph memory
                            should be and what displacement to use
            :param k_n: To how many neighbors should each node link.
                            TODO: should this be an attribute of GraphMemory? Or 
                            should this be object specific?
        """
        if graph_id in self.graph_memory:
            self.graph_memory[graph_id] = extend_graphs(self.graph_memory[graph_id], graph, k_n, displacement)
        else:
            print("Graph with ID "+graph_id+" is not in memory. Use _add_graph_to_memory instead.")

    def get_graph(self, graph_id):
        """add graph
        Arguments
            :param graph_id: id of graph to retrieve
        """
        return self.graph_memory[graph_id]
    
    def get_memory_ids(self):
        """get list of objects in memory"""
        return list(self.graph_memory.keys())

    def get_possible_nodes(self):
        """return possible nodes for each object (for logging/plotting)"""
        return self.possible_nodes.copy()

    def reset_possible_nodes(self):
        """
        Call this before an episode when starting to explore a new object. 
        """
        for graph_id in self.get_memory_ids():
            self.possible_nodes[graph_id] = np.array(self.graph_memory[graph_id].x).astype(int)


    def predict_using_displacements(self, displacement, graph_id, round_to=4):
        if round_to != None:
            displacement = np.round(np.array(displacement), round_to)

        possible_next_nodes = []
        for node in self.possible_nodes[graph_id]:
            edges_of_node = np.where(self.possible_matches[graph_id].edge_index[0]==node)[0][1:]
            node_displacements = np.array(self.possible_matches[graph_id].edge_attr)[edges_of_node]
            if round_to != None:
                # displacements don't need to match exactly
                node_displacements = np.round(node_displacements, round_to)

            for i, d in enumerate(node_displacements):
                if np.all(d == displacement):
                    # print("match for node " + str(node))
                    next_node = self.possible_matches[graph_id].edge_index[1][edges_of_node[i]]
                    possible_next_nodes.append(int(next_node))

        self.possible_nodes[graph_id] = np.unique(np.array(possible_next_nodes).flatten())
        print("possible next nodes for " + graph_id + ": " + str(self.possible_nodes[graph_id]))
        if len(self.possible_nodes[graph_id]) == 0:
            return 0
        else:
            return 1

    def make_predictions(self, query, round_to=4):
        """
        Arguments
            :param query: x,y,z coordinates that are queried (where the sensor moved). 
                        For predict_using_displacements this is a displacement vector.
                        For predict_using_location this is a location in space.
            TODO: In which reference frame are these coordinates?
        Returns
            predictions: Binary predictions for each graph in possible_matches
        """
        predictions = {}
        for graph_id in self.possible_matches:
            #prediction = predict(self.possible_matches[graph_id], target_location)
            prediction = self.predict_using_displacements(query, graph_id, round_to)
            predictions[graph_id] = prediction
        return predictions
    
    def get_prediction_error(self, predictions, target):
        """
        Arguments
            :param predictions: A binary prediction on the objects morphology (object there or not) per graph
            :param target: The actual sensation at the new location (also binary)
        Returns
            prediction_error: Binary prediction error for each graph (1 if object was predicted by target was 0; 0 otherwise)
        """
        prediction_error = {}
        for graph_id in predictions:
            # TODO: how we talked about it we would only want a prediction error if
            #       an object was predicted but wasn't sensed, not the other way around
            #       However, since the continuous policy always stays on the object, 
            #       there would never be a prediction error. Figure out how to do this.
            # if predictions[graph_id] == 1 and target == 0:
            #     # if object was predicted but there was none
            #     prediction_error[graph_id] = 1
            # else:
            #     prediction_error[graph_id] = 0
            prediction_error[graph_id] = int(target != predictions[graph_id])
        print(prediction_error)
        return prediction_error

    def update_possible_matches(self, prediction_error, threshold=0):
        for graph_id in prediction_error:
            if prediction_error[graph_id] > threshold:
                self.possible_matches.pop(graph_id)

    def clean_up_graphs(self):
        """
        This could be a function that cleans up graphs in memory to make
        more efficient use of their nodes by spacing them out evenly along 
        the approximated object surface. It could be something that happens
        during sleep. During clean up, similar graphs could also be merged.
        """
        pass



In [None]:
graph_memory = GraphMemory()
# Add cup from one perspective
graph_memory._add_graph_to_memory(n5_graph3d_persp1, "cup")
# Update cup graph with observations from another persepctive
graph_memory.update_graph(n5_graph3d_persp3, "cup", 5)
# add another object to graph memory (right now same graph, will need to get other objs working)
graph_memory._add_graph_to_memory(cube_graph, "cube")
# Get stored cube model from memory
cube_graph_mem = graph_memory.get_graph('cube')
# Get list of objcts in memory
memory_ids = graph_memory.get_memory_ids()

In [None]:
plot_graph(cube_graph_mem)
plt.show()

In [None]:
list(memory_ids)

### Predictions with Continuous Movement Policy and Object North

Using a *continuous movement* along the objects surface instead of jumping to random points on the object makes it easier to match displacements without having to store or calculate all possible displacements in an objects graph (we only need neighboring displacements).

Using an *'object north'* fixes the objects default rotation in the world and makes it easier to match displacements. if an objects orientation deviates from the default orientation the matches can be attempted with rotated versions of the graph but this requires more cognitive resources and is not as exact (i.e. recognizing an upside down face or metally rotating a spatial puzzle).

For now we implicitly assume a graph north by not rotating the objects in space. TODO: make the object north explicit and allow for mentally rotating the graph if we can't find a match.

In [None]:
def sample_continous_movement_policy(object_morphology, previous_node):
    """
    Sample an observation on the object which is adjacent to the previous observation.
    For this we use a detailed graph of the object at the moment. In the future we could 
    move along the object surface in Habitat (TODO).

    This policy makes graph prediction/matching easier since we don't need all possible
    displacements but only the neighboring ones. When using this to build a graph it also
    avoids the problem of graphs potentially being disconnected.
    """
    # indexes of previous node as start node in the connectivity matrix
    idxs = np.where(np.array(object_morphology.edge_index[0])==previous_node)
    # Remove first result since it is the connection to itself
    neighbors = object_morphology.edge_index[1][idxs][1:]
    next_node = int(np.random.choice(np.array(neighbors)))
    x,y,z = object_morphology.pos[next_node][0], object_morphology.pos[next_node][1], object_morphology.pos[next_node][2]
    return x, y, z, 1, 0, next_node#x, y, z, obj, feat

In [None]:
# Get the object morphology from a detailed 3D graph
xs, ys, zs, objs, feats = get_fake_3D_saccades(obs)
object_morphology = build_adjacency_graph(xs, ys, zs, objs, feats, 5)

# select a random start node
previous_node = int(np.random.choice(np.array(object_morphology.x)))
cxs, cys, czs = np.array(object_morphology.pos[previous_node][0]), np.array(object_morphology.pos[previous_node][1]), np.array(object_morphology.pos[previous_node][2])
cobjs, cfeats = [1], [0]
# for n_steps move from the start node along the edges of the object graph and sample the observations at the nodes
n_steps = 200
for i in range(n_steps):
    # Sample neighboring node observation
    cx, cy, cz, cobj, cfeat, node_id = sample_continous_movement_policy(object_morphology, previous_node)
    # Append observation to buffer of observations
    cxs = np.append(cxs,cx)
    cys = np.append(cys,cy)
    czs = np.append(czs,cz)
    cobjs = np.append(cobjs,cobj)
    cfeats = np.append(cfeats,cfeat)
    # set current node to previous node for next step
    previous_node = node_id

# show the path that was taken along the object graph
continuous_graph = build_temporal_graph(cxs, cys, czs, cobjs, cfeats)
plot_graph(continuous_graph,rotation=-40)
plt.show()

In [None]:
graph_memory.reset_possible_nodes()
predictions = graph_memory.make_predictions(query=(0.0,0.0,0.0))
print("predictions: "+str(predictions))
predictions['cup'] = 0 # TODO: remove
PE = graph_memory.get_prediction_error(predictions, target=0)
print("prediction errors: "+str(PE))
# removes the cube since it was predicted but wasn't actually there
graph_memory.update_possible_matches(PE)
print(graph_memory.get_memory_ids())

In [None]:
def run_episode(obs, continuous_movement=True, min_steps=5, exploration_len=100, k_n=3, max_steps=1000):
    """add graph
        Arguments
            :param obs: Observation from current perspective. TODO: replace with saccades in habitat.
            :param min_steps: How many steps do we want to make before we safely recognize an object? TODO: currently needs
                    to be >= k_n. Make the independent of k_n
            :param exploration_len: How many saccades should be made on a new object before adding it into graph memory.
            :param k_n: How many neighbors should each node connect to.
        """
    done = False
    # TODO: turn this into one buffer variable
    all_x, all_y, all_z, all_d, all_obj, all_feat = [],[],[],[],[],[]
    obj_matches, node_matches = [], []
    episode_stats = {}

    if continuous_movement:
        if type(obs) == torch_geometric.data.data.Data:
            print("using graph as input")
            # input a graph as observation to saccade along -> make sure to saccade to exact points stored in memory.
            object_morphology = obs
        else:
            print("building graph")
            xs, ys, zs, objs, feats = get_fake_3D_saccades(obs)
            object_morphology = build_adjacency_graph(xs, ys, zs, objs, feats, 5)
        # select a random start node
        previous_node = int(np.random.choice(np.array(object_morphology.x)))

    # At the start of exploration an observation could correspond to any node on any graph in memory
    graph_memory.reset_possible_nodes()
    while not done:
        print('step ' + str(len(all_x)))
        if continuous_movement:
            x, y, z, obj, feat, previous_node = sample_continous_movement_policy(object_morphology, previous_node)
        else:
            x, y, z, obj, feat = get_fake_3D_saccades(obs, num_saccades=1)
        all_x.append(x), all_y.append(y), all_z.append(z), all_obj.append(obj), all_feat.append(feat)

        # TODO: figure out in which reference frame target location should be
        # query position on object
        #predictions = graph_memory.make_predictions(query=(x,y,z))
        # query displacements (can only e performed after first action)
        if len(all_x) > 1:
            displacement = np.array([x-all_x[-2], y-all_y[-2], z-all_z[-2]]).flatten()
            all_d.append(displacement)
            predictions = graph_memory.make_predictions(query=displacement, round_to=3)
            PE = graph_memory.get_prediction_error(predictions, target=int(obj))
            graph_memory.update_possible_matches(PE)
        # Check if we are done
        if len(graph_memory.possible_matches) == 0:
            # No matches -> explore object more and add graph to memory
            print("No matches -> Adding new object to memory")
            x, y, z, obj, feat = get_fake_3D_saccades(obs, num_saccades=exploration_len)

            all_x = np.hstack((all_x,x))
            all_y = np.hstack((all_y,y))
            all_z = np.hstack((all_z,z))
            all_obj = np.hstack((all_obj,obj))
            all_feat = np.hstack((all_feat,feat))

            body_ref_graph = build_adjacency_graph(all_x, all_y, all_z, all_obj, all_feat, k_n)
            #node_ref_graph = to_node_centric_positions(body_ref_graph, center_node_id=0)
            # TODO: which graph to use?
            graph_memory._add_graph_to_memory(body_ref_graph, "new_object") # TODO: how to name this?
            done = True
        elif len(graph_memory.possible_matches) == 1 and len(all_x) >= min_steps:
            # Only one match in memory and we made min_steps -> recognize object and update graph in memory
            object = list(graph_memory.possible_matches.keys())[0]
            print('Recognized ' + object)
            body_ref_graph = build_adjacency_graph(all_x, all_y, all_z, all_obj, all_feat, k_n)
            #node_ref_graph = to_node_centric_positions(body_ref_graph, center_node_id=0)
            # TODO: which graph to use?
            #graph_memory.update_graph(body_ref_graph, object, k_n)
            done = True
        else:
            current_matches = graph_memory.get_memory_ids()
            print('Current possible matches: ' + str(current_matches))
            if len(all_x) > max_steps:
                print("Reached max_steps without finding a match.")
                done = True
        obj_matches.append(current_matches)
        node_matches.append(graph_memory.get_possible_nodes())
    episode_stats = {'x':np.array(all_x), 'y':np.array(all_y), 'z':np.array(all_z), 'displacements': np.array(all_d),
                     'obj':np.array(all_obj), 'feat':np.array(all_feat), 'object_matches':obj_matches, 
                     'node_matches': node_matches}
    return episode_stats


In [None]:
graph_memory = GraphMemory()
graph_memory._add_graph_to_memory(cup_graph_3p, "cup")
graph_memory._add_graph_to_memory(cube_graph, "cube")

#stats = run_episode(obs, max_steps=10)
#stats = run_episode(cup_graph_3p, max_steps=10)
stats = run_episode(cube_graph, max_steps=10)

In [None]:
fig = plt.figure(figsize=(20,8))
ax1 = fig.add_subplot(2,5,1, projection="3d")
ax2 = fig.add_subplot(2,5,2, projection="3d")
ax3 = fig.add_subplot(2,5,3, projection="3d")
ax4 = fig.add_subplot(2,5,4, projection="3d")
ax5 = fig.add_subplot(2,5,5, projection="3d")

ax6 = fig.add_subplot(2,5,6, projection="3d")
ax7 = fig.add_subplot(2,5,7, projection="3d")
ax8 = fig.add_subplot(2,5,8, projection="3d")
ax9 = fig.add_subplot(2,5,9, projection="3d")
ax10 = fig.add_subplot(2,5,10, projection="3d")

axes = [ax1,ax2,ax3,ax4,ax5]
axes2 = [ax6,ax7,ax8,ax9,ax10]
match = list(graph_memory.possible_matches)[0]
graph = graph_memory.possible_matches[match]

for step, ax in enumerate(axes):
    ax.scatter(graph.pos[:,1], graph.pos[:,0], graph.pos[:,2],s=3, alpha=0.7, c='grey')

    node_matches = stats['node_matches'][step][match]
    ax.scatter(graph.pos[node_matches,1], graph.pos[node_matches,0], graph.pos[node_matches,2],s=12, alpha=0.7, c='green')

    if step<4:
        for i in range(step+1):
            axes[step+1].plot([stats['y'][i],stats['y'][i+1]],[stats['x'][i],stats['x'][i+1]],[stats['z'][i],stats['z'][i+1]],c='red',lw=3)
            axes2[step+1].plot([stats['y'][i],stats['y'][i+1]],[stats['x'][i],stats['x'][i+1]],[stats['z'][i],stats['z'][i+1]],c='red')

    ax.set_title("step " + str(step) + "\npossible nodes")
    axes2[step].set_title('displacement/action')
    
    ax.set_xlabel("x",labelpad=-10), axes2[step].set_xlabel("x",labelpad=-10)
    ax.set_zlabel("z",labelpad=-15), axes2[step].set_ylabel("y",labelpad=-15)
    ax.set_ylabel("y",labelpad=-15), axes2[step].set_zlabel("z",labelpad=-15)

    ax.set_xticks([]), ax.set_yticks([]), ax.set_zticks([])
    axes2[step].set_xticks([]), axes2[step].set_yticks([]), axes2[step].set_zticks([])

    # Comment in to zoom in on saccade locations
    # zoom = 0.01
    # ax.set_xlim([min(stats['y']) - zoom, max(stats['y']) + zoom])
    # ax.set_ylim([min(stats['x']) - zoom, max(stats['x']) + zoom])
    # ax.set_zlim([min(stats['z']) - zoom, max(stats['z']) + zoom])

    zoom2 = np.max(np.linalg.norm(stats['displacements'],axis=1))# + 0.01
    # axes2[step].set_xlim([min(stats['y']) - np.ptp(stats['y']) * zoom2, max(stats['y']) + np.ptp(stats['y']) * zoom2])
    # axes2[step].set_ylim([min(stats['x']) - np.ptp(stats['x']) * zoom2, max(stats['x']) + np.ptp(stats['x']) * zoom2])
    # axes2[step].set_zlim([min(stats['z']) - np.ptp(stats['z']) * zoom2, max(stats['z']) + np.ptp(stats['z']) * zoom2])
    axes2[step].set_xlim([np.mean(stats['y']) -  zoom2, np.mean(stats['y']) + zoom2])
    axes2[step].set_ylim([np.mean(stats['x']) - zoom2, np.mean(stats['x']) + zoom2])
    axes2[step].set_zlim([np.mean(stats['z']) - zoom2, np.mean(stats['z']) + zoom2])

    ax.view_init(-40,180)
    axes2[step].view_init(-40,180)

#plt.show()
plt.savefig('./figures/prediction_on_cube.png',bbox_inches='tight',dpi=200)

In [None]:
new_graph = graph_memory.get_graph('new_object')
plot_graph(new_graph)
plt.show()

In [None]:
# New elements:
# Classes:
#       GraphMemory -> pass as argument to learning module
#       extend_graphs(graph1, graph2, displacement, k_n)
# Functions:
#       build_adjacency_graph(xs, ys, zs, objs, feats, k_n)
#       get_camera_displacement(state1,state2) -> learning module or create folder for experiment utils
#       predict(graph, target_location)
#   other: -> dev utils (under projects)
#       to_node_centric_positions(graph, node_id)
#       get_one_obs_sim
#       fake_saccade
#       collect_fake_saccades
#       build_temporal_graph
#       plot_graph
#       pixel_to_world