In [1]:
import numpy as np

size=8000
x=np.random.normal(size=size).astype(np.float32)

%timeit np.sort(x,kind='heapsort')

778 µs ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [2]:
import jax 
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size,))
%timeit jnp.sort(x,kind='heapsort').block_until_ready()

2023-06-05 10:38:36.609534: W external/xla/xla/service/gpu/nvptx_compiler.cc:564] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.1.105). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


172 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
x.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [4]:
class abstract_graph:
    
    def __init__(self,_edges):
        self.edges=_edges
        self.nodes={u for u,v in self.edges} | {v for u,v in self.edges}
        
    def adjacency_matrix(self):
        pass
    
    def adjacency_list(self):
        pass

    
class simple_graph(abstract_graph):
    
    def __init__(self,_edges):
        tmp=[]
        for (u,v) in _edges:
            tmp.append((u,v))
            if (v,u) not in tmp and v!=u:
                tmp.append((v,u))
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}
     
    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=jnp.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat.at[i,j].set(1)
        return mat
    
    
    def adjacency_list(self):
        adjacent=lambda n : {v for u,v in self.edges if u==n } 
        return {v:adjacent(v) for v in self.nodes}

  
class weighted_graph(simple_graph):
    
    def __init__(self,_edges):
        tmp=dict()
        for (u,v),w in _edges.items():
            tmp.update({(u,v):w})
            if (v,u) not in tmp.keys() and v!=u:
                tmp.update({(v,u):w})
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}

In [49]:

from functools import partial
from jax.tree_util import tree_structure

#@partial(jax.jit, static_argnames=['u'])
def naive_find(C, u):
    while C[u] != u:
        u = C[u]
    return u

def naive_union(C, u, v):
    u = naive_find(C, u)
    v = naive_find(C, v)
    C[u] = v
        
def jax_kruskal(G):
    E = G.edges
    T=[]
    values=jnp.array([*E.values()])
    items=[*E.keys()]
    C,R = tree_structure({u:u for u in G.nodes}), {u:0 for u in G.nodes}
    sorted_edges=[(items[c.item()],w.item()) for c,w in zip(jnp.argsort(values),jnp.sort(values))] 
    for (u,v),weight in sorted_edges:
        if naive_find(C, u) != naive_find(C, v): 
            T.append((u,v,{'weight':weight}))
            naive_union(C, u, v)
    return T

# Algoritmo Kruskal en JAX

In [50]:
E={('a','b'):4,('b','c'):7,('b','d'):8,('c','d'):10,('a','c'):9}
G=weighted_graph(E)
print('aristas : ',G.edges)

aristas :  {('a', 'b'): 4, ('b', 'a'): 4, ('b', 'c'): 7, ('c', 'b'): 7, ('b', 'd'): 8, ('d', 'b'): 8, ('c', 'd'): 10, ('d', 'c'): 10, ('a', 'c'): 9, ('c', 'a'): 9}


In [51]:
T=jax_kruskal(G)

TypeError: Cannot interpret value of type <class 'jaxlib.xla_extension.pytree.PyTreeDef'> as an abstract array; it does not have a dtype attribute

In [46]:
T

[('a', 'b', {'weight': 4}),
 ('b', 'c', {'weight': 7}),
 ('b', 'd', {'weight': 8})]

In [10]:
type(G.adjacency_matrix())

jaxlib.xla_extension.ArrayImpl

In [12]:
import networkx as nx

G_nx=nx.Graph()
G_nx.add_weighted_edges_from([(u,v,w) for ((u,v),w) in E.items()])
T_nx=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')

In [13]:
T_nx.edges()

EdgeView([('a', 'b'), ('b', 'c'), ('b', 'd')])

# Algoritmo Prim en JAX

In [14]:
def jax_prim(graph,start):
    nodes=list(G.nodes)
    parents={n:None for n in nodes}
    tree=[]
    frontier=jnp.array([jnp.inf]*len(nodes))
    pos=nodes.index(start)
    frontier=frontier.at[pos].set(0.0)
    adjacency=graph.adjacency_list()
    visited=set()
    for _ in range(len(nodes)+1):
        pos=jnp.argmin(frontier).item()
        weight=frontier.at[pos].get().item()
        v=nodes[pos]
        if v in visited:
            continue
        visited.update({v})
        frontier=frontier.at[pos].set(jnp.inf)
        u=parents[v]
        tree.append((u,v,{'weight':weight}))
        for neighbor in adjacency[v]:
            if neighbor not in visited:
                n_pos=nodes.index(neighbor)
                n_weight=frontier.at[n_pos].get().item()
                if graph.edges[(v,neighbor)]<n_weight:
                  parents[neighbor]=v
                  frontier=frontier.at[n_pos].set(graph.edges[(v,neighbor)])
    return tree



In [15]:
jax_prim(G,'a')

[(None, 'a', {'weight': 0.0}),
 ('a', 'b', {'weight': 4.0}),
 ('b', 'c', {'weight': 7.0}),
 ('b', 'd', {'weight': 8.0})]

In [16]:
import networkx as nx

def gen_random_graph(n,p):
    not_connected=True
    while not_connected:
        G_nx = nx.erdos_renyi_graph(int(n),p,directed=False)
        not_connected=nx.is_connected(G_nx)
        weights={edge:np.random.randint(1,10) for edge in G_nx.edges}
        nx.set_edge_attributes(G_nx, values = weights, name = 'weight')
        break
    return G_nx

In [17]:
G_nx=gen_random_graph(1e2,0.8)

In [18]:
E={(u,v):k['weight'] for (u,v,k) in G_nx.edges(data=True)}
G=weighted_graph(E)

In [19]:
%timeit T_nx_kruskal=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')

8.25 ms ± 497 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%timeit T_jax_kruskal=jax_kruskal(G)

1.63 s ± 61.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
T_nx_kruskal=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')

In [35]:
sum([k['weight'] for (u,v,k) in T_nx_kruskal.edges(data=True)])

99

In [26]:
T_jax_kruskal=jax_kruskal(G)

In [31]:
sum([k['weight'] for (u,v,k) in T_jax_kruskal])

99