In [1]:
import torch
import torch.nn
import torch.optim

import GraphX as gx
import ConnectionGraphX as cgx
import numpy as np
import networkx as nx
import scipy as sp

In [2]:
'''
phi,c: |V|d x 1
B: |E|d x |V|d
w: |E| x 1
c: |V|d x 1
'''
def loss_fn(phi, B, w, c, alpha):
    loss0 = -torch.sum(phi*c)
    
    loss1 = torch.matmul(B, phi).reshape((w.shape[0],-1))
    loss1 = torch.linalg.norm(loss1, dim=1)
    loss1 = loss1 - w
    loss1 = torch.nn.ReLU()(loss1)
    loss1 = torch.sum(loss1**2)
    
    loss = loss0 + (0.5/alpha)*loss1
    return loss

In [75]:
def optimize(B, w, c, alpha, learning_rate, n_epochs, phi0 = None, print_freq=10):
    if phi0 is None:
        phi = torch.randn(B.shape[1], 1, requires_grad=True)
    else:
        phi = torch.tensor(phi0, requires_grad=True)
    optimizer = torch.optim.Adam([phi], lr=learning_rate)
    for epoch in range(n_epochs):
        # Compute loss
        loss = loss_fn(phi, B, w, c, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % print_freq == 0:
            print(f"epoch: {epoch}, loss: {loss:>7f}")
    return phi

# Define B, w and c

In [140]:
NODES = 2
EDGES = 1
seed = 42

a = nx.adjacency_matrix(nx.gnm_random_graph(NODES, EDGES, seed=seed))

# a = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
g = gx.GraphX(sp.sparse.csr_matrix.toarray(a))

DIM_CONNECTION = 2
h = cgx.ConnectionGraphX(sp.sparse.csr_matrix.toarray(a), DIM_CONNECTION)

In [150]:
B = h.connectionIncidenceMatrix.T.astype('float32')
w = np.ones(B.shape[0]//DIM_CONNECTION).astype('float32')

np.random.seed(42)
def rand_prob_mass(n, d):
    mu = np.random.uniform(0, 1, (n, d)).astype('float32')
    mu = mu/(mu.sum(axis=0)[None,:])
    mu = mu.flatten()[:,None]
    return mu

mu = rand_prob_mass(h.nNodes, DIM_CONNECTION)
nu = rand_prob_mass(h.nNodes, DIM_CONNECTION)
c = (mu - nu)

In [151]:
c.reshape((h.nNodes, DIM_CONNECTION))

array([[-0.39023045,  0.46100134],
       [ 0.39023048, -0.46100134]], dtype=float32)

In [152]:
c.reshape((h.nNodes, DIM_CONNECTION)).sum(axis=0)

array([2.9802322e-08, 0.0000000e+00], dtype=float32)

# Check feasibility of B

In [153]:
c_sol, residuals, _, _ = np.linalg.lstsq(B.T, c.flatten())

  c_sol, residuals, _, _ = np.linalg.lstsq(B.T, c.flatten())


In [154]:
residuals

array([4.440892e-16], dtype=float32)

In [155]:
np.linalg.norm(c.flatten() - B.T.dot(c_sol).flatten())

2.9802322e-08

In [162]:
learning_rate = 0.01
alpha = 1e-5
n_epochs = 10000

B = torch.tensor(B)
w = torch.tensor(w)
c = torch.tensor(c)

optimize(B, w, c, alpha, learning_rate, n_epochs)

  B = torch.tensor(B)
  w = torch.tensor(w)
  c = torch.tensor(c)


epoch: 0, loss: 47161.414062
epoch: 10, loss: 29849.808594
epoch: 20, loss: 17027.871094
epoch: 30, loss: 8544.127930
epoch: 40, loss: 3648.445068
epoch: 50, loss: 1260.156494
epoch: 60, loss: 319.896301
epoch: 70, loss: 46.473259
epoch: 80, loss: 1.149476
epoch: 90, loss: 0.457033
epoch: 100, loss: 0.455001
epoch: 110, loss: 0.454158
epoch: 120, loss: 0.453749
epoch: 130, loss: 0.453494
epoch: 140, loss: 0.453292
epoch: 150, loss: 0.453104
epoch: 160, loss: 0.452918
epoch: 170, loss: 0.452728
epoch: 180, loss: 0.452534
epoch: 190, loss: 0.452334
epoch: 200, loss: 0.452128
epoch: 210, loss: 0.451916
epoch: 220, loss: 0.451699
epoch: 230, loss: 0.451476
epoch: 240, loss: 0.451248
epoch: 250, loss: 0.451014
epoch: 260, loss: 0.450775
epoch: 270, loss: 0.450531
epoch: 280, loss: 0.450281
epoch: 290, loss: 0.450027
epoch: 300, loss: 0.449767
epoch: 310, loss: 0.449502
epoch: 320, loss: 0.449232
epoch: 330, loss: 0.448957
epoch: 340, loss: 0.448677
epoch: 350, loss: 0.448393
epoch: 360, los

epoch: 3150, loss: 0.349869
epoch: 3160, loss: 0.349605
epoch: 3170, loss: 0.349341
epoch: 3180, loss: 0.349077
epoch: 3190, loss: 0.348811
epoch: 3200, loss: 0.348546
epoch: 3210, loss: 0.348279
epoch: 3220, loss: 0.348012
epoch: 3230, loss: 0.347745
epoch: 3240, loss: 0.347476
epoch: 3250, loss: 0.347208
epoch: 3260, loss: 0.346938
epoch: 3270, loss: 0.346669
epoch: 3280, loss: 0.346398
epoch: 3290, loss: 0.346128
epoch: 3300, loss: 0.345856
epoch: 3310, loss: 0.345584
epoch: 3320, loss: 0.345311
epoch: 3330, loss: 0.345038
epoch: 3340, loss: 0.344764
epoch: 3350, loss: 0.344490
epoch: 3360, loss: 0.344215
epoch: 3370, loss: 0.343939
epoch: 3380, loss: 0.343663
epoch: 3390, loss: 0.343386
epoch: 3400, loss: 0.343109
epoch: 3410, loss: 0.342831
epoch: 3420, loss: 0.342552
epoch: 3430, loss: 0.342273
epoch: 3440, loss: 0.341994
epoch: 3450, loss: 0.341714
epoch: 3460, loss: 0.341433
epoch: 3470, loss: 0.341151
epoch: 3480, loss: 0.340869
epoch: 3490, loss: 0.340587
epoch: 3500, loss: 0

epoch: 6350, loss: 0.230026
epoch: 6360, loss: 0.262226
epoch: 6370, loss: 0.273990
epoch: 6380, loss: 0.272738
epoch: 6390, loss: 0.266898
epoch: 6400, loss: 0.259424
epoch: 6410, loss: 0.251350
epoch: 6420, loss: 0.243039
epoch: 6430, loss: 0.234618
epoch: 6440, loss: 0.226130
epoch: 6450, loss: 0.254051
epoch: 6460, loss: 0.267526
epoch: 6470, loss: 0.266855
epoch: 6480, loss: 0.261201
epoch: 6490, loss: 0.253772
epoch: 6500, loss: 0.245696
epoch: 6510, loss: 0.237365
epoch: 6520, loss: 0.228918
epoch: 6530, loss: 0.220402
epoch: 6540, loss: 0.257572
epoch: 6550, loss: 0.271852
epoch: 6560, loss: 0.271519
epoch: 6570, loss: 0.266040
epoch: 6580, loss: 0.258730
epoch: 6590, loss: 0.250752
epoch: 6600, loss: 0.242514
epoch: 6610, loss: 0.234157
epoch: 6620, loss: 0.225731
epoch: 6630, loss: 0.217254
epoch: 6640, loss: 0.242811
epoch: 6650, loss: 0.266624
epoch: 6660, loss: 0.269686
epoch: 6670, loss: 0.265451
epoch: 6680, loss: 0.258632
epoch: 6690, loss: 0.250881
epoch: 6700, loss: 0

epoch: 9560, loss: -0.006122
epoch: 9570, loss: -0.013932
epoch: 9580, loss: 0.008430
epoch: 9590, loss: 0.038458
epoch: 9600, loss: 0.044154
epoch: 9610, loss: 0.041300
epoch: 9620, loss: 0.035425
epoch: 9630, loss: 0.028468
epoch: 9640, loss: 0.021107
epoch: 9650, loss: 0.013580
epoch: 9660, loss: 0.005970
epoch: 9670, loss: -0.001694
epoch: 9680, loss: -0.009401
epoch: 9690, loss: -0.017148
epoch: 9700, loss: -0.024935
epoch: 9710, loss: -0.032761
epoch: 9720, loss: 0.043087
epoch: 9730, loss: 0.070658
epoch: 9740, loss: 0.076011
epoch: 9750, loss: 0.073556
epoch: 9760, loss: 0.068343
epoch: 9770, loss: 0.062143
epoch: 9780, loss: 0.055575
epoch: 9790, loss: 0.048856
epoch: 9800, loss: 0.042062
epoch: 9810, loss: 0.035219
epoch: 9820, loss: 0.028338
epoch: 9830, loss: 0.021420
epoch: 9840, loss: 0.014468
epoch: 9850, loss: 0.007480
epoch: 9860, loss: 0.000457
epoch: 9870, loss: -0.006600
epoch: 9880, loss: -0.013693
epoch: 9890, loss: -0.020821
epoch: 9900, loss: -0.027984
epoch: 99

tensor([[ 0.3135],
        [-0.6339],
        [ 0.9110],
        [-0.0839]], requires_grad=True)

# Working example

In [66]:
w = np.ones(1)
d = 2
sigma = np.eye(d)
B = np.block([np.sqrt(w)*np.eye(d), -np.sqrt(w)*sigma])

np.random.seed(42)
mu = np.random.uniform(0, 1, d)
mu = mu/mu.sum()
nu = np.random.uniform(0, 1, d)
nu = nu/nu.sum()
c = mu-nu
optim_val = np.linalg.norm(c)
optim_phi = c/np.linalg.norm(c)
c = np.block([c,-c])
c = c.reshape((2*d,1))

phi0 = np.block([optim_phi,optim_phi*0]).reshape(2*d,1)

print(B)
print(w)
print(c)
print(phi0)
print(optim_val)

B = torch.tensor(B.astype('float32'))
w = torch.tensor(w.astype('float32'))
c = torch.tensor(c.astype('float32'))
phi0 = phi0.astype('float32')

[[ 1.  0. -1. -0.]
 [ 0.  1. -0. -1.]]
[1.]
[[-0.26748401]
 [ 0.26748401]
 [ 0.26748401]
 [-0.26748401]]
[[-0.70710678]
 [ 0.70710678]
 [-0.        ]
 [ 0.        ]]
0.37827952162613965


# Solve Beckmann - connection problem

In [88]:
learning_rate = 0.001
alpha = 1e-1
n_epochs = 10000

In [89]:
optimize(B, w, c, alpha, learning_rate, n_epochs, phi0=phi0)

epoch: 0, loss: -0.378280
epoch: 10, loss: -0.384847
epoch: 20, loss: -0.385245
epoch: 30, loss: -0.385234
epoch: 40, loss: -0.385434
epoch: 50, loss: -0.385405
epoch: 60, loss: -0.385434
epoch: 70, loss: -0.385430
epoch: 80, loss: -0.385434
epoch: 90, loss: -0.385434
epoch: 100, loss: -0.385434
epoch: 110, loss: -0.385434
epoch: 120, loss: -0.385434
epoch: 130, loss: -0.385434
epoch: 140, loss: -0.385434
epoch: 150, loss: -0.385434
epoch: 160, loss: -0.385434
epoch: 170, loss: -0.385434
epoch: 180, loss: -0.385434
epoch: 190, loss: -0.385434
epoch: 200, loss: -0.385434
epoch: 210, loss: -0.385434
epoch: 220, loss: -0.385434
epoch: 230, loss: -0.385434
epoch: 240, loss: -0.385434
epoch: 250, loss: -0.385434
epoch: 260, loss: -0.385434
epoch: 270, loss: -0.385434
epoch: 280, loss: -0.385434
epoch: 290, loss: -0.385434
epoch: 300, loss: -0.385434
epoch: 310, loss: -0.385434
epoch: 320, loss: -0.385434
epoch: 330, loss: -0.385434
epoch: 340, loss: -0.385434
epoch: 350, loss: -0.385434
epo

epoch: 2930, loss: -0.385434
epoch: 2940, loss: -0.385434
epoch: 2950, loss: -0.385434
epoch: 2960, loss: -0.385434
epoch: 2970, loss: -0.385434
epoch: 2980, loss: -0.385434
epoch: 2990, loss: -0.385434
epoch: 3000, loss: -0.385434
epoch: 3010, loss: -0.385434
epoch: 3020, loss: -0.385434
epoch: 3030, loss: -0.385434
epoch: 3040, loss: -0.385434
epoch: 3050, loss: -0.385434
epoch: 3060, loss: -0.385434
epoch: 3070, loss: -0.385434
epoch: 3080, loss: -0.385434
epoch: 3090, loss: -0.385434
epoch: 3100, loss: -0.385434
epoch: 3110, loss: -0.385434
epoch: 3120, loss: -0.385434
epoch: 3130, loss: -0.385434
epoch: 3140, loss: -0.385434
epoch: 3150, loss: -0.385434
epoch: 3160, loss: -0.385434
epoch: 3170, loss: -0.385434
epoch: 3180, loss: -0.385434
epoch: 3190, loss: -0.385434
epoch: 3200, loss: -0.385434
epoch: 3210, loss: -0.385434
epoch: 3220, loss: -0.385434
epoch: 3230, loss: -0.385434
epoch: 3240, loss: -0.385434
epoch: 3250, loss: -0.385434
epoch: 3260, loss: -0.385434
epoch: 3270, l

epoch: 5810, loss: -0.385434
epoch: 5820, loss: -0.385434
epoch: 5830, loss: -0.385434
epoch: 5840, loss: -0.385434
epoch: 5850, loss: -0.385434
epoch: 5860, loss: -0.385434
epoch: 5870, loss: -0.385434
epoch: 5880, loss: -0.385434
epoch: 5890, loss: -0.385434
epoch: 5900, loss: -0.385434
epoch: 5910, loss: -0.385434
epoch: 5920, loss: -0.385434
epoch: 5930, loss: -0.385434
epoch: 5940, loss: -0.385434
epoch: 5950, loss: -0.385434
epoch: 5960, loss: -0.385434
epoch: 5970, loss: -0.385434
epoch: 5980, loss: -0.385434
epoch: 5990, loss: -0.385434
epoch: 6000, loss: -0.385434
epoch: 6010, loss: -0.385434
epoch: 6020, loss: -0.385434
epoch: 6030, loss: -0.385434
epoch: 6040, loss: -0.385434
epoch: 6050, loss: -0.385434
epoch: 6060, loss: -0.385434
epoch: 6070, loss: -0.385434
epoch: 6080, loss: -0.385434
epoch: 6090, loss: -0.385434
epoch: 6100, loss: -0.385434
epoch: 6110, loss: -0.385434
epoch: 6120, loss: -0.385434
epoch: 6130, loss: -0.385434
epoch: 6140, loss: -0.385434
epoch: 6150, l

epoch: 8690, loss: -0.385434
epoch: 8700, loss: -0.385434
epoch: 8710, loss: -0.385434
epoch: 8720, loss: -0.385434
epoch: 8730, loss: -0.385434
epoch: 8740, loss: -0.385434
epoch: 8750, loss: -0.385434
epoch: 8760, loss: -0.385434
epoch: 8770, loss: -0.385434
epoch: 8780, loss: -0.385434
epoch: 8790, loss: -0.385434
epoch: 8800, loss: -0.385434
epoch: 8810, loss: -0.385434
epoch: 8820, loss: -0.385434
epoch: 8830, loss: -0.385434
epoch: 8840, loss: -0.385434
epoch: 8850, loss: -0.385434
epoch: 8860, loss: -0.385434
epoch: 8870, loss: -0.385434
epoch: 8880, loss: -0.385434
epoch: 8890, loss: -0.385434
epoch: 8900, loss: -0.385434
epoch: 8910, loss: -0.385434
epoch: 8920, loss: -0.385434
epoch: 8930, loss: -0.385434
epoch: 8940, loss: -0.385434
epoch: 8950, loss: -0.385434
epoch: 8960, loss: -0.385434
epoch: 8970, loss: -0.385434
epoch: 8980, loss: -0.385434
epoch: 8990, loss: -0.385434
epoch: 9000, loss: -0.385434
epoch: 9010, loss: -0.385434
epoch: 9020, loss: -0.385434
epoch: 9030, l

tensor([[-0.7205],
        [ 0.7205],
        [ 0.0134],
        [-0.0134]], requires_grad=True)