In [1]:
import sys
sys.path.insert(0, '../pyLDLE2/')

import torch
import torch.nn
import torch.optim

import GraphX as gx
import ConnectionGraphX as cgx
import numpy as np
import networkx as nx
import scipy as sp

import ConnectionNetworkX as cnx

import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm


matplotlib.get_backend() =  module://matplotlib_inline.backend_inline


In [2]:
'''
phi,c: |V|d x 1
B: |E|d x |V|d      Connection incidence matrix
w: |E| x 1          Edge weights
c: |V|d x 1         c = alpha-beta; i.e. difference of densities
'''
def loss_fn(phi, B, w, c):
    loss0 = -torch.sum(phi*c)
    
    loss1 = torch.matmul(B, phi).reshape((w.shape[0],-1))
    loss1 = torch.linalg.norm(loss1, dim=1)
    loss1 = loss1 - w
    loss1 = torch.nn.ReLU()(loss1)
    loss1 = torch.sum(loss1**2)
    
    return loss0, loss1


def active_edges(phi, B, w, c):
    loss1 = torch.matmul(B, phi).reshape((w.shape[0],-1))
    loss1 = torch.linalg.norm(loss1, dim=1)
    loss1 = loss1 - w
    return loss1

In [3]:
def optimize(B, w, c, alpha, learning_rate, n_epochs, phi0 = None, print_freq=10):
    if phi0 is None:
        phi = torch.randn(B.shape[1], 1, requires_grad=True)
    else:
        phi = torch.tensor(phi0, requires_grad=True)
    optimizer = torch.optim.Adam([phi], lr=learning_rate)
    for epoch in range(n_epochs):
        # Compute loss
        loss0, loss1 = loss_fn(phi, B, w, c)
        loss = loss0 + (0.5/alpha)*loss1

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % print_freq == 0:
            print(f"epoch: {epoch}, loss: {loss:>7f}, loss0: {loss0:>7f}, loss1: {loss1:>7f}")
    return phi

# Define B, w and c

In [4]:
NODES = 2
EDGES = 1
seed = 42

a = nx.adjacency_matrix(nx.gnm_random_graph(NODES, EDGES, seed=seed))

# a = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
g = gx.GraphX(sp.sparse.csr_matrix.toarray(a))

DIM_CONNECTION = 2
h = cgx.ConnectionGraphX(sp.sparse.csr_matrix.toarray(a), DIM_CONNECTION)

In [36]:
B = h.connectionIncidenceMatrix.T.astype('float32')
w = np.ones(B.shape[0]//DIM_CONNECTION).astype('float32')

np.random.seed(42)
def rand_prob_mass(n, d):
    mu = np.random.uniform(0, 1, (n, d)).astype('float32')
    mu = mu/(mu.sum(axis=0)[None,:])
    mu = mu.flatten()[:,None]
    return mu

mu = rand_prob_mass(h.nNodes, DIM_CONNECTION)
nu = rand_prob_mass(h.nNodes, DIM_CONNECTION)
c = (mu - nu)

In [37]:
c.reshape((h.nNodes, DIM_CONNECTION))

array([[-0.39023045,  0.46100134],
       [ 0.39023048, -0.46100134]], dtype=float32)

In [38]:
c.reshape((h.nNodes, DIM_CONNECTION)).sum(axis=0)

array([2.9802322e-08, 0.0000000e+00], dtype=float32)

# Check feasibility of B

In [39]:
c_sol, residuals, _, _ = np.linalg.lstsq(B.T, c.flatten())

  c_sol, residuals, _, _ = np.linalg.lstsq(B.T, c.flatten())


In [11]:
residuals

array([4.440892e-16], dtype=float32)

In [12]:
np.linalg.norm(c.flatten() - B.T.dot(c_sol).flatten())

2.9802322e-08

In [52]:
learning_rate = 0.1
alpha = 1e-5
n_epochs = 10000

B = torch.tensor(B)
w = torch.tensor(w)
c = torch.tensor(c)

optimize(B, w, c, alpha, learning_rate, n_epochs)

  B = torch.tensor(B)
  w = torch.tensor(w)
  c = torch.tensor(c)


epoch: 0, loss: 132426576.000000
epoch: 10, loss: 102855832.000000
epoch: 20, loss: 79530856.000000
epoch: 30, loss: 61616044.000000
epoch: 40, loss: 47996096.000000
epoch: 50, loss: 37631480.000000
epoch: 60, loss: 29709116.000000
epoch: 70, loss: 23617958.000000
epoch: 80, loss: 18903358.000000
epoch: 90, loss: 15226124.000000
epoch: 100, loss: 12333168.000000
epoch: 110, loss: 10041921.000000
epoch: 120, loss: 8215495.000000
epoch: 130, loss: 6750782.500000
epoch: 140, loss: 5570462.000000
epoch: 150, loss: 4615551.500000
epoch: 160, loss: 3840102.250000
epoch: 170, loss: 3208350.750000
epoch: 180, loss: 2692356.750000
epoch: 190, loss: 2269549.500000
epoch: 200, loss: 1921701.500000
epoch: 210, loss: 1633956.375000
epoch: 220, loss: 1394844.000000
epoch: 230, loss: 1195149.250000
epoch: 240, loss: 1027512.125000
epoch: 250, loss: 886152.687500
epoch: 260, loss: 766512.625000
epoch: 270, loss: 664860.562500
epoch: 280, loss: 578149.562500
epoch: 290, loss: 503862.156250
epoch: 300, 

epoch: 2960, loss: -1.691564
epoch: 2970, loss: -1.691720
epoch: 2980, loss: -1.691876
epoch: 2990, loss: -1.692033
epoch: 3000, loss: -1.692191
epoch: 3010, loss: -1.692350
epoch: 3020, loss: -1.692509
epoch: 3030, loss: -1.692670
epoch: 3040, loss: -1.692831
epoch: 3050, loss: -1.692993
epoch: 3060, loss: -1.693156
epoch: 3070, loss: -1.693320
epoch: 3080, loss: -1.693484
epoch: 3090, loss: -1.693650
epoch: 3100, loss: -1.693816
epoch: 3110, loss: -1.693983
epoch: 3120, loss: -1.694151
epoch: 3130, loss: -1.694320
epoch: 3140, loss: -1.694490
epoch: 3150, loss: -1.694661
epoch: 3160, loss: -1.694832
epoch: 3170, loss: -1.695005
epoch: 3180, loss: -1.695178
epoch: 3190, loss: -1.695353
epoch: 3200, loss: -1.695528
epoch: 3210, loss: -1.695704
epoch: 3220, loss: -1.695881
epoch: 3230, loss: -1.696059
epoch: 3240, loss: -1.696238
epoch: 3250, loss: -1.696418
epoch: 3260, loss: -1.696598
epoch: 3270, loss: -1.696780
epoch: 3280, loss: -1.696963
epoch: 3290, loss: -1.697146
epoch: 3300, l

epoch: 5960, loss: -1.801894
epoch: 5970, loss: -1.802610
epoch: 5980, loss: -1.803328
epoch: 5990, loss: -1.804051
epoch: 6000, loss: -1.804777
epoch: 6010, loss: -1.805507
epoch: 6020, loss: -1.806240
epoch: 6030, loss: -1.806977
epoch: 6040, loss: -1.807718
epoch: 6050, loss: -1.808462
epoch: 6060, loss: -1.809210
epoch: 6070, loss: -1.809962
epoch: 6080, loss: -1.810717
epoch: 6090, loss: -1.811477
epoch: 6100, loss: -1.812240
epoch: 6110, loss: -1.813007
epoch: 6120, loss: -1.813778
epoch: 6130, loss: -1.814552
epoch: 6140, loss: -1.815331
epoch: 6150, loss: -1.816113
epoch: 6160, loss: -1.816900
epoch: 6170, loss: -1.817691
epoch: 6180, loss: -1.818486
epoch: 6190, loss: -1.819285
epoch: 6200, loss: -1.820088
epoch: 6210, loss: -1.820894
epoch: 6220, loss: -1.821705
epoch: 6230, loss: -1.822520
epoch: 6240, loss: -1.823340
epoch: 6250, loss: -1.824163
epoch: 6260, loss: -1.824991
epoch: 6270, loss: -1.825822
epoch: 6280, loss: -1.826658
epoch: 6290, loss: -1.827498
epoch: 6300, l

epoch: 8940, loss: -2.130183
epoch: 8950, loss: -2.130707
epoch: 8960, loss: -2.131232
epoch: 8970, loss: -2.131760
epoch: 8980, loss: -2.132290
epoch: 8990, loss: -2.132821
epoch: 9000, loss: -2.133354
epoch: 9010, loss: -2.133890
epoch: 9020, loss: -2.134427
epoch: 9030, loss: -2.134967
epoch: 9040, loss: -2.135509
epoch: 9050, loss: -2.136052
epoch: 9060, loss: -2.136598
epoch: 9070, loss: -2.137145
epoch: 9080, loss: -2.137695
epoch: 9090, loss: -2.138246
epoch: 9100, loss: -2.138800
epoch: 9110, loss: -2.139355
epoch: 9120, loss: -2.139913
epoch: 9130, loss: -2.140473
epoch: 9140, loss: -2.141034
epoch: 9150, loss: -2.141598
epoch: 9160, loss: -2.142164
epoch: 9170, loss: -2.142732
epoch: 9180, loss: -2.143302
epoch: 9190, loss: -2.143874
epoch: 9200, loss: -2.144448
epoch: 9210, loss: -2.145024
epoch: 9220, loss: -2.145603
epoch: 9230, loss: -2.146183
epoch: 9240, loss: -2.146766
epoch: 9250, loss: -2.147350
epoch: 9260, loss: -2.147936
epoch: 9270, loss: -2.148525
epoch: 9280, l

epoch: 11870, loss: -2.349428
epoch: 11880, loss: -2.361071
epoch: 11890, loss: -2.373480
epoch: 11900, loss: -2.375175
epoch: 11910, loss: -2.379894
epoch: 11920, loss: -2.367612
epoch: 11930, loss: -2.370335
epoch: 11940, loss: -2.379749
epoch: 11950, loss: -2.375687
epoch: 11960, loss: -2.374114
epoch: 11970, loss: -2.382146
epoch: 11980, loss: -2.376537
epoch: 11990, loss: -2.367945
epoch: 12000, loss: -2.373600
epoch: 12010, loss: -2.384202
epoch: 12020, loss: -2.383232
epoch: 12030, loss: -2.358673
epoch: 12040, loss: -2.358825
epoch: 12050, loss: -2.367618
epoch: 12060, loss: -2.379633
epoch: 12070, loss: -2.392444
epoch: 12080, loss: -2.398501
epoch: 12090, loss: -2.394556
epoch: 12100, loss: -2.387462
epoch: 12110, loss: -2.391257
epoch: 12120, loss: -2.401548
epoch: 12130, loss: -2.379933
epoch: 12140, loss: -2.361239
epoch: 12150, loss: -2.363900
epoch: 12160, loss: -2.373947
epoch: 12170, loss: -2.386675
epoch: 12180, loss: -2.400399
epoch: 12190, loss: -2.414297
epoch: 122

epoch: 14770, loss: -2.753651
epoch: 14780, loss: -2.745789
epoch: 14790, loss: -2.714396
epoch: 14800, loss: -2.718951
epoch: 14810, loss: -2.736310
epoch: 14820, loss: -2.758279
epoch: 14830, loss: -2.781752
epoch: 14840, loss: -2.805290
epoch: 14850, loss: -2.829524
epoch: 14860, loss: -2.854052
epoch: 14870, loss: -2.878946
epoch: 14880, loss: -2.873134
epoch: 14890, loss: -2.859673
epoch: 14900, loss: -2.870502
epoch: 14910, loss: -2.890566
epoch: 14920, loss: -2.913946
epoch: 14930, loss: -2.827794
epoch: 14940, loss: -2.769784
epoch: 14950, loss: -2.764678
epoch: 14960, loss: -2.778461
epoch: 14970, loss: -2.798655
epoch: 14980, loss: -2.821402
epoch: 14990, loss: -2.844861
epoch: 15000, loss: -2.868742
epoch: 15010, loss: -2.893151
epoch: 15020, loss: -2.917172
epoch: 15030, loss: -2.941745
epoch: 15040, loss: -2.966082
epoch: 15050, loss: -2.901336
epoch: 15060, loss: -2.859350
epoch: 15070, loss: -2.859960
epoch: 15080, loss: -2.875133
epoch: 15090, loss: -2.896378
epoch: 151

epoch: 17690, loss: -3.478828
epoch: 17700, loss: -3.487969
epoch: 17710, loss: -3.504972
epoch: 17720, loss: -3.525954
epoch: 17730, loss: -3.547286
epoch: 17740, loss: -3.569840
epoch: 17750, loss: -3.557301
epoch: 17760, loss: -3.564352
epoch: 17770, loss: -3.572204
epoch: 17780, loss: -3.533027
epoch: 17790, loss: -3.533446
epoch: 17800, loss: -3.548445
epoch: 17810, loss: -3.567597
epoch: 17820, loss: -3.590048
epoch: 17830, loss: -3.605275
epoch: 17840, loss: -3.597266
epoch: 17850, loss: -3.606355
epoch: 17860, loss: -3.611268
epoch: 17870, loss: -3.592162
epoch: 17880, loss: -3.596319
epoch: 17890, loss: -3.613768
epoch: 17900, loss: -3.637637
epoch: 17910, loss: -3.621092
epoch: 17920, loss: -3.625272
epoch: 17930, loss: -3.630187
epoch: 17940, loss: -3.619396
epoch: 17950, loss: -3.580527
epoch: 17960, loss: -3.538099
epoch: 17970, loss: -3.540755
epoch: 17980, loss: -3.559744
epoch: 17990, loss: -3.584547
epoch: 18000, loss: -3.607560
epoch: 18010, loss: -3.631634
epoch: 180

tensor([[ 6.7340e-01],
        [ 1.2780e+00],
        [-1.0491e+00],
        [-8.5379e-01],
        [ 5.0705e-01],
        [-7.0313e-01],
        [ 1.1674e-01],
        [-1.3119e-01],
        [ 1.8228e-01],
        [ 6.4240e-01],
        [-2.4638e-01],
        [-5.9769e-02],
        [ 8.2849e-01],
        [-1.2071e-01],
        [-1.6568e-01],
        [ 7.9109e-01],
        [-1.6643e-01],
        [-3.9150e-01],
        [-2.4814e-01],
        [-2.0349e-01],
        [-9.6865e-01],
        [ 4.4054e-01],
        [ 4.9716e-01],
        [ 4.1691e-01],
        [-2.3037e-02],
        [-1.6836e-01],
        [ 1.0982e-01],
        [-3.6439e-01],
        [ 1.6225e-01],
        [ 1.9481e-01],
        [-3.5027e-01],
        [ 2.5473e-01],
        [ 1.7783e-01],
        [ 2.9683e-01],
        [ 5.9190e-01],
        [ 4.7400e-01],
        [-2.3670e-01],
        [ 8.0121e-01],
        [-4.7974e-01],
        [ 1.1167e+00],
        [ 1.5239e-01],
        [ 4.3944e-01],
        [ 1.6281e-01],
        [ 8

In [53]:
w

tensor([1., 1., 1.,  ..., 1., 1., 1.])

# Working example

In [14]:
w = np.ones(1)
d = 2
sigma = np.eye(d)
B = np.block([np.sqrt(w)*np.eye(d), -np.sqrt(w)*sigma])

np.random.seed(42)
mu = np.random.uniform(0, 1, d)
mu = mu/mu.sum()
nu = np.random.uniform(0, 1, d)
nu = nu/nu.sum()
c = mu-nu
optim_val = np.linalg.norm(c)
optim_phi = c/np.linalg.norm(c)
c = np.block([c,-c])
c = c.reshape((2*d,1))

phi0 = np.block([optim_phi,optim_phi*0]).reshape(2*d,1)

print(B)
print(w)
print(c)
print(phi0)
print(optim_val)

B = torch.tensor(B.astype('float32'))
w = torch.tensor(w.astype('float32'))
c = torch.tensor(c.astype('float32'))
phi0 = phi0.astype('float32')

[[ 1.  0. -1. -0.]
 [ 0.  1. -0. -1.]]
[1.]
[[-0.26748401]
 [ 0.26748401]
 [ 0.26748401]
 [-0.26748401]]
[[-0.70710678]
 [ 0.70710678]
 [-0.        ]
 [ 0.        ]]
0.37827952162613965


# Solve Beckmann - connection problem

In [15]:
learning_rate = 0.001
alpha = 1e-1
n_epochs = 10000

In [16]:
optimize(B, w, c, alpha, learning_rate, n_epochs, phi0=phi0)

epoch: 0, loss: -0.378280
epoch: 10, loss: -0.384847
epoch: 20, loss: -0.385245
epoch: 30, loss: -0.385234
epoch: 40, loss: -0.385434
epoch: 50, loss: -0.385405
epoch: 60, loss: -0.385434
epoch: 70, loss: -0.385430
epoch: 80, loss: -0.385434
epoch: 90, loss: -0.385434
epoch: 100, loss: -0.385434
epoch: 110, loss: -0.385434
epoch: 120, loss: -0.385434
epoch: 130, loss: -0.385434
epoch: 140, loss: -0.385434
epoch: 150, loss: -0.385434
epoch: 160, loss: -0.385434
epoch: 170, loss: -0.385434
epoch: 180, loss: -0.385434
epoch: 190, loss: -0.385434
epoch: 200, loss: -0.385434
epoch: 210, loss: -0.385434
epoch: 220, loss: -0.385434
epoch: 230, loss: -0.385434
epoch: 240, loss: -0.385434
epoch: 250, loss: -0.385434
epoch: 260, loss: -0.385434
epoch: 270, loss: -0.385434
epoch: 280, loss: -0.385434
epoch: 290, loss: -0.385434
epoch: 300, loss: -0.385434
epoch: 310, loss: -0.385434
epoch: 320, loss: -0.385434
epoch: 330, loss: -0.385434
epoch: 340, loss: -0.385434
epoch: 350, loss: -0.385434
epo

tensor([[-0.7205],
        [ 0.7205],
        [ 0.0134],
        [-0.0134]], requires_grad=True)

# Attempt Beckmann on Puppet Connection Graph

Initialize the puppet connection graph

In [4]:
p = "/data2/dhruv/PuppetsData/"

N_PUPPET_IMAGES = 500
NEAREST_NEIGHBORS = 13
INTRINSIC_DIMENSION = 2

puppetGraph = cnx.cnxFromImageDirectory(p, INTRINSIC_DIMENSION, k=NEAREST_NEIGHBORS, nImages=N_PUPPET_IMAGES)


X.shape =  (8100, 100)
local_opts['k_nn0'] = 49 is created.
Options provided:
local_opts:
{
    "Atilde_method": "LDLE_1",
    "N": 100,
    "U_method": "k_nn",
    "algo": "LPCA",
    "alpha": 1,
    "debug": true,
    "delta": 0.9,
    "gl_type": "unnorm",
    "k": 13,
    "k_nn": 49,
    "k_nn0": 49,
    "k_tune": 7,
    "lambda1_decay": 0.75,
    "lambda1_init": 8,
    "lambda1_min": 0.001,
    "max_iter": 300,
    "max_sparsity": 0.9,
    "metric": "euclidean",
    "n_proc": 32,
    "p": 0.99,
    "power": 5,
    "pp_n_thresh": 32,
    "radius": 0.5,
    "reg": 0.0,
    "scale_by": "gamma",
    "tau": 50,
    "to_postprocess": true,
    "tuning": "self",
    "verbose": true
}
intermed_opts:
{
    "algo": "best",
    "debug": true,
    "eta_max": 1,
    "eta_min": 5,
    "len_S_thresh": 256,
    "local_algo": "LPCA",
    "n_proc": 32,
    "n_times": 4,
    "verbose": true
}
global_opts:
{
    "add_dim": false,
    "align_transform": "rigid",
    "align_w_parent_only": true,
    "al

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:11<00:00, 43.85it/s]


Proportion of edges which were removed due to remoteness:  0.012661712083677402
[3.06901529 2.9604963  2.83141477 2.82813369 2.80719606 2.51536304
 2.43105356 2.0961851  2.05515014 1.82172534]
MOST LIKELY INCONSISTENT: |lambda_min| >= 1e-8. 


Setting up B, c, w for the puppet graph

In [5]:
puppetB = puppetGraph.connectionIncidenceMatrix.T.astype('float32')
puppetW = np.ones(puppetB.shape[0]//INTRINSIC_DIMENSION).astype('float32')

np.random.seed(42)

#puppetMu = rand_prob_mass(puppetGraph.nNodes, INTRINSIC_DIMENSION)
#puppetNu = rand_prob_mass(puppetGraph.nNodes, INTRINSIC_DIMENSION)

puppetMu = np.zeros((puppetGraph.nNodes * INTRINSIC_DIMENSION,1))
puppetMu[0:2, :] = [[1],[1]]
puppetNu = np.zeros((puppetGraph.nNodes * INTRINSIC_DIMENSION,1))
puppetNu[2:4, :] = [[1],[1]]

puppetC = (puppetMu - puppetNu)

puppetC.reshape((puppetGraph.nNodes, INTRINSIC_DIMENSION))
puppetC.reshape((puppetGraph.nNodes, INTRINSIC_DIMENSION)).sum(axis=0)


array([0., 0.])

Checking Feasibility - looks good

In [6]:
print(type(puppetC))
print(type(puppetB))
c_sol, residuals, _, _ = np.linalg.lstsq(sp.sparse.lil_matrix.toarray(puppetB.T), puppetC)

# I didn't rename the dummy variables here- caution.

np.linalg.norm(puppetC.flatten() - puppetB.T.dot(c_sol).flatten())
#np.linalg.norm(c_sol)


<class 'numpy.ndarray'>
<class 'scipy.sparse._lil.lil_matrix'>


  c_sol, residuals, _, _ = np.linalg.lstsq(sp.sparse.lil_matrix.toarray(puppetB.T), puppetC)


7.196364329407029e-15

Attempt to plug into the optimization methods

In [7]:
learning_rate = 0.1
alpha = 1
n_epochs = 10000

B = torch.tensor(sp.sparse.lil_matrix.toarray(puppetB))
w = torch.tensor(puppetW)
c = torch.tensor(puppetC.astype('float32'))
phi0 = -np.matmul(np.linalg.pinv(np.matmul(B.T, B)), c)

print('Initial loss:', loss_fn(phi0, B, w, c))

puppetPhi = optimize(B, w, c, alpha, learning_rate, n_epochs, phi0 = phi0)

Initial loss: (tensor(0.2598), tensor(0.))
epoch: 0, loss: 0.259824, loss0: 0.259824, loss1: 0.000000
epoch: 10, loss: -3.213919, loss0: -3.623785, loss1: 0.819732
epoch: 20, loss: -4.239977, loss0: -5.169390, loss1: 1.858826
epoch: 30, loss: -4.446993, loss0: -4.594340, loss1: 0.294693


  phi = torch.tensor(phi0, requires_grad=True)


epoch: 40, loss: -4.537025, loss0: -4.869762, loss1: 0.665475
epoch: 50, loss: -4.560396, loss0: -4.773962, loss1: 0.427132
epoch: 60, loss: -4.570121, loss0: -4.817598, loss1: 0.494954
epoch: 70, loss: -4.571414, loss0: -4.826947, loss1: 0.511064
epoch: 80, loss: -4.572386, loss0: -4.820227, loss1: 0.495681
epoch: 90, loss: -4.572742, loss0: -4.816817, loss1: 0.488151
epoch: 100, loss: -4.572875, loss0: -4.820088, loss1: 0.494426
epoch: 110, loss: -4.572927, loss0: -4.820162, loss1: 0.494469
epoch: 120, loss: -4.572935, loss0: -4.820153, loss1: 0.494435
epoch: 130, loss: -4.572943, loss0: -4.820110, loss1: 0.494334
epoch: 140, loss: -4.572946, loss0: -4.820390, loss1: 0.494888
epoch: 150, loss: -4.572947, loss0: -4.820309, loss1: 0.494724
epoch: 160, loss: -4.572948, loss0: -4.820203, loss1: 0.494511
epoch: 170, loss: -4.572948, loss0: -4.820285, loss1: 0.494675
epoch: 180, loss: -4.572947, loss0: -4.820263, loss1: 0.494633
epoch: 190, loss: -4.572948, loss0: -4.820268, loss1: 0.49464

epoch: 1380, loss: -4.571670, loss0: -4.829666, loss1: 0.515993
epoch: 1390, loss: -4.571312, loss0: -4.814532, loss1: 0.486440
epoch: 1400, loss: -4.571578, loss0: -4.817657, loss1: 0.492159
epoch: 1410, loss: -4.572156, loss0: -4.819576, loss1: 0.494840
epoch: 1420, loss: -4.572096, loss0: -4.824182, loss1: 0.504171
epoch: 1430, loss: -4.571859, loss0: -4.813593, loss1: 0.483469
epoch: 1440, loss: -4.571412, loss0: -4.809130, loss1: 0.475436
epoch: 1450, loss: -4.572330, loss0: -4.817333, loss1: 0.490007
epoch: 1460, loss: -4.571836, loss0: -4.815790, loss1: 0.487909
epoch: 1470, loss: -4.571675, loss0: -4.814696, loss1: 0.486042
epoch: 1480, loss: -4.572580, loss0: -4.819160, loss1: 0.493158
epoch: 1490, loss: -4.570213, loss0: -4.814601, loss1: 0.488776
epoch: 1500, loss: -4.570585, loss0: -4.829897, loss1: 0.518625
epoch: 1510, loss: -4.571499, loss0: -4.815258, loss1: 0.487517
epoch: 1520, loss: -4.571613, loss0: -4.827754, loss1: 0.512282
epoch: 1530, loss: -4.572207, loss0: -4.

epoch: 2740, loss: -4.572414, loss0: -4.820496, loss1: 0.496162
epoch: 2750, loss: -4.571550, loss0: -4.822238, loss1: 0.501377
epoch: 2760, loss: -4.571870, loss0: -4.826629, loss1: 0.509517
epoch: 2770, loss: -4.571528, loss0: -4.829626, loss1: 0.516194
epoch: 2780, loss: -4.570375, loss0: -4.829403, loss1: 0.518056
epoch: 2790, loss: -4.572064, loss0: -4.812550, loss1: 0.480972
epoch: 2800, loss: -4.570445, loss0: -4.824446, loss1: 0.508003
epoch: 2810, loss: -4.571569, loss0: -4.827630, loss1: 0.512122
epoch: 2820, loss: -4.571497, loss0: -4.831796, loss1: 0.520599
epoch: 2830, loss: -4.570612, loss0: -4.803189, loss1: 0.465155
epoch: 2840, loss: -4.571908, loss0: -4.829395, loss1: 0.514974
epoch: 2850, loss: -4.571646, loss0: -4.827154, loss1: 0.511017
epoch: 2860, loss: -4.571819, loss0: -4.811970, loss1: 0.480301
epoch: 2870, loss: -4.571195, loss0: -4.821983, loss1: 0.501575
epoch: 2880, loss: -4.571595, loss0: -4.812212, loss1: 0.481235
epoch: 2890, loss: -4.571354, loss0: -4.

epoch: 4080, loss: -4.571589, loss0: -4.817102, loss1: 0.491026
epoch: 4090, loss: -4.572037, loss0: -4.820096, loss1: 0.496116
epoch: 4100, loss: -4.571870, loss0: -4.819670, loss1: 0.495599
epoch: 4110, loss: -4.572518, loss0: -4.826854, loss1: 0.508671
epoch: 4120, loss: -4.570982, loss0: -4.835422, loss1: 0.528880
epoch: 4130, loss: -4.572374, loss0: -4.828042, loss1: 0.511337
epoch: 4140, loss: -4.571621, loss0: -4.832963, loss1: 0.522684
epoch: 4150, loss: -4.572286, loss0: -4.828488, loss1: 0.512404
epoch: 4160, loss: -4.572376, loss0: -4.828560, loss1: 0.512368
epoch: 4170, loss: -4.570720, loss0: -4.835192, loss1: 0.528944
epoch: 4180, loss: -4.572392, loss0: -4.827261, loss1: 0.509737
epoch: 4190, loss: -4.571955, loss0: -4.820251, loss1: 0.496592
epoch: 4200, loss: -4.572191, loss0: -4.821743, loss1: 0.499103
epoch: 4210, loss: -4.571513, loss0: -4.832340, loss1: 0.521654
epoch: 4220, loss: -4.572401, loss0: -4.828672, loss1: 0.512543
epoch: 4230, loss: -4.571321, loss0: -4.

epoch: 5420, loss: -4.570763, loss0: -4.829795, loss1: 0.518066
epoch: 5430, loss: -4.571168, loss0: -4.825376, loss1: 0.508416
epoch: 5440, loss: -4.571281, loss0: -4.814620, loss1: 0.486678
epoch: 5450, loss: -4.570543, loss0: -4.822822, loss1: 0.504556
epoch: 5460, loss: -4.571811, loss0: -4.817170, loss1: 0.490719
epoch: 5470, loss: -4.572544, loss0: -4.826480, loss1: 0.507873
epoch: 5480, loss: -4.572070, loss0: -4.823105, loss1: 0.502071
epoch: 5490, loss: -4.572507, loss0: -4.820226, loss1: 0.495438
epoch: 5500, loss: -4.571560, loss0: -4.810391, loss1: 0.477663
epoch: 5510, loss: -4.572757, loss0: -4.819549, loss1: 0.493585
epoch: 5520, loss: -4.571939, loss0: -4.823161, loss1: 0.502445
epoch: 5530, loss: -4.572112, loss0: -4.819537, loss1: 0.494851
epoch: 5540, loss: -4.572301, loss0: -4.811941, loss1: 0.479280
epoch: 5550, loss: -4.571194, loss0: -4.806458, loss1: 0.470529
epoch: 5560, loss: -4.571943, loss0: -4.808429, loss1: 0.472972
epoch: 5570, loss: -4.571790, loss0: -4.

epoch: 6770, loss: -4.569736, loss0: -4.817362, loss1: 0.495252
epoch: 6780, loss: -4.572436, loss0: -4.825544, loss1: 0.506216
epoch: 6790, loss: -4.570414, loss0: -4.828591, loss1: 0.516356
epoch: 6800, loss: -4.571870, loss0: -4.808792, loss1: 0.473842
epoch: 6810, loss: -4.571437, loss0: -4.810586, loss1: 0.478297
epoch: 6820, loss: -4.572430, loss0: -4.820901, loss1: 0.496941
epoch: 6830, loss: -4.571066, loss0: -4.831594, loss1: 0.521056
epoch: 6840, loss: -4.570831, loss0: -4.832592, loss1: 0.523522
epoch: 6850, loss: -4.570158, loss0: -4.802537, loss1: 0.464759
epoch: 6860, loss: -4.571514, loss0: -4.813082, loss1: 0.483135
epoch: 6870, loss: -4.570515, loss0: -4.822941, loss1: 0.504853
epoch: 6880, loss: -4.572473, loss0: -4.817166, loss1: 0.489386
epoch: 6890, loss: -4.569692, loss0: -4.800554, loss1: 0.461724
epoch: 6900, loss: -4.572157, loss0: -4.823245, loss1: 0.502176
epoch: 6910, loss: -4.572159, loss0: -4.822421, loss1: 0.500525
epoch: 6920, loss: -4.571695, loss0: -4.

epoch: 8180, loss: -4.570862, loss0: -4.809714, loss1: 0.477705
epoch: 8190, loss: -4.571448, loss0: -4.825225, loss1: 0.507554
epoch: 8200, loss: -4.571481, loss0: -4.829051, loss1: 0.515141
epoch: 8210, loss: -4.570644, loss0: -4.835397, loss1: 0.529506
epoch: 8220, loss: -4.571432, loss0: -4.830142, loss1: 0.517420
epoch: 8230, loss: -4.571689, loss0: -4.827305, loss1: 0.511234
epoch: 8240, loss: -4.570326, loss0: -4.815875, loss1: 0.491099
epoch: 8250, loss: -4.571795, loss0: -4.819530, loss1: 0.495470
epoch: 8260, loss: -4.569540, loss0: -4.834781, loss1: 0.530483
epoch: 8270, loss: -4.571944, loss0: -4.817918, loss1: 0.491948
epoch: 8280, loss: -4.571336, loss0: -4.817665, loss1: 0.492657
epoch: 8290, loss: -4.572244, loss0: -4.825256, loss1: 0.506025
epoch: 8300, loss: -4.570405, loss0: -4.833932, loss1: 0.527054
epoch: 8310, loss: -4.572059, loss0: -4.823902, loss1: 0.503686
epoch: 8320, loss: -4.570505, loss0: -4.816269, loss1: 0.491529
epoch: 8330, loss: -4.570855, loss0: -4.

epoch: 9600, loss: -4.570988, loss0: -4.816545, loss1: 0.491114
epoch: 9610, loss: -4.571458, loss0: -4.818417, loss1: 0.493917
epoch: 9620, loss: -4.571364, loss0: -4.828562, loss1: 0.514394
epoch: 9630, loss: -4.571581, loss0: -4.828359, loss1: 0.513555
epoch: 9640, loss: -4.571986, loss0: -4.826712, loss1: 0.509450
epoch: 9650, loss: -4.571504, loss0: -4.827652, loss1: 0.512296
epoch: 9660, loss: -4.571175, loss0: -4.829878, loss1: 0.517407
epoch: 9670, loss: -4.571416, loss0: -4.823259, loss1: 0.503686
epoch: 9680, loss: -4.570232, loss0: -4.837444, loss1: 0.534424
epoch: 9690, loss: -4.571208, loss0: -4.822999, loss1: 0.503581
epoch: 9700, loss: -4.571742, loss0: -4.822541, loss1: 0.501599
epoch: 9710, loss: -4.571404, loss0: -4.828370, loss1: 0.513931
epoch: 9720, loss: -4.570921, loss0: -4.815585, loss1: 0.489327
epoch: 9730, loss: -4.572309, loss0: -4.821625, loss1: 0.498632
epoch: 9740, loss: -4.571109, loss0: -4.833440, loss1: 0.524662
epoch: 9750, loss: -4.571458, loss0: -4.

epoch: 10960, loss: -4.572948, loss0: -4.820258, loss1: 0.494620
epoch: 10970, loss: -4.572948, loss0: -4.820258, loss1: 0.494621
epoch: 10980, loss: -4.572948, loss0: -4.820260, loss1: 0.494624
epoch: 10990, loss: -4.572948, loss0: -4.820261, loss1: 0.494626
epoch: 11000, loss: -4.572948, loss0: -4.820262, loss1: 0.494629
epoch: 11010, loss: -4.572947, loss0: -4.820273, loss1: 0.494652
epoch: 11020, loss: -4.572947, loss0: -4.820355, loss1: 0.494817
epoch: 11030, loss: -4.572948, loss0: -4.820341, loss1: 0.494787
epoch: 11040, loss: -4.572947, loss0: -4.820297, loss1: 0.494699
epoch: 11050, loss: -4.572945, loss0: -4.819919, loss1: 0.493947
epoch: 11060, loss: -4.572924, loss0: -4.819186, loss1: 0.492524
epoch: 11070, loss: -4.572906, loss0: -4.818599, loss1: 0.491387
epoch: 11080, loss: -4.572890, loss0: -4.818303, loss1: 0.490826
epoch: 11090, loss: -4.572740, loss0: -4.816972, loss1: 0.488465
epoch: 11100, loss: -4.572479, loss0: -4.814791, loss1: 0.484623
epoch: 11110, loss: -4.57

epoch: 12300, loss: -4.571896, loss0: -4.817673, loss1: 0.491553
epoch: 12310, loss: -4.571905, loss0: -4.822191, loss1: 0.500572
epoch: 12320, loss: -4.570595, loss0: -4.835054, loss1: 0.528919
epoch: 12330, loss: -4.572484, loss0: -4.824448, loss1: 0.503929
epoch: 12340, loss: -4.570618, loss0: -4.819889, loss1: 0.498541
epoch: 12350, loss: -4.571266, loss0: -4.830893, loss1: 0.519253
epoch: 12360, loss: -4.570727, loss0: -4.815479, loss1: 0.489505
epoch: 12370, loss: -4.571663, loss0: -4.826257, loss1: 0.509187
epoch: 12380, loss: -4.572004, loss0: -4.824990, loss1: 0.505972
epoch: 12390, loss: -4.567651, loss0: -4.832079, loss1: 0.528857
epoch: 12400, loss: -4.569344, loss0: -4.809068, loss1: 0.479450
epoch: 12410, loss: -4.570820, loss0: -4.830682, loss1: 0.519724
epoch: 12420, loss: -4.567487, loss0: -4.797790, loss1: 0.460605
epoch: 12430, loss: -4.562202, loss0: -4.826719, loss1: 0.529034
epoch: 12440, loss: -4.569029, loss0: -4.811690, loss1: 0.485321
epoch: 12450, loss: -4.57

epoch: 13570, loss: -4.572365, loss0: -4.815966, loss1: 0.487200
epoch: 13580, loss: -4.571438, loss0: -4.818594, loss1: 0.494311
epoch: 13590, loss: -4.572085, loss0: -4.825795, loss1: 0.507420
epoch: 13600, loss: -4.571397, loss0: -4.823299, loss1: 0.503805
epoch: 13610, loss: -4.570370, loss0: -4.815857, loss1: 0.490974
epoch: 13620, loss: -4.569895, loss0: -4.836098, loss1: 0.532406
epoch: 13630, loss: -4.571877, loss0: -4.819139, loss1: 0.494524
epoch: 13640, loss: -4.571572, loss0: -4.811135, loss1: 0.479126
epoch: 13650, loss: -4.569212, loss0: -4.833793, loss1: 0.529162
epoch: 13660, loss: -4.572105, loss0: -4.809546, loss1: 0.474882
epoch: 13670, loss: -4.571311, loss0: -4.806335, loss1: 0.470047
epoch: 13680, loss: -4.572438, loss0: -4.823220, loss1: 0.501564
epoch: 13690, loss: -4.570915, loss0: -4.834812, loss1: 0.527794
epoch: 13700, loss: -4.571915, loss0: -4.818663, loss1: 0.493496
epoch: 13710, loss: -4.571968, loss0: -4.818084, loss1: 0.492233
epoch: 13720, loss: -4.57

epoch: 14950, loss: -4.570126, loss0: -4.831983, loss1: 0.523714
epoch: 14960, loss: -4.570904, loss0: -4.809959, loss1: 0.478110
epoch: 14970, loss: -4.571756, loss0: -4.827884, loss1: 0.512256
epoch: 14980, loss: -4.572062, loss0: -4.820181, loss1: 0.496238
epoch: 14990, loss: -4.570926, loss0: -4.830910, loss1: 0.519968
epoch: 15000, loss: -4.572264, loss0: -4.822556, loss1: 0.500585
epoch: 15010, loss: -4.572503, loss0: -4.818846, loss1: 0.492686
epoch: 15020, loss: -4.570514, loss0: -4.814451, loss1: 0.487873
epoch: 15030, loss: -4.572439, loss0: -4.823391, loss1: 0.501903
epoch: 15040, loss: -4.571467, loss0: -4.826743, loss1: 0.510550
epoch: 15050, loss: -4.571028, loss0: -4.813984, loss1: 0.485913
epoch: 15060, loss: -4.569992, loss0: -4.821453, loss1: 0.502922
epoch: 15070, loss: -4.571657, loss0: -4.818894, loss1: 0.494473
epoch: 15080, loss: -4.571755, loss0: -4.822307, loss1: 0.501104
epoch: 15090, loss: -4.571095, loss0: -4.827165, loss1: 0.512139
epoch: 15100, loss: -4.57

epoch: 16240, loss: -4.570992, loss0: -4.810853, loss1: 0.479722
epoch: 16250, loss: -4.571958, loss0: -4.814957, loss1: 0.485998
epoch: 16260, loss: -4.571084, loss0: -4.831178, loss1: 0.520187
epoch: 16270, loss: -4.571920, loss0: -4.821135, loss1: 0.498430
epoch: 16280, loss: -4.570757, loss0: -4.817874, loss1: 0.494235
epoch: 16290, loss: -4.571507, loss0: -4.832957, loss1: 0.522900
epoch: 16300, loss: -4.571799, loss0: -4.827259, loss1: 0.510920
epoch: 16310, loss: -4.569874, loss0: -4.831115, loss1: 0.522482
epoch: 16320, loss: -4.569438, loss0: -4.802249, loss1: 0.465622
epoch: 16330, loss: -4.569266, loss0: -4.831476, loss1: 0.524420
epoch: 16340, loss: -4.568972, loss0: -4.826165, loss1: 0.514386
epoch: 16350, loss: -4.571668, loss0: -4.811975, loss1: 0.480613
epoch: 16360, loss: -4.572369, loss0: -4.818678, loss1: 0.492618
epoch: 16370, loss: -4.571855, loss0: -4.822022, loss1: 0.500334
epoch: 16380, loss: -4.572109, loss0: -4.817900, loss1: 0.491582
epoch: 16390, loss: -4.57

epoch: 17550, loss: -4.572241, loss0: -4.819452, loss1: 0.494422
epoch: 17560, loss: -4.572650, loss0: -4.822188, loss1: 0.499076
epoch: 17570, loss: -4.572865, loss0: -4.818243, loss1: 0.490755
epoch: 17580, loss: -4.572913, loss0: -4.821538, loss1: 0.497251
epoch: 17590, loss: -4.572936, loss0: -4.820558, loss1: 0.495245
epoch: 17600, loss: -4.572943, loss0: -4.820200, loss1: 0.494515
epoch: 17610, loss: -4.572946, loss0: -4.820204, loss1: 0.494516
epoch: 17620, loss: -4.572947, loss0: -4.820302, loss1: 0.494711
epoch: 17630, loss: -4.572947, loss0: -4.820263, loss1: 0.494631
epoch: 17640, loss: -4.572948, loss0: -4.820232, loss1: 0.494569
epoch: 17650, loss: -4.572948, loss0: -4.820259, loss1: 0.494623
epoch: 17660, loss: -4.572947, loss0: -4.820266, loss1: 0.494638
epoch: 17670, loss: -4.572947, loss0: -4.820257, loss1: 0.494620
epoch: 17680, loss: -4.572948, loss0: -4.820257, loss1: 0.494618
epoch: 17690, loss: -4.572948, loss0: -4.820258, loss1: 0.494622
epoch: 17700, loss: -4.57

epoch: 18830, loss: -4.571598, loss0: -4.827481, loss1: 0.511766
epoch: 18840, loss: -4.571452, loss0: -4.831722, loss1: 0.520540
epoch: 18850, loss: -4.572118, loss0: -4.822144, loss1: 0.500051
epoch: 18860, loss: -4.570347, loss0: -4.814365, loss1: 0.488037
epoch: 18870, loss: -4.571607, loss0: -4.826801, loss1: 0.510389
epoch: 18880, loss: -4.571352, loss0: -4.831895, loss1: 0.521084
epoch: 18890, loss: -4.572202, loss0: -4.822357, loss1: 0.500309
epoch: 18900, loss: -4.570286, loss0: -4.813076, loss1: 0.485581
epoch: 18910, loss: -4.571853, loss0: -4.826824, loss1: 0.509941
epoch: 18920, loss: -4.572015, loss0: -4.823412, loss1: 0.502795
epoch: 18930, loss: -4.571161, loss0: -4.815970, loss1: 0.489618
epoch: 18940, loss: -4.571257, loss0: -4.831597, loss1: 0.520679
epoch: 18950, loss: -4.570945, loss0: -4.824257, loss1: 0.506626
epoch: 18960, loss: -4.570615, loss0: -4.817616, loss1: 0.494000
epoch: 18970, loss: -4.571729, loss0: -4.825095, loss1: 0.506733
epoch: 18980, loss: -4.57

In [57]:
B.shape, c.shape

(torch.Size([3660, 1000]), torch.Size([1000, 1]))

In [None]:
np.linalg.norm(puppetPhi.detach().numpy())

Doing some post-hoc illustration- want to see "active edges"

In [11]:
#print(c_sol.shape)

#for edge in range(1846):
#    if abs(c_sol[edge]) < TOLERANCE:
#        c_sol[edge] = 0

# with np.printoptions(threshold=np.inf):
    # print(c_sol)


In [8]:
c_sol = active_edges(puppetPhi, B, w, c)

In [9]:
c_sol = c_sol.detach().numpy()
print(c_sol)

[ 0.08698285  0.091295    0.13666463 ... -0.9477288  -0.7730732
 -0.67467713]


In [10]:
c_sol.shape

(3633,)

In [15]:
from sklearn.manifold import LocallyLinearEmbedding
ltsa_obj = LocallyLinearEmbedding(n_components=INTRINSIC_DIMENSION+1, n_neighbors=10, method='ltsa')

# from umap import UMAP
# umap_obj = UMAP(n_neighbors=13, min_dist=0.5, n_components=3, random_state=42)
puppetGraph_embedding = ltsa_obj.fit_transform(puppetGraph.imageData)



In [None]:
%matplotlib notebook
from mpl_toolkits.mplot3d.art3d import Line3D

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

# ax.set_xlim((-1.1, 1.1))
# ax.set_ylim((-1.1, 1.1))
# ax.set_zlim((-1.1, 1.1))

nodeData = nx.kamada_kawai_layout(puppetGraph)

TOLERANCE = -1e-1
MAX = np.max(c_sol)
MIN = np.min(c_sol)

for edgeIndex, edge in zip(range(puppetGraph.nEdges), list(puppetGraph.edges())):

    if c_sol[edgeIndex] > TOLERANCE:
        col=(0, 0, 1, (c_sol[edgeIndex]-MIN)/(MAX-MIN))
        z = 3
    else:
        #col="tab:blue"
        col=(1,1,1,0)
        z = 1

    #col = (0, 0, 1, 1)
    #col=(0, 0, 1, (c_sol[edgeIndex]-MIN)/(MAX-MIN))
    fromNode = edge[0]
    toNode = edge[1]
    centerFromNode = puppetGraph_embedding[fromNode, :]
    centerToNode = puppetGraph_embedding[toNode, :]

    ax.plot((centerFromNode[0], centerToNode[0]), (centerFromNode[1], centerToNode[1]), (centerFromNode[2], centerToNode[2]), color=col, lw=3)
    # edgeLabel = Line3D((centerFromNode[0], centerToNode[0]), (centerFromNode[1], centerToNode[1]), (centerFromNode[2], centerToNode[2]), color=col, lw=3)
    # edgeLabel.zorder = z
    # ax.add_line(edgeLabel)

# for node in tqdm(list(range(2, N_PUPPET_IMAGES)) + [0, 1]):
#     center = nodeData[node][0], nodeData[node][1]
#     if node in [0,1]:
#         col="tab:red"
#         zNode=2
#         r=1e-1
#         #ax.text(center[0], center[1], str(node+1))
#     else:
#         col="tab:blue"
#         zNode=1
#         r=2e-2
#     nodeLabel = matplotlib.patches.Circle(center, radius=r, color=col, zorder=zNode)
#     ax.add_patch(nodeLabel)

nodeColor = np.zeros(N_PUPPET_IMAGES)
nodeColor[0:2] = 1

#ax.scatter(*umap_puppetGraph_embedding.T, c='k', marker='o')

ax.scatter(*puppetGraph_embedding[0:2,:].T, c='r', s=100, marker='o')
plt.show()
# Want to color "active" edges red. WIP.
# Noticed that CNX has an issue- not properly removing the remote edges. Zeroing the Connection incidence is not enough.


<IPython.core.display.Javascript object>

In [None]:
print(puppetGraph.nEdges * 2)
print(c_sol.shape)
print(puppetPhi.detach().numpy().shape)



In [None]:
np.amax(c_sol)