In [20]:
import numpy as np

size=8000
x=np.random.normal(size=size).astype(np.float32)

%timeit np.sort(x,kind='heapsort')

802 µs ± 7.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
import jax 
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size,))
%timeit jnp.sort(x,kind='heapsort').block_until_ready()

130 µs ± 2.84 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [22]:
x.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [23]:
class abstract_graph:
    
    def __init__(self,_edges):
        self.edges=_edges
        self.nodes={u for u,v in self.edges} | {v for u,v in self.edges}
        
    def adjacency_matrix(self):
        pass
    
    def adjacency_list(self):
        pass

    
class simple_graph(abstract_graph):
    
    def __init__(self,_edges):
        tmp=[]
        for (u,v) in _edges:
            tmp.append((u,v))
            if (v,u) not in tmp and v!=u:
                tmp.append((v,u))
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}
     
    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=jnp.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat=mat.at[i,j].set(1)
        return mat
    
    
    def adjacency_list(self):
        adjacent=lambda n : {v for u,v in self.edges if u==n } 
        return {v:adjacent(v) for v in self.nodes}

  
class weighted_graph(simple_graph):
    
    def __init__(self,_edges):
        tmp=dict()
        for (u,v),w in _edges.items():
            tmp.update({(u,v):w})
            if (v,u) not in tmp.keys() and v!=u:
                tmp.update({(v,u):w})
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}

    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=jnp.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat=mat.at[i,j].set(self.edges[(v,k)])
        return mat

# Algoritmo Kruskal en JAX

In [24]:
def naive_find(C, u):
    while C.at[u].get() != u:
        u = C.at[u].get().item()
    return u

def naive_union(C, u, v):
    u = naive_find(C, u)
    v = naive_find(C, v)
    return C.at[u].set(v)
        
def jax_kruskal(G):
    E = G.edges
    T=[]
    values=jnp.array([*E.values()])
    items=[*E.keys()]
    C= jnp.arange(len(G.nodes))
    i_sort=jnp.argsort(values)
    for ix in i_sort:
        (u,v),weight=items[ix.item()],values[ix.item()].item()
        if naive_find(C, u) != naive_find(C, v): 
            T.append((u,v,{'weight':weight}))
            C=naive_union(C, u, v)
    return T

In [167]:
E={(1,2):4,(2,3):7,(3,4):8,(3,5):10,(4,5):9}
G=weighted_graph(E)
print('aristas : ',G.edges)

aristas :  {(1, 2): 4, (2, 1): 4, (2, 3): 7, (3, 2): 7, (3, 4): 8, (4, 3): 8, (3, 5): 10, (5, 3): 10, (4, 5): 9, (5, 4): 9}


In [168]:
import networkx as nx

G_nx=nx.Graph()
G_nx.add_weighted_edges_from([(u,v,w) for ((u,v),w) in E.items()])
T_nx=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')

In [169]:
T_nx.edges(data=True)

EdgeDataView([(1, 2, {'weight': 4}), (2, 3, {'weight': 7}), (3, 4, {'weight': 8}), (4, 5, {'weight': 9})])

# Algoritmo Prim en JAX

In [155]:
def min_index(visited,frontier):
    min_val=jnp.inf
    min_index=0
    for index in range(len(frontier)):
        if frontier[index].item()<min_val and not visited[index]:
            min_index=index
            min_val=frontier[index].item()
    return min_val,min_index

def jax_prim(mat,start):
    nodes=mat.shape[0]
    tree=[]
    frontier=jnp.array([jnp.inf]*nodes)
    visited=jnp.array([False]*nodes)
    frontier=frontier.at[start].set(0.0)
    u=start
    for i in range(nodes):
        weight,v=min_index(visited,frontier)
        visited=visited.at[v].set(True)
        frontier=jnp.where(frontier<mat[v,:],frontier,jnp.where(mat[v,:]==0,frontier,mat[v,:]))
        tree.append((u,v,{'weight':weight}))
        u=v
    return tree



In [170]:
jax_prim(G.adjacency_matrix(),0)

[(0, 0, {'weight': 0.0}),
 (0, 1, {'weight': 4.0}),
 (1, 2, {'weight': 7.0}),
 (2, 3, {'weight': 8.0}),
 (3, 4, {'weight': 9.0})]

In [172]:
nodes=len(G.nodes)
frontier=jnp.array([jnp.inf]*nodes)
visited=jnp.array([False]*nodes)
frontier=frontier.at[2].set(0.0)

In [178]:
jnp.argmin(frontier)

Array(2, dtype=int32)

In [173]:
min_index(visited,frontier)

(0.0, 2)

# Algoritmo Kruskal en Grafos Aleatorios

In [135]:
import networkx as nx

def gen_random_graph(n,p):
    not_connected=True
    while not_connected:
        G_nx = nx.erdos_renyi_graph(int(n),p,directed=False)
        not_connected=nx.is_connected(G_nx)
        weights={edge:np.random.randint(1,10) for edge in G_nx.edges}
        nx.set_edge_attributes(G_nx, values = weights, name = 'weight')
        break
    return G_nx

In [157]:
G_nx=gen_random_graph(1e2,0.2)

In [158]:
E={(u,v):k['weight'] for (u,v,k) in G_nx.edges(data=True)}
G=weighted_graph(E)

In [159]:
len(E)

1016

In [160]:
import time 

t1=time.time()
T_nx_kruskal=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')
t2=time.time()
print('Algoritmo Kruskal NX Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,sum([k['weight'] for (u,v,k) in T_nx_kruskal.edges(data=True)])))

Algoritmo Kruskal NX Tiempo : 0.003486[s], Peso : 116


In [162]:
t1=time.time()
T_jax_kruskal=jax_kruskal(G)
t2=time.time()
print('Algoritmo Kruskal JAX Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,sum([k['weight'] for (u,v,k) in T_jax_kruskal])))

Algoritmo Kruskal JAX Tiempo : 160.381301[s], Peso : 116


In [161]:
t1=time.time()
T_jax_prim=jax_prim(G.adjacency_matrix(),0)
t2=time.time()
print('Algoritmo Prim JAX Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,sum([k['weight'] for (u,v,k) in T_jax_prim])))

Algoritmo Prim JAX Tiempo : 15.937573[s], Peso : 116.0


In [143]:
import scipy as sp
from scipy.sparse.csgraph import minimum_spanning_tree

M = nx.adjacency_matrix(G_nx)
t1=time.time()
T_sp=minimum_spanning_tree(M)
t2=time.time()
print('Algoritmo Kruskal SciPy Tiempo : {0:2f}[s], Peso : {1}'.format(t2-t1,np.sum(T_sp)))

Algoritmo Kruskal SciPy Tiempo : 0.000527[s], Peso : 117.0
