In [41]:
import numpy as np

size=8000
x=np.random.normal(size=size).astype(np.float32)

%timeit np.sort(x,kind='heapsort')

355 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [40]:
import jax 
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size,))
%timeit jnp.sort(x,kind='heapsort').block_until_ready()

2.97 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [86]:
class abstract_graph:
    
    def __init__(self,_edges):
        self.edges=_edges
        self.nodes={u for u,v in self.edges} | {v for u,v in self.edges}
        
    def adjacency_matrix(self):
        pass
    
    def adjacency_list(self):
        pass

    
class simple_graph(abstract_graph):
    
    def __init__(self,_edges):
        tmp=[]
        for (u,v) in _edges:
            tmp.append((u,v))
            if (v,u) not in tmp and v!=u:
                tmp.append((v,u))
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}
     
    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=np.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat[i,j]=1
        return mat
    
    
    def adjacency_list(self):
        adjacent=lambda n : {v for u,v in self.edges if u==n } 
        return {v:adjacent(v) for v in self.nodes}

  
class weighted_graph(simple_graph):
    
    def __init__(self,_edges):
        tmp=dict()
        for (u,v),w in _edges.items():
            tmp.update({(u,v):w})
            if (v,u) not in tmp.keys() and v!=u:
                tmp.update({(v,u):w})
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}

1. Algoritmo Kruskal en JAX

In [84]:
def find(C, u): 
    if C[u] != u:
        C[u] = find(C, C[u])
    return C[u]

def union(C, R, u, v):
    u, v = find(C, u), find(C, v)
    if R[u] > R[v]:
        C[v] = u 
    else:
        C[u] = v
    if R[u] == R[v]:
        R[v] += 1
        
def jax_kruskal(G):
    E = G.edges
    T=[]
    values=jnp.array([*E.values()])
    items=[*E.keys()]
    C,R = {u:u for u in G.nodes}, {u:0 for u in G.nodes}
    sorted_edges=[(items[c.item()],w.item()) for c,w in zip(jnp.argsort(values),jnp.sort(values))] 
    for (u,v),weight in sorted_edges:
        if find(C, u) != find(C, v): 
            T.append((u, v))
            union(C, R, u, v)
    return T

In [87]:
E={('a','b'):4,('b','c'):7,('b','d'):8,('c','d'):10,('a','c'):9}
G=weighted_graph(E)
print('aristas : ',G.edges)

aristas :  {('a', 'b'): 4, ('b', 'a'): 4, ('b', 'c'): 7, ('c', 'b'): 7, ('b', 'd'): 8, ('d', 'b'): 8, ('c', 'd'): 10, ('d', 'c'): 10, ('a', 'c'): 9, ('c', 'a'): 9}


In [89]:
T=jax_kruskal(G)

In [90]:
T

[('a', 'b'), ('b', 'c'), ('b', 'd')]