In [None]:
import pandas as pd
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import math, random, torch, collections, time, torch.nn.functional as F, networkx as nx, matplotlib.pyplot as plt, numpy as np
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from IPython.display import clear_output
from torch_geometric.utils import to_networkx
from torch_geometric.utils import from_networkx

In [None]:
import sys, os
sys.path.append('../../gnumap/')
from models.train_models import *
from scipy import optimize

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from codecarbon import OfflineEmissionsTracker
from umap_functions import *
from simulation_utils import make_roll

In [None]:
N_NEIGHBOURS = 5

In [None]:
make_roll(c=0.6, v=4, omega=12, n_samples = 2000, n_neighbours = 30,
              a = 2, b = 2, scale=0.5, plot=True, features=None,
              standardize=True)
"""
X: coordinates for swissroll
t: underlying beta samples
new_data: node features
"""

## Data generation

In [None]:
import random
random.seed(12345)
X, t, new_data = make_roll(n_neighbours = N_NEIGHBOURS, scale=0.1, n_samples = 4000, features='coordinates')
# new_data is graph object

In [None]:
x = X[:,0]
y = X[:,1]
z = X[:,2]

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter3D(x, y, z, c=t, cmap='Spectral')

In [None]:
# import umap.umap_ as umap
import umap
import numpy as np
from sklearn.neighbors import kneighbors_graph

X = np.vstack([np.array(x),np.array(y),np.array(z)]).T
A_dist = kneighbors_graph(X, N_NEIGHBOURS, mode='distance', include_self=False)
embedding = umap.UMAP(n_components=2, n_neighbors= 10, min_dist= 0.3).fit_transform(X)
plt.scatter(*embedding.T, s=10, c=t, alpha=0.5, cmap='Spectral')

The number of neighbor regulates how much local information that you want to reflect on the embedding spcaces. The smaller the local n_neighbor, the more local your embedding is. 

## Draw reconstructed graph

In [None]:
# A_dist = kneighbors_graph(X, 5, mode='distance', include_self=False)
### Very sensitive to wrong edges
plt.figure()
nx.draw_networkx(nx.from_scipy_sparse_matrix(A_dist),
                 pos={i:[new_data.x[i,0].numpy(),new_data.x[i,1].numpy()] for i in range(new_data.num_nodes)},
                 # to see the "roll" it should by x, y axis
                 node_color=t, cmap='Spectral', with_labels=False)
plt.show()

n_neighbors = 3 : small k puts more emphasis on the local structure

# Define GNUMAP model

In [None]:
from graph_utils import *
heat_edge, heat_weight = get_weights(new_data, method = 'heat')
heat_edge2, heat_weight2 = get_weights(new_data, method = 'heat', beta = 0.5)
power_edge, power_weight = get_weights(new_data, method = 'power')
lap_edge1, lap_weight1 = get_weights(new_data, method = 'laplacian', alpha = 0.1)
lap_edge2, lap_weight2 = get_weights(new_data, method = 'laplacian', alpha = 0.5)
lap_edge3, lap_weight3 = get_weights(new_data, method = 'laplacian', alpha = 0.9)

In [None]:
new_data.edge_index, new_data.edge_weight

In [None]:
plt.hist(new_data.edge_weight.detach())

In [None]:
plt.hist(heat_weight2.detach())

In [None]:
from models.data_augmentation import *

In [None]:
high_graph_index, high_graph_weights  = get_weights(new_data, neighbours=15, method = 'laplacian', alpha= 0.5)
data_tmp = Data(x=new_data.x, y = new_data.y,
                edge_index= high_graph_index,
                edge_weight= high_graph_weights)
out,_ = random_aug(data_tmp, feat_drop_rate = 0.1, edge_mask_rate = 0.3)
target_graph_index, target_graph_weights  = out.edge_index, out.edge_weight

In [None]:
target_graph_weights.shape

In [None]:
import numpy as np
from carbontracker.tracker import CarbonTracker
import cProfile
import os
import scipy
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import remove_self_loops, negative_sampling
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import to_scipy_sparse_matrix, to_networkx, from_scipy_sparse_matrix
import time
from umap_functions import *
from graph_utils import *
from models.data_augmentation import *

def train_gnumap(data, hid_dim, dim, n_layers=2, target=None,
                 method = 'laplacian', must_propagate=None,
                 norm='normalize', neighbours=15,
                 patience=20, epochs=200, lr=1e-3, wd=1e-2,
                 min_dist=0.1, name_file="1", subsampling=None,
                 alpha = 0.5, spread = 1.0, lambd_corr=1e-2,
                 beta = 1., gnn_type: str = 'symmetric',
                 power = 3,
                 feat_drop_rate = 0.0, edge_mask_rate = 0.0,
                 repulsion_strength=None,
                 local_connectivity=1,
                 device=None, colours=None,
                 verbose = False):


    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    EPS_0 = data.num_edges/ (data.num_nodes ** 2)
    _a, _b = find_ab_params(spread, min_dist) # spread , min_dist given as hyperparameter

    #if torch_geometric.utils.is_undirected(data.edge_index):
    #    new_edge_index, new_edge_attr = torch_geometric.utils.to_undirected(data.edge_index, data.edge_weight)
    #else:
    #    new_edge_index, new_edge_attr = data.edge_index, data.edge_weight

    if torch_geometric.utils.is_undirected(data.edge_index):
        new_edge_index, new_edge_attr = data.edge_index, data.edge_weight
    else:
        new_edge_index, new_edge_attr = torch_geometric.utils.to_undirected(data.edge_index, data.edge_weight)

    #### transform edge index into knn matrix
    knn = []
    for i in range(data.num_nodes):
        knn += [list(np.sort(list(new_edge_attr[(new_edge_index[0]==i) & (new_edge_index[1]!=i )].numpy())))]
    knn_dists = pd.DataFrame(knn).fillna(0).values
    sigmas, rhos = smooth_knn_dist(
        knn_dists,
        float(neighbours),
        local_connectivity=float(local_connectivity),
         )

    # Maybe use the distance as the original distribution?
    vals = [ np.exp(-(np.max([new_edge_attr.numpy()[i] - rhos[new_edge_index[0,i]], 0])) /
                    (sigmas[new_edge_index[0,i]])) for i in range(len(new_edge_attr))]

    #print(np.where(vals > 1e5))
    rows = new_edge_index[0,:].numpy()
    cols = new_edge_index[1,:].numpy()
    vals = np.array(vals)
    vals[vals<1e-5] = 0

    high = []
    for i in range(data.num_nodes):
        high.append(
            np.insert((new_edge_attr[(new_edge_index[0]==i) & (new_edge_index[1]!=i )].numpy())/new_edge_attr[(new_edge_index[0] == i) & (new_edge_index[1] != i)].sum().numpy(),0,0)
        )
    # highs = np.hstack(high)
    highs  = new_edge_attr/data.num_edges
    p =[]
    for i in range(data.num_nodes):
        p.append(
            highs[(new_edge_index[0] == i) & (new_edge_index[1] != i)].sum().numpy()
        )
    eta = data.edge_weight
    for i in range(len(data.edge_weight)):
        eta[i] = (p[data.edge_index[0,i]]+p[data.edge_index[1,i]])/2*data.x.shape[0]

    result = scipy.sparse.coo_matrix(
  #      (vals, (rows, cols)), shape=(X.shape[0], X.shape[0])
         (highs, (rows, cols)), shape=(X.shape[0], X.shape[0])
    )
    result.eliminate_zeros()
    # target_graph_index, target_graph_weights = from_scipy_sparse_matrix(result)
    high_graph_index, high_graph_weights  = get_weights(data, neighbours=neighbours, method = method, alpha= alpha, beta = beta,  power = power)
    data_tmp = Data(x=data.x, y = data.y,
                    edge_index= high_graph_index,
                    edge_weight= high_graph_weights)
    out, _ = random_aug(data_tmp, feat_drop_rate = feat_drop_rate, edge_mask_rate = edge_mask_rate)
    target_graph_index, target_graph_weights  = out.edge_index, out.edge_weight

    #### Prune
    EPS = 1e-29 #math.exp(-1.0/(2*_b) * math.log(1.0/_a * (1.0/EPS_0 -1)))
    print("Epsilon is " + str(EPS))
    print("Hyperparameters a = " + str(_a) + " and b = " + str(_b))



    model = GNN(data.num_features, hid_dim, dim, n_layers=n_layers,
                must_propagate=must_propagate,
                norm=norm)
    model = model.to(device)
    model.apply(init_weights)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                             weight_decay=wd)
    new_data = Data(x=out.x, edge_index=target_graph_index, # if feature drops the feature matrix should be updated
                    y=data.y, edge_weight=target_graph_weights)
    sparsity =  new_data.num_edges/(new_data.num_nodes**2 -new_data.num_nodes)
    if repulsion_strength is None:
        repulsion_strength = 1.0/sparsity
        # we have way more samples that are "not" connected(sparsity), so need to give more weight to negative sampling to get balanced results
    row_pos, col_pos =  new_data.edge_index
    index = (row_pos != col_pos)
        # to exclude self-connectivity
    edge_weights_pos = new_data.edge_weight#[index]

    if target is not None:
        edge_weights_pos = fast_intersection(row_pos[index], col_pos[index], edge_weights_pos,
                                             target, unknown_dist=1.0, far_dist=5.0)
        # p_{ij}
    if subsampling is None:
        row_neg, col_neg = negative_sampling(new_data.edge_index, num_neg_samples = 5 * new_data.edge_index.shape[1] )
        # m = 5
        index_neg = (row_neg != col_neg)
        # edge_weights_neg = EPS * torch.ones(len(row_neg))
        edge_weights_neg = m*torch.ones(len(row_neg))
        if target is not None:
            edge_weights_neg = fast_intersection(row_neg[index_neg], col_neg[index_neg], edge_weights_neg,
                                                 target, unknown_dist=1.0, far_dist=5.0)
    best_t=0
    cnt_wait = 0
    best=1e9
    log_sigmoid = torch.nn.LogSigmoid()
    edges = [(e[0],e[1]) for _, e in enumerate(data.edge_index.numpy().T)]
    for epoch in range(epochs):
        tic_epoch = time.time()
        model.train()
        optimizer.zero_grad()
        tic = time.time()
        out = model(data.x.float(), new_data.edge_index) # data.edge_index?
        diff_norm = torch.sum(torch.square(out[row_pos[index]] - out[col_pos[index]]), 1)
        diff_norm = torch.clip(diff_norm, min=1e-3)
        log_q = - torch.log1p(_a *  diff_norm ** _b) # 1/(1+a*d^2b)
        # log_q = - torch.log1p(1+ diff_norm)
        loss_pos = - torch.mean(edge_weights_pos[index] * log_sigmoid(log_q)) - torch.mean((1. - edge_weights_pos[index]) *  (log_sigmoid(log_q) - log_q ) * repulsion_strength)
        # log(q/(q+1))


        if subsampling is None:
            diff_norm_neg = torch.sum(torch.square(out[row_neg[index_neg]] - out[col_neg[index_neg]]), 1) #+ 1e-3
            diff_norm_neg = torch.clip(diff_norm_neg, min=1e-3)
            log_q_neg = - torch.log1p(_a *  diff_norm_neg ** _b)
            # log_q_neg = - torch.log1p(1+ diff_norm_neg)
        else:
            row_neg, col_neg = negative_sampling(new_data.edge_index,
                                                 num_neg_samples=subsampling)
            index_neg = (row_neg != col_neg)
            edge_weights_neg = EPS * torch.ones(len(row_neg))
            if target is not None:
                edge_weights_neg = fast_intersection(row_neg[index_neg],
                                                     col_neg[index_neg], edge_weights_neg,
                                                     target, unknown_dist=1.0, far_dist=5.0)
            diff_norm_neg = torch.sum(torch.square(out[row_neg[index_neg]] - out[col_neg[index_neg]]), 1) #+ 1e-3
            diff_norm_neg = torch.clip(diff_norm_neg, min=1e-3)
            log_q_neg = - torch.log1p(_a *  diff_norm_neg ** _b)
            # log_q_neg = - torch.log1p(1+ diff_norm_neg)
        print("loss before neg", loss_pos)
        loss_neg = - torch.mean((log_sigmoid(log_q_neg) - log_q_neg ) * repulsion_strength)
        print("loss after neg", loss_neg)
        ### Add a term to make sure that the features are learned independently
        c1 = torch.mm(out.T, out)
        c1 = c1 / out.shape[0]
        iden = torch.tensor(np.eye(out.shape[1])).to(device)
        loss_dec1 = (torch.diag_embed(c1) - c1).pow(2).sum()
        loss = loss_pos + loss_neg +  lambd_corr * loss_dec1
        print("loss corr", lambd_corr * loss_dec1)
        print("loss final", loss)
        tic =  time.time()
        loss.backward()
        #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=4)
        optimizer.step()

        if verbose is True:
            if epoch%10== 0:
                u = out.detach().numpy()
                plt.figure()
                plt.scatter(u[:,0], u[:,1], c = t,
                            cmap="Spectral")
                plt.show()
                print(torch.mm(out.T, out)/ new_data.num_nodes)

        for g in optimizer.param_groups:
            g['lr'] = lr * (1.0 - (float(epoch) / float(epochs)))

        print('Epoch={:03d}, loss={:.4f}, time={:.4f}'.format(epoch, loss.item(),time.time()-tic_epoch))
        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(model.state_dict(), os.getcwd()  + '/results/best_gnumap_'
                                          + str(method) + '_neigh' + str(neighbours)
                                          + '_dim' + str(dim) + '_' + name_file +  '.pkl')
        else:
            cnt_wait += 1
        if cnt_wait == patience and epoch>50:
            print('Early stopping at epoch {}!'.format(epoch))
            break
        #print("Time epoch after saving", time.time()-tic_epoch)
    #tracker.stop()
    print('Loading {}th epoch'.format(best_t))
    model.load_state_dict(torch.load(os.getcwd()  + '/results/best_gnumap_' +
                                     str(method) + '_neigh' + str(neighbours)
                                     + '_dim' + str(dim) + '_' + name_file + '.pkl'))
    return(model,target_graph_index, vals, knn_dists)

## Test with various weight augmentations

### (d) Heat diffusion N_neighbours = 5

In [None]:
#xx = torch.ones((X.shape[0], 10))
# new_data2.x =  new_data.x[:,:2] # leave out z-axis
model4, target_index,_,_  =  train_gnumap(new_data,
                                     target=None, hid_dim=256, dim=2,
                                     n_layers=1, must_propagate= [True, True, True, True, True],
                                     method = 'laplacian', alpha=0.5,
                                     gnn_type='symmetric', repulsion_strength=10.,
                                     norm='standardize', neighbours=5,
                                     beta=1, patience=20, epochs=1000,
                                     lr=1e-2,
                                     wd=1e-4,
                                     lambd_corr=1.,
                                     min_dist=0.001,subsampling=100000)
out = model4(new_data.x.float(), new_data.edge_index)
plt.figure()
plt.hist(out.detach().numpy())
plt.show()
u = out.detach().numpy()
plt.figure()
plt.scatter(u[:,0], u[:,1], c = t,
            cmap="Spectral")
plt.show()
print(torch.mm(out.T, out)/ new_data.num_nodes)

In [None]:
def focal_loss(gamma, alpha, x, p):
    loss = torch.sum(-((1-p)**gamma)*x*torch.log(p)-(log(p)**gamma)*(1-x)*torch.log(1-p))
    return loss

In [None]:
## from repo provided by the author
import torch


def align_loss(x, y, alpha=2):
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()


def uniform_loss(x, t=2):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()