In [1]:
import numpy as np

size=8000
x=np.random.normal(size=size).astype(np.float32)

%timeit np.sort(x,kind='heapsort')

812 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [2]:
import jax 
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size,))
%timeit jnp.sort(x,kind='heapsort').block_until_ready()

2023-06-09 12:37:13.107955: W external/xla/xla/service/gpu/nvptx_compiler.cc:564] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.1.105). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


252 µs ± 60.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
x.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [4]:
class abstract_graph:
    
    def __init__(self,_edges):
        self.edges=_edges
        self.nodes={u for u,v in self.edges} | {v for u,v in self.edges}
        
    def adjacency_matrix(self):
        pass
    
    def adjacency_list(self):
        pass

    
class simple_graph(abstract_graph):
    
    def __init__(self,_edges):
        tmp=[]
        for (u,v) in _edges:
            tmp.append((u,v))
            if (v,u) not in tmp and v!=u:
                tmp.append((v,u))
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}
     
    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=jnp.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat=mat.at[i,j].set(1)
        return mat
    
    
    def adjacency_list(self):
        adjacent=lambda n : {v for u,v in self.edges if u==n } 
        return {v:adjacent(v) for v in self.nodes}

  
class weighted_graph(simple_graph):
    
    def __init__(self,_edges):
        tmp=dict()
        for (u,v),w in _edges.items():
            tmp.update({(u,v):w})
            if (v,u) not in tmp.keys() and v!=u:
                tmp.update({(v,u):w})
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}

    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=jnp.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat=mat.at[i,j].set(self.edges(v,k))
        return mat

# Algoritmo Kruskal en JAX

In [54]:
def naive_find(C, u):
    while C.at[u].get() != u:
        u = C.at[u].get().item()
    return u

def naive_union(C, u, v):
    u = naive_find(C, u)
    v = naive_find(C, v)
    return C.at[u].set(v)
        
def jax_kruskal(G):
    E = G.edges
    T=[]
    values=jnp.array([*E.values()])
    items=[*E.keys()]
    C= jnp.arange(len(G.nodes))
    i_sort=jnp.argsort(values)
    for ix in i_sort:
        (u,v),weight=items[ix.item()],values[ix.item()].item()
        if jitted_naive_find(C, u) != jitted_naive_find(C, v): 
            T.append((u,v,{'weight':weight}))
            C=naive_union(C, u, v)
    return T

In [65]:
from jax.lax import while_loop

#def while_loop(cond_fun, body_fun, init_val)

def cond_fun(C,u): 
    return C.at[u].get() != u

def body_fun(C,u):
    return C.at[u].get().item()

def jitted_naive_find(C, u):
    while cond_fun(C,u):
        u = jax.jit(body_fun,static_argnames=['C'])(C,u)
    return u

C=jnp.arange(10)
naive_find(C,3)

3

In [40]:
key = jax.random.PRNGKey(0)
C=jax.random.permutation(key,jnp.arange(100))
C=C.at[1].set(1)

In [46]:
C[7]

Array(19, dtype=int32)

In [52]:
%timeit naive_find(C,3)

874 µs ± 18.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [53]:
%timeit jitted_naive_find(C,3)

868 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [63]:
E={(1,2):4,(2,3):7,(3,4):8,(3,5):10,(4,5):9}
G=weighted_graph(E)
print('aristas : ',G.edges)

aristas :  {(1, 2): 4, (2, 1): 4, (2, 3): 7, (3, 2): 7, (3, 4): 8, (4, 3): 8, (3, 5): 10, (5, 3): 10, (4, 5): 9, (5, 4): 9}


In [66]:
T=jax_kruskal(G)

ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'body_fun' while trying to hash an object of type <class 'jaxlib.xla_extension.ArrayImpl'>, [0 2 2 3 4]. The error was:
TypeError: unhashable type: 'ArrayImpl'


In [139]:
T

[(1, 2, {'weight': 4}), (2, 3, {'weight': 7}), (3, 4, {'weight': 8})]

In [140]:
import networkx as nx

G_nx=nx.Graph()
G_nx.add_weighted_edges_from([(u,v,w) for ((u,v),w) in E.items()])
T_nx=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')

In [141]:
T_nx.edges(data=True)

EdgeDataView([(1, 2, {'weight': 4}), (2, 3, {'weight': 7}), (3, 4, {'weight': 8}), (4, 5, {'weight': 9})])

# Algoritmo Prim en JAX

In [None]:
def jax_prim(graph,start):
    nodes=list(G.nodes)
    parents={n:None for n in nodes}
    tree=[]
    frontier=jnp.array([jnp.inf]*len(nodes))
    pos=nodes.index(start)
    frontier=frontier.at[pos].set(0.0)
    adjacency=graph.adjacency_list()
    visited=set()
    for _ in range(len(nodes)+1):
        pos=jnp.argmin(frontier).item()
        weight=frontier.at[pos].get().item()
        v=nodes[pos]
        if v in visited:
            continue
        visited.update({v})
        frontier=frontier.at[pos].set(jnp.inf)
        u=parents[v]
        tree.append((u,v,{'weight':weight}))
        for neighbor in adjacency[v]:
            if neighbor not in visited:
                n_pos=nodes.index(neighbor)
                n_weight=frontier.at[n_pos].get().item()
                if graph.edges[(v,neighbor)]<n_weight:
                  parents[neighbor]=v
                  frontier=frontier.at[n_pos].set(graph.edges[(v,neighbor)])
    return tree



In [142]:
jax_prim(G,1)

[(None, 1, {'weight': 0.0}),
 (1, 2, {'weight': 4.0}),
 (2, 3, {'weight': 7.0}),
 (3, 4, {'weight': 8.0}),
 (4, 5, {'weight': 9.0})]

# Algoritmo Kruskal en Grafos Aleatorios

In [56]:
import networkx as nx

def gen_random_graph(n,p):
    not_connected=True
    while not_connected:
        G_nx = nx.erdos_renyi_graph(int(n),p,directed=False)
        not_connected=nx.is_connected(G_nx)
        weights={edge:np.random.randint(1,10) for edge in G_nx.edges}
        nx.set_edge_attributes(G_nx, values = weights, name = 'weight')
        break
    return G_nx

In [57]:
G_nx=gen_random_graph(1e2,0.2)

In [58]:
E={(u,v):k['weight'] for (u,v,k) in G_nx.edges(data=True)}
G=weighted_graph(E)

In [59]:
len(E)

1044

In [60]:
import time 

t1=time.time()
T_nx_kruskal=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')
t2=time.time()
print('Algoritmo Kruskal NX Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,sum([k['weight'] for (u,v,k) in T_nx_kruskal.edges(data=True)])))

Algoritmo Kruskal NX Tiempo : 0.006311[s], Peso : 113


In [61]:
t1=time.time()
T_jax_kruskal=jax_kruskal(G)
t2=time.time()
print('Algoritmo Kruskal JAX Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,sum([k['weight'] for (u,v,k) in T_jax_kruskal])))

TypeError: Expected a callable value, got 83

In [155]:
import scipy as sp
from scipy.sparse.csgraph import minimum_spanning_tree

M = nx.adjacency_matrix(G_nx)
t1=time.time()
T_sp=minimum_spanning_tree(M)
t2=time.time()
print('Algoritmo Kruskal SciPy Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,np.sum(T_sp)))

Algoritmo Kruskal SciPy Tiempo : 0.000951[s], Peso : 108.0
