In [1]:
import bentley_ottmann
from utils import *
from modules import *
from sklearn import preprocessing

In [2]:
G_list, data_list = load_processed_data(G_list_file='G_list.pickle', 
                                        data_list_file='data_list.pickle')

In [3]:
pos = data_list[0].x.detach().cpu().numpy()

In [173]:
edges = np.array([[0, 1],
                  [0, 2],
                  [0, 3],
                  [0, 4],
                  [0, 5],
                  [0, 6],
                  [0, 7],
                  [0, 8],
                  [0, 9],
                  [0, 10],
                  [0, 11],
                  [0, 12],
                  [0, 13],
                  [0, 14],
                  [0, 15]])

In [6]:
def get_real_edges(batch):
    data_list = batch.to_data_list() if type(batch) is Batch else [batch]

    offset = 0
    neighbor_mask_, edge_list_ = [], []
    for data in data_list:
        size = data.num_nodes
        edge_list_.append(data.edge_index.T.cpu().numpy() + offset)
        l = data.edge_attr[:, 0].detach().cpu().numpy()
        neighbor_mask_.append(l == l.min())
        offset += size
    neighbor_mask = np.concatenate(neighbor_mask_)
    edge_list = np.concatenate(edge_list_)

    return edge_list[neighbor_mask]

In [163]:
def counter_clockwise_node_pairs(edges, pos):
    u, v = edges[:, 0], edges[:, 1]
    diff = pos[v] - pos[u]
    diff_normalized = preprocessing.normalize(diff, norm='l2')
    # get cosine angle between uv and y-axis
    cos = diff_normalized @ np.array([[1],[0]])
    # get radian between uv and y-axis
    radian = np.arccos(cos) * np.expand_dims(np.sign(diff[:, 1]), axis=1)
    # for each u, sort edges based on the position of v
    sorted_idx = sorted(np.arange(len(edges)), key=lambda node: (u[node], radian[node]))
    sorted_v = v[sorted_idx]
    # get start index for each u
    idx = np.unique(u, return_index=True)[1]
    roll_idx = np.arange(1, len(u) + 1)
    roll_idx[np.roll(idx - 1, -1)] = idx
    rolled_v = sorted_v[roll_idx]
    return np.stack([u, sorted_v, rolled_v]).T[sorted_v != rolled_v]

In [8]:
def calculate_l(edges, pos):
    pairs = counter_clockwise_node_pairs(edges, pos)
    p0, p1 = pos[pairs[:, 0]], pos[pairs[:, 1]]
    e0 = p0 - pos[edges[:, 0]]
    e1 = p1 - pos[edges[:, 0]]
    du = len(edges)
    e0norm, e1norm = np.linalg.norm(e0, axis=1), np.linalg.norm(e1, axis=1)
    return np.sqrt(np.abs(e0norm ** 2 + e1norm ** 2 - 2 * e0norm * e1norm * np.cos(2 * np.pi / du)))

In [9]:
def cross_spring(edges, pos):
    pairs = counter_clockwise_node_pairs(edges, pos)
    p0, p1 = pos[pairs[:, 0]], pos[pairs[:, 1]]
    diff = p1 - p0
    term1 = np.log(np.linalg.norm(diff, axis=1) / calculate_l(edges, pos))
    term2 = preprocessing.normalize(diff, norm='l2')
    return np.expand_dims(term1, axis=1) * term2

In [10]:
data = data_list[0]
data

Data(edge_attr=[240, 2], edge_index=[2, 240], x=[16, 2])

In [59]:
edges = get_real_edges(data)
pos = data.x
node, idx, counts = np.unique(edges[:, 0], return_index=True, return_counts=True)
roll_idx = np.arange(1, 43)
roll_idx[np.roll(idx - 1, -1)] = idx

In [199]:
torch.tensor([1.,2.,3.,4.]).pow(-1)

tensor([1.0000, 0.5000, 0.3333, 0.2500])

In [84]:
counter_clockwise_node_pairs(edges, pos)

array([[ 0,  8, 15],
       [ 0, 15,  6],
       [ 0,  6,  5],
       [ 0,  5,  8],
       [ 2, 15,  6],
       [ 2,  6, 15],
       [ 3, 14,  8],
       [ 3,  8, 15],
       [ 3, 15, 10],
       [ 3, 10, 14],
       [ 4, 11,  8],
       [ 4,  8, 11],
       [ 6,  0,  2],
       [ 6,  2,  0],
       [ 8,  4,  0],
       [ 8,  0,  3],
       [ 8,  3,  4],
       [ 9, 15, 12],
       [ 9, 12, 11],
       [ 9, 11, 15],
       [10, 13,  3],
       [10,  3, 13],
       [11,  9, 13],
       [11, 13,  4],
       [11,  4,  9],
       [12,  9, 13],
       [12, 13,  1],
       [12,  1,  9],
       [13, 11, 10],
       [13, 10, 12],
       [13, 12, 11],
       [14, 15,  7],
       [14,  7,  3],
       [14,  3, 15],
       [15,  9, 14],
       [15, 14,  2],
       [15,  2,  3],
       [15,  3,  0],
       [15,  0,  9]])

In [212]:
cross_spring(edges, pos).shape

(15, 2)

In [215]:
pairs

array([[14,  8],
       [ 8,  9],
       [ 9, 15],
       [15,  6],
       [ 6, 10],
       [10, 13],
       [13,  4],
       [ 4,  2],
       [ 2,  1],
       [ 1, 12],
       [12,  7],
       [ 7,  3],
       [ 3,  5],
       [ 5, 11],
       [11, 14]])

In [145]:
e0u, e1u = np.linalg.norm(e0, axis=1), np.linalg.norm(e1, axis=1)

In [146]:
l = np.sqrt(np.abs(e0u ** 2 + e1u ** 2 - 2 * e0u * e1u * np.cos(2 * np.pi / du)))

array([0.03640628, 0.05352685, 0.2593167 , 0.27827108, 0.2112186 ,
       0.200629  , 0.12601173, 0.16117835, 0.33452034, 0.38255978,
       0.4499724 , 0.21375167, 0.23889041, 0.13126785, 0.0792875 ],
      dtype=float32)

In [None]:
f_ang_spring = 