In [10]:
%load_ext autoreload
%load_ext line_profiler
%autoreload 2

import numpy as np
import networkx as nx
import torch
import torch.nn as nn
from torch.utils import data
from tqdm.notebook import tqdm

%aimport metrics_pytorch
from metrics_pytorch import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [13]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

### Let's make a simple symmetric graph (3D gridworld)

In [14]:
graph_id = 'simple3d'

D = 5
N_STATES = D*D*D

def idx_to_coord(idx):
  return idx // (D*D), (idx % (D*D)) // D, idx % D

def coord_to_idx(coord):
  x, y, z = coord
  return x*D*D + y*D + z

def neighbors_c(coord):
  x, y, z = coord
  return [
    ((x + 1) % D, y, z),
    ((x - 1) % D, y, z),
    (x, (y + 1) % D, z),
    (x, (y - 1) % D, z),
    (x, y, (z + 1) % D),
    (x, y, (z - 1) % D)
  ]

def neighbors_i(idx):
  return list(map(coord_to_idx, neighbors_c(idx_to_coord(idx))))

G = nx.Graph()
for node in range(N_STATES):
  for neighbor in neighbors_i(node):
    G.add_edge(node, neighbor, weight=np.random.randint(1, 101))
    
G = G.to_undirected() # Make symm
G.number_of_nodes(), G.number_of_edges()

(125, 375)

##### Since there are only 125 nodes, we will use 1-hot embeddings, and prepare a dataset of $(X_1, X_2, D(X_1,X_2), D(X_1, X_2) + \text{Noise})$

In [192]:
X1, X2, DX, DXnoisy = [], [], [], []

for node in range(N_STATES):
  for target in range(N_STATES):
    X1.append(node)
    X2.append(target)
    DX.append(nx.astar_path_length(G, node, target))
    DXnoisy.append(DX[-1] + np.random.normal(scale=5))

X1 = torch.LongTensor(X1).to(DEVICE)
X2 = torch.LongTensor(X2).to(DEVICE)
DX = torch.FloatTensor(DX).to(DEVICE)
DXnoisy = torch.FloatTensor(DXnoisy).to(DEVICE)

In [251]:
# LETS USE 20% of the DATA to TRAIN OUR METRICS
TRAINING_INDICES = torch.LongTensor(np.random.choice(len(X1), size=(len(X1) // 5,), replace=False)).to(DEVICE)

### Now let's define some metrics, and check whether they satisfy triangle inequality

In [384]:
def GroundTruthMetric(x1s, x2s):
  idxs = x1s*N_STATES + x2s
  return DX[idxs.to(torch.int64)]

def NoisyMetric(x1s, x2s):
  idxs = x1s*N_STATES + x2s
  return DXnoisy[idxs.to(torch.int64)]

class MetricModel(nn.Module):
  def __init__(self, metric):
    super().__init__()
    self.metric = metric
    self.embedding = nn.Linear(N_STATES, N_STATES, bias=False)
    
  def forward(self, x, y):
    """x and y will be ints"""
    x = torch.eye(N_STATES, device=DEVICE)[x]
    y = torch.eye(N_STATES, device=DEVICE)[y]
    phi_x = self.embedding(x)
    phi_y = self.embedding(y)
    return self.metric(phi_x, phi_y)

mahalanobis  = MetricModel(EuclideanMetric()).to(DEVICE)
widenorm     = MetricModel(WideNormMetric(N_STATES, 32, 32)).to(DEVICE)
deepnorm     = MetricModel(DeepNormMetric(
                  N_STATES, (128, 128), 
                  activation=lambda: MaxReLUPairwiseActivation(128))).to(DEVICE)
neuralmetric = MetricModel(DeepNormMetric(
                  N_STATES, (128, 128), activation=lambda: MaxReLUPairwiseActivation(128), 
                  concave_activation_size=10)).to(DEVICE)
notametric   = MetricModel(MLPNonMetric(N_STATES, (128, 128), mode='subtract')).to(DEVICE)

def train_metric(metric, opt, epochs=500):
  for _ in range(epochs):
    loss = F.mse_loss(metric(X1[TRAINING_INDICES], X2[TRAINING_INDICES]), DXnoisy[TRAINING_INDICES])
    opt.zero_grad()
    loss.backward()
    opt.step()
    
def train_on_full_dataset(metric, opt, epochs=500):
  for _ in range(epochs):
    loss = F.mse_loss(metric(X1, X2), DXnoisy)
    opt.zero_grad()
    loss.backward()
    opt.step()
      
def test_violations(metric, N=10000):
  triplets = torch.randint(N_STATES, (N, 3)).to(DEVICE)
  start, middle, end = [t.squeeze(1) for t in torch.chunk(triplets, 3, 1)]
  violations = metric(start, end) > metric(start, middle) + metric(middle, end) + 1e-4 # buffer for precision
  return torch.sum(violations) 
  
def l2_error(metric1, metric2, N=10000):
  pairs = torch.randint(N_STATES, (N, 2)).to(DEVICE)
  start, end = [t.squeeze(1) for t in torch.chunk(pairs, 2, 1)]
  return F.mse_loss(metric1(start, end), metric2(start, end)) 

def evaluate_metric(metric, N=10000):
  print(f'{test_violations(metric, N)} triangle ineq. violations. '+\
        f'{l2_error(metric, GroundTruthMetric, N):.4f} l2 error to GT. '+\
        f'{l2_error(metric, NoisyMetric, N):.4f} l2 error to noisy GT.')
  
def evaluate_and_train_metric(metric, epochs=20, lr=0.1, training_fn=train_metric):
  opt = torch.optim.SGD(metric.parameters(), lr, momentum=0.9, weight_decay=1e-4)
  for _ in range(epochs):
    evaluate_metric(metric)
    training_fn(metric, opt)
  evaluate_metric(metric)

In [385]:
evaluate_metric(GroundTruthMetric)

0 triangle ineq. violations. 0.0000 l2 error to GT. 24.9282 l2 error to noisy GT.


In [386]:
evaluate_metric(NoisyMetric)

306 triangle ineq. violations. 24.3428 l2 error to GT. 0.0000 l2 error to noisy GT.


In [387]:
# Performance seems to saturate at around 150-160
evaluate_and_train_metric(mahalanobis, lr=1.)

0 triangle ineq. violations. 10311.9072 l2 error to GT. 10341.7471 l2 error to noisy GT.
0 triangle ineq. violations. 170.3797 l2 error to GT. 188.9328 l2 error to noisy GT.
0 triangle ineq. violations. 159.3267 l2 error to GT. 182.2031 l2 error to noisy GT.
0 triangle ineq. violations. 159.4852 l2 error to GT. 181.7425 l2 error to noisy GT.
0 triangle ineq. violations. 152.7919 l2 error to GT. 183.4248 l2 error to noisy GT.
0 triangle ineq. violations. 153.8724 l2 error to GT. 176.5508 l2 error to noisy GT.
0 triangle ineq. violations. 159.1762 l2 error to GT. 184.9554 l2 error to noisy GT.
0 triangle ineq. violations. 153.1673 l2 error to GT. 177.6189 l2 error to noisy GT.
0 triangle ineq. violations. 155.7677 l2 error to GT. 184.5368 l2 error to noisy GT.
0 triangle ineq. violations. 161.6409 l2 error to GT. 177.2016 l2 error to noisy GT.
0 triangle ineq. violations. 152.0458 l2 error to GT. 182.0342 l2 error to noisy GT.
0 triangle ineq. violations. 159.1007 l2 error to GT. 180.298

In [388]:
# Widenorm is much more expressive, with performance saturating around 110-120
widenorm     = MetricModel(WideNormMetric(N_STATES, 32, 32)).to(DEVICE)
evaluate_and_train_metric(widenorm)

0 triangle ineq. violations. 10471.7471 l2 error to GT. 10518.0234 l2 error to noisy GT.
0 triangle ineq. violations. 343.7487 l2 error to GT. 361.9023 l2 error to noisy GT.
0 triangle ineq. violations. 130.0676 l2 error to GT. 151.8909 l2 error to noisy GT.
0 triangle ineq. violations. 115.6810 l2 error to GT. 142.0233 l2 error to noisy GT.
0 triangle ineq. violations. 114.6723 l2 error to GT. 134.4137 l2 error to noisy GT.
0 triangle ineq. violations. 111.1650 l2 error to GT. 130.0697 l2 error to noisy GT.
0 triangle ineq. violations. 111.7674 l2 error to GT. 133.4260 l2 error to noisy GT.
0 triangle ineq. violations. 118.1884 l2 error to GT. 133.4871 l2 error to noisy GT.
0 triangle ineq. violations. 115.1974 l2 error to GT. 134.8011 l2 error to noisy GT.
0 triangle ineq. violations. 116.2533 l2 error to GT. 135.4603 l2 error to noisy GT.
0 triangle ineq. violations. 115.7687 l2 error to GT. 133.2284 l2 error to noisy GT.
1 triangle ineq. violations. 115.2697 l2 error to GT. 132.804

In [411]:
# Requires a smaller learning rate and trains a bit slower. 
deepnorm     = MetricModel(DeepNormMetric(
                  N_STATES, (128, 128), 
                  activation=lambda: MaxReLUPairwiseActivation(128))).to(DEVICE)
evaluate_and_train_metric(deepnorm, epochs=50, lr=1e-2)

0 triangle ineq. violations. 10400.1523 l2 error to GT. 10573.5264 l2 error to noisy GT.
0 triangle ineq. violations. 168.0085 l2 error to GT. 193.4304 l2 error to noisy GT.
0 triangle ineq. violations. 138.3754 l2 error to GT. 162.1016 l2 error to noisy GT.
0 triangle ineq. violations. 131.3053 l2 error to GT. 158.0584 l2 error to noisy GT.
0 triangle ineq. violations. 131.4579 l2 error to GT. 156.5459 l2 error to noisy GT.
0 triangle ineq. violations. 127.7473 l2 error to GT. 149.6919 l2 error to noisy GT.
0 triangle ineq. violations. 128.4855 l2 error to GT. 144.8129 l2 error to noisy GT.
0 triangle ineq. violations. 122.9368 l2 error to GT. 144.8664 l2 error to noisy GT.
0 triangle ineq. violations. 121.3388 l2 error to GT. 140.4595 l2 error to noisy GT.
0 triangle ineq. violations. 125.2140 l2 error to GT. 135.5015 l2 error to noisy GT.
0 triangle ineq. violations. 124.6440 l2 error to GT. 145.1458 l2 error to noisy GT.
0 triangle ineq. violations. 120.3016 l2 error to GT. 145.574

In [412]:
# Concave activations might help a bit here. 
neuralmetric = MetricModel(DeepNormMetric(
                  N_STATES, (128, 128), activation=lambda: MaxReLUPairwiseActivation(128), 
                  concave_activation_size=10)).to(DEVICE)
evaluate_and_train_metric(neuralmetric, epochs=20, lr=1e-2)

0 triangle ineq. violations. 10474.0420 l2 error to GT. 10534.6318 l2 error to noisy GT.
0 triangle ineq. violations. 164.4842 l2 error to GT. 186.1844 l2 error to noisy GT.
0 triangle ineq. violations. 135.1067 l2 error to GT. 154.1147 l2 error to noisy GT.
0 triangle ineq. violations. 135.4845 l2 error to GT. 150.5378 l2 error to noisy GT.
0 triangle ineq. violations. 119.1945 l2 error to GT. 145.5924 l2 error to noisy GT.
0 triangle ineq. violations. 123.3479 l2 error to GT. 142.1933 l2 error to noisy GT.
0 triangle ineq. violations. 124.8420 l2 error to GT. 143.2315 l2 error to noisy GT.
0 triangle ineq. violations. 121.7935 l2 error to GT. 139.9626 l2 error to noisy GT.
0 triangle ineq. violations. 115.8865 l2 error to GT. 142.2808 l2 error to noisy GT.
0 triangle ineq. violations. 117.8065 l2 error to GT. 144.5377 l2 error to noisy GT.
0 triangle ineq. violations. 118.9322 l2 error to GT. 137.1372 l2 error to noisy GT.
0 triangle ineq. violations. 122.7727 l2 error to GT. 140.612

In [413]:
# MLP needs a very small learning rate to train. 
# However, without the inductive bias, it doesn't fit very well.
# And it violates triangle inequality
notametric   = MetricModel(MLPNonMetric(N_STATES, (128, 128), mode='subtract')).to(DEVICE)
evaluate_and_train_metric(notametric, lr=3e-4)

10000 triangle ineq. violations. 10495.0625 l2 error to GT. 10534.7617 l2 error to noisy GT.
77 triangle ineq. violations. 184.6394 l2 error to GT. 210.0858 l2 error to noisy GT.
148 triangle ineq. violations. 178.7564 l2 error to GT. 193.9753 l2 error to noisy GT.
187 triangle ineq. violations. 181.3610 l2 error to GT. 198.7089 l2 error to noisy GT.
196 triangle ineq. violations. 193.4611 l2 error to GT. 202.2033 l2 error to noisy GT.
189 triangle ineq. violations. 185.0388 l2 error to GT. 208.8876 l2 error to noisy GT.
194 triangle ineq. violations. 183.0283 l2 error to GT. 209.8664 l2 error to noisy GT.
173 triangle ineq. violations. 195.1491 l2 error to GT. 204.6775 l2 error to noisy GT.
203 triangle ineq. violations. 185.0747 l2 error to GT. 207.7710 l2 error to noisy GT.
191 triangle ineq. violations. 194.9169 l2 error to GT. 208.4712 l2 error to noisy GT.
235 triangle ineq. violations. 187.6752 l2 error to GT. 203.4320 l2 error to noisy GT.
194 triangle ineq. violations. 186.542

In [414]:
# Note that if we let it train on the whole dataset, it fits exceptionally well (though has violations)
notametric   = MetricModel(MLPNonMetric(N_STATES, (128, 128), mode='subtract')).to(DEVICE)
evaluate_and_train_metric(notametric, lr=3e-4, training_fn=train_on_full_dataset)

0 triangle ineq. violations. 10539.4316 l2 error to GT. 10582.9150 l2 error to noisy GT.
101 triangle ineq. violations. 83.6959 l2 error to GT. 102.5730 l2 error to noisy GT.
118 triangle ineq. violations. 30.7381 l2 error to GT. 47.8616 l2 error to noisy GT.
131 triangle ineq. violations. 23.0263 l2 error to GT. 36.3405 l2 error to noisy GT.
173 triangle ineq. violations. 17.7455 l2 error to GT. 27.6664 l2 error to noisy GT.
154 triangle ineq. violations. 14.4404 l2 error to GT. 22.5303 l2 error to noisy GT.
154 triangle ineq. violations. 20.2875 l2 error to GT. 24.9125 l2 error to noisy GT.
167 triangle ineq. violations. 13.2997 l2 error to GT. 16.3037 l2 error to noisy GT.
184 triangle ineq. violations. 29.3111 l2 error to GT. 30.5212 l2 error to noisy GT.
191 triangle ineq. violations. 16.2492 l2 error to GT. 14.1748 l2 error to noisy GT.
243 triangle ineq. violations. 24.8715 l2 error to GT. 20.8237 l2 error to noisy GT.
179 triangle ineq. violations. 19.3138 l2 error to GT. 13.86