In [71]:


## custom
import lovasz_losses as L

## third party
from PIL import Image
from natsort import natsorted

## sys
import random
import time
from glob import glob

## numeric
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

## vis
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits import mplot3d
from matplotlib import collections  as mc
from mpl_toolkits.mplot3d.art3d import Line3DCollection

## notebook
from IPython import display
from tqdm import tqdm_notebook as tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

plt.style.use('ggplot')
plt.style.use('seaborn-colorblind')

## Utils

In [58]:
def colorScale2cmap(domain, range1):
    domain = np.array(domain)
    domain = (domain-domain.min())/(domain.max()-domain.min())
    range1 = np.array(range1)/255.0
    red = [r[0] for r in range1]
    green = [r[1] for r in range1]
    blue = [r[2] for r in range1]
    red = tuple((d,r,r) for d,r in zip(domain, red))
    green = tuple((d,r,r) for d,r in zip(domain, green))
    blue = tuple((d,r,r) for d,r in zip(domain, blue))
    return LinearSegmentedColormap('asdasdas', {'red':red, 'green': green, 'blue':blue})
    

#https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3
def pairwise_distances(x, y=None, w=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is None:
        y = x
        y_t = y.t()
        y_norm = x_norm
    else:
        y_t = y.t()
        y_norm = (y**2).sum(1).view(1, -1)
        
    if w is not None:
        x = x * w    
        y = y * w    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)

x = torch.ones([5,2])
y = torch.zeros([3,2])
pairwise_distances(x,y)

tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])

In [169]:
def file2graph(fn='./facebook/0.edges'):
    with open(fn) as f:
        lines = [l.split()[:2] for l in f.readlines()]
        edges = [tuple(int(i) for i in l) for l in lines]
        nodes = set(sum(edges, ())) ## SLOW?
#         edges += [(-1, n) for n in nodes]
#         nodes.update({-1})
    G = nx.Graph()
    G.add_nodes_from(list(nodes))
    G.add_edges_from(edges)
    return G


def dict2tensor(d, fill=None):
    n = len(d.keys())
    k2i = {k:i for i,k in enumerate(sorted(d.keys()))}
    res = torch.zeros(len(d.keys()), len(d.keys()), device=device)
    for src_node, dst_nodes in d.items():
        for dst_node, distance in dst_nodes.items():
            if fill is not None:
                res[k2i[src_node],k2i[dst_node]] = fill
            else:
                res[k2i[src_node],k2i[dst_node]] = distance
    return res, k2i


def draw_graph_3d(ax, x, G, grad=None):
    ax.scatter(x[:,0], x[:,1], x[:,2])
    # ax.view_init(elev=20.0, azim=0)
    edgeLines = [(x[e0][:3], x[e1][:3]) for e0,e1 in G.edges]
    lc = Line3DCollection(edgeLines, linewidths=1)
    ax.add_collection(lc)
    if grad is not None:
        ax.quiver(x[:,0], x[:,1], x[:,2], 
                 -grad[:,0], -grad[:,1], -grad[:,2], length=4, colors='C1')
    return ax


def plot(x, pred, G, Adj, lossHistory, jaccardHistory, i):
    x = X.detach().cpu().numpy()
    grad = X.grad.data.cpu().numpy()
    print(f'loss: {loss.item()}')
    print(f'max grad: {np.abs(grad).max()}')
    fig = plt.figure(figsize=[14,10])
#         display.clear_output(wait=True)

    ## graph
    if x.shape[1] == 2:
        plt.subplot(231)
        nx.draw_networkx(G, pos={k: x[k2i[k],:2] for k in G.nodes}, font_color='white')
        plt.quiver(x[:,0], x[:,1], 
                   -grad[:,0], -grad[:,1], 
                   units='inches', label=f'neg grad (max={np.linalg.norm(grad, axis=1).max():.2e})')
        plt.axis('equal')
        plt.legend()  
    else:
        ax = fig.add_subplot(2,3,1, projection='3d')
        ax = draw_graph_3d(ax, x, G, grad)
    plt.title('epoch: {}'.format(i))

    ## loss
    plt.subplot(232)
    plt.plot(lossHistory)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    ## loss
    plt.subplot(233)
    plt.plot(jaccardHistory)
    plt.xlabel('Epoch')
    plt.ylabel('Jaccard')

    ## pred vs truth
    plt.subplot(234)
    pdist = pairwise_distances(X)
    pdist = pdist.detach().cpu()
    plt.imshow(pdist.max()-pdist-np.eye(pdist.shape[0]))
    plt.title('max - distance')
    plt.colorbar()

    plt.subplot(235)
    pred = pred.detach().cpu()
    vmax = min(pred.max(), -pred.min())
    cmap = colorScale2cmap([-1, 0, 1], colors)
    plt.imshow(pred, cmap=cmap, vmin=-0.1, vmax=0.1)
    plt.title('Prediction')
    plt.colorbar()

    plt.subplot(236)
    cmap = colorScale2cmap([0, 0.5, 1], colors)
    plt.imshow(Adj.detach().cpu(), cmap=cmap)
    plt.colorbar()
    plt.title('Ground Truth')

    plt.savefig(f'fig/epoch{i}.png')
    plt.close()
#         plt.show()


colors = [
    [44,52,179],
    [0,0,0],
    [174,33,57],
]

## Optimization Procedures

In [324]:
## model
def euclidean_neighbor(x, samples=None, subgraph=None):
    nodeCount = x.shape[0]
    
    if subgraph is not None \
    and samples is None:
        sampleCount = len(subgraph)
        x,y = x[subgraph,:], x[subgraph,:]
        pdist = pairwise_distances(x,y)
        res = torch.zeros([x.shape[0], y.shape[0]], device=device)
        
        for i, [distances, ns] in enumerate(zip(pdist, neighborSizes)):
            k = sampleCount-ns.item()
            topk = distances.topk(k)
            thresh = topk.values[-2:].mean()
            res[i,:] = thresh - pdist[i,:]
            
            
    if samples is not None \
    and subgraph is None:
        x,y = x[samples,:], x
        pdist = pairwise_distances(x, y)
        res = torch.zeros([x.shape[0], y.shape[0]], device=device)
        
        for i, [distances, ns] in enumerate(zip(pdist, neighborSizes)):
            k = nodeCount-ns.item()
            topk = distances.topk(k)
            thresh = topk.values[-2:].mean()
            res[i,:] = thresh - pdist[i,:]
            
            
    if samples is None \
    and subgraph is None:
        pdist = pairwise_distances(x)
        res = torch.zeros([x.shape[0], x.shape[0]], device=device)
     
        nodeCount = x.shape[0]    
        for i, [distances, ns] in enumerate(zip(pdist, neighborSizes)):
            k = nodeCount-ns.item()
            topk = distances.topk(k)
            thresh = topk.values[-2:].mean()
            res[i,:] = thresh - pdist[i,:]
    return res


def jaccardIndex(pred, target):
    intersect = pred*target
    union = (pred+target).clamp(0,1)
    if intersect.sum() == 0:
        return torch.tensor(0.0)
    else:
        return intersect.sum() / union.sum()

## test:
logits = torch.tensor([1.0, 1.0, 1.0], requires_grad=True)
target = torch.tensor([1.0, 1.0, 0.0])
f = L.lovasz_hinge

print('jaccard index:', jaccardIndex(logits, target).item())
print(' jaccard loss:', 1-jaccardIndex(logits, target).item())
print('  lovasz loss:', f(logits, target).item())


## test
# ground_truth = torch.tensor([0.0, 0.0])
# steps = 19
# x,y = np.meshgrid(np.linspace(-2,2,steps).astype('float32'), np.linspace(-2,2,steps).astype('float32'))
# z = []
# for logits in np.c_[x.ravel(), y.ravel()]:
#     logits = torch.tensor(logits, requires_grad=True)
#     loss = f(logits, ground_truth).item()
#     z.append(loss)
# z = np.array(z)

# x = x.reshape([steps,steps])
# y = y.reshape([steps,steps])
# z = z.reshape([steps,steps])

# fig = plt.figure()
# ax = plt.axes(projection='3d')
# ax.plot_surface(x,y,z, cmap='viridis')
# ax.view_init(elev=20.0, azim=210)
# plt.show()

jaccard index: 0.6666666865348816
 jaccard loss: 0.3333333134651184
  lovasz loss: 0.6666666865348816


## generate a graph

In [359]:
%%time

print('generating graph...')
# G = nx.path_graph(10)
G = nx.cycle_graph(10)
# G = nx.balanced_tree(3,3)
# G = nx.connected_watts_strogatz_graph(10,5,0.5)
# G = file2graph('./facebook/0.edges')

print('calculating all pairs shortest path...')
D, k2i = dict2tensor(dict(nx.all_pairs_shortest_path_length(G)))
Adj,_ = dict2tensor(dict(G.adjacency()), fill=1)

print(len(G.nodes), 'nodes')
print('\n\n')

neighborSizes = Adj.sum(dim=1).int()
nodeCount = Adj.shape[0]
eye = torch.eye(Adj.shape[0], device=device)
truth = Adj + eye


generating graph...
calculating all pairs shortest path...
10 nodes



CPU times: user 11.8 ms, sys: 3.09 ms, total: 14.9 ms
Wall time: 13 ms


## Optimize via Stochastic Gradient Descent (SGD)

In [346]:
## delete old# !rm -r fig
# !mkdir fig

In [360]:
##init position
niter = 3000
X = torch.rand(len(G.nodes), 3, requires_grad = True, device=device)
lr=0.1
optimizer = optim.SGD([X], lr=lr)
# optimizer = optim.Adam([X], lr=0.1)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, niter, eta_min=0, last_epoch=-1)
scheduler = None

##functions
sigmoid = nn.Sigmoid()
jaccard_loss = L.lovasz_hinge

lossHistory = []
jaccardHistory = []

In [353]:
isStochastic = True
# isStochastic = False

iterBar = tqdm(range(niter))
for i in iterBar:
    if scheduler is not None:
        scheduler.step()
    
    if X.grad is not None:
        X.grad.data.fill_(0)
        
    if not isStochastic:
        pred = euclidean_neighbor(X)
        loss = jaccard_loss(pred.view(-1), truth.view(-1))
    else:
        sampleSize = 5
        indices = np.random.randint(0, X.shape[0], size=sampleSize)
        pred = euclidean_neighbor(X, samples=indices)
        truth_i = truth[indices, :]
        loss = jaccard_loss(pred.view(-1), truth_i.view(-1))

#         sampleSize = len(G.nodes)#/2
#         seedNodes = [np.random.randint(0,len(G.nodes)) for _ in range(sampleSize)]
#         subgraphNodes = [[i,]+list(G.neighbors(i)) for i in seedNodes]
#         subgraphNodes = list(set(sum(subgraphNodes, [])))
#         pred = euclidean_neighbor(X, subgraph=subgraphNodes)
#         truth_i = truth[subgraphNodes, subgraphNodes]
#         loss = jaccard_loss(pred.view(-1), truth_i.view(-1))
        
    loss.backward()
    optimizer.step()
    
    
    ##debug info
    ##vis
    if i%10==0 or i==niter-1:

        pred = euclidean_neighbor(X)
        jaccard = jaccardIndex((pred>0).float(), truth).item()
        iterBar.set_postfix({'loss': loss.item(), 'jaccard': jaccard})
        
        jaccardHistory.append(jaccard)
        lossHistory.append(loss.item())
#         plot(X, pred, G, Adj+eye, lossHistory, jaccardHistory, i)
        
          

HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))




In [358]:
pred = euclidean_neighbor(X)
plot(X, pred, G, Adj+eye, lossHistory, jaccardHistory, i)

loss: 0.0
max grad: 0.0


In [286]:
X.grad

tensor([[-0.2887,  0.2360, -0.1150],
        [ 0.1402, -0.2046, -0.2420],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0952, -0.0192,  0.3311],
        [-0.0695, -0.3389, -0.0120],
        [ 0.0000,  0.0000,  0.0000],
        [-0.0245, -0.3457,  0.0172],
        [ 0.1473,  0.6725,  0.0207],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]], device='cuda:0')

## animation

In [210]:
# Create a gif from images in fig/

frames = []
imgs = natsorted(glob('fig/*.png'))

for img in imgs:
    new_frame = Image.open(img)
    frames.append(new_frame)

# Save into a GIF file that loops forever
frames[0].save('anim9.gif', format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=60, loop=0)

In [None]:
## notebook animation

# import imageio
# from natsort import natsorted
# from glob import glob

# fig = plt.figure(figsize=[14,10])

# ims = []
# for fn in natsorted(glob('fig/epoch*.png')):
#     im = imageio.imread(fn)
#     im = plt.imshow(im, animated=True)
#     ims.append([im])

# ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True,
#                                 repeat_delay=1000)

# # ani.save('dynamic_images.mp4')

# display.HTML(ani.to_jshtml())
# # plt.show()