In [1]:
import numpy as np

size=8000
x=np.random.normal(size=size).astype(np.float32)

%timeit np.sort(x,kind='heapsort')

376 µs ± 5.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
import jax 
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size,))
%timeit jnp.sort(x,kind='heapsort').block_until_ready()

2.96 ms ± 13.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
x.device()

CpuDevice(id=0)

In [7]:
class abstract_graph:
    
    def __init__(self,_edges):
        self.edges=_edges
        self.nodes={u for u,v in self.edges} | {v for u,v in self.edges}
        
    def adjacency_matrix(self):
        pass
    
    def adjacency_list(self):
        pass

    
class simple_graph(abstract_graph):
    
    def __init__(self,_edges):
        tmp=[]
        for (u,v) in _edges:
            tmp.append((u,v))
            if (v,u) not in tmp and v!=u:
                tmp.append((v,u))
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}
     
    def adjacency_matrix(self):
        # completar
        n=len(self.nodes)
        mat=np.zeros((n,n))
        for i,v in enumerate(self.nodes):
            for j,k in enumerate(self.nodes):
                if (v,k) in self.edges:
                    mat[i,j]=1
        return mat
    
    
    def adjacency_list(self):
        adjacent=lambda n : {v for u,v in self.edges if u==n } 
        return {v:adjacent(v) for v in self.nodes}

  
class weighted_graph(simple_graph):
    
    def __init__(self,_edges):
        tmp=dict()
        for (u,v),w in _edges.items():
            tmp.update({(u,v):w})
            if (v,u) not in tmp.keys() and v!=u:
                tmp.update({(v,u):w})
        self.edges=tmp
        self.nodes={u for u,v in _edges} | {v for u,v in _edges}

In [8]:
def find(C, u): 
    if C[u] != u:
        C[u] = find(C, C[u])
    return C[u]

def union(C, R, u, v):
    u, v = find(C, u), find(C, v)
    if R[u] > R[v]:
        C[v] = u 
    else:
        C[u] = v
    if R[u] == R[v]:
        R[v] += 1
        
def jax_kruskal(G):
    E = G.edges
    T=[]
    values=jnp.array([*E.values()])
    items=[*E.keys()]
    C,R = {u:u for u in G.nodes}, {u:0 for u in G.nodes}
    sorted_edges=[(items[c.item()],w.item()) for c,w in zip(jnp.argsort(values),jnp.sort(values))] 
    for (u,v),weight in sorted_edges:
        if find(C, u) != find(C, v): 
            T.append((u,v,{'weight':weight}))
            union(C, R, u, v)
    return T

# Algoritmo Kruskal en JAX

In [9]:
E={('a','b'):4,('b','c'):7,('b','d'):8,('c','d'):10,('a','c'):9}
G=weighted_graph(E)
print('aristas : ',G.edges)

aristas :  {('a', 'b'): 4, ('b', 'a'): 4, ('b', 'c'): 7, ('c', 'b'): 7, ('b', 'd'): 8, ('d', 'b'): 8, ('c', 'd'): 10, ('d', 'c'): 10, ('a', 'c'): 9, ('c', 'a'): 9}


In [10]:
T=jax_kruskal(G)

In [11]:
T

[('a', 'b', {'weight': 4}),
 ('b', 'c', {'weight': 7}),
 ('b', 'd', {'weight': 8})]

In [13]:
import networkx as nx

G_nx=nx.Graph()
G_nx.add_weighted_edges_from([(u,v,w) for ((u,v),w) in E.items()])
T_nx=nx.minimum_spanning_tree(G_nx,algorithm='kruskal')

In [14]:
T_nx.edges()

EdgeView([('a', 'b'), ('b', 'c'), ('b', 'd')])

# Algoritmo Prim en JAX

In [49]:
def jax_prim(graph,start):
    nodes=list(G.nodes)
    parents={n:None for n in nodes}
    tree=[]
    frontier=jnp.array([jnp.inf]*len(nodes))
    pos=nodes.index(start)
    frontier=frontier.at[pos].set(0.0)
    adjacency=graph.adjacency_list()
    visited=set()
    for _ in range(len(nodes)+1):
        pos=jnp.argmin(frontier[jnp.where(frontier >= 0)]).item()
        weight=frontier.at[pos].get().item()
        v=nodes[pos]
        print(v,visited,frontier)
        if v in visited:
            continue
        visited.update({v})
        frontier=frontier.at[pos].set(-jnp.inf)
        u=parents[v]
        tree.append((u,v,{'weight':weight}))
        for neighbor in adjacency[v]:
            print(parents)
            print(pos,neighbor)
            if neighbor not in visited:
                n_pos=nodes.index(neighbor)
                n_weight=frontier.at[n_pos].get().item()
                if graph.edges[(v,neighbor)]<n_weight:
                  parents[neighbor]=v
                  frontier=frontier.at[n_pos].set(graph.edges[(v,neighbor)])
    return tree



In [35]:
frontier=jnp.array([jnp.inf]*10)
frontier=frontier.at[3].set(0.0)

In [36]:
frontier

Array([inf, inf, inf,  0., inf, inf, inf, inf, inf, inf], dtype=float32)

In [37]:
jnp.argmin(frontier[jnp.where(frontier>0)]).item()

0

In [38]:
frontier=frontier.at[5].set(3.0)

In [32]:
G.edges

{('a', 'b'): 4,
 ('b', 'a'): 4,
 ('b', 'c'): 7,
 ('c', 'b'): 7,
 ('b', 'd'): 8,
 ('d', 'b'): 8,
 ('c', 'd'): 10,
 ('d', 'c'): 10,
 ('a', 'c'): 9,
 ('c', 'a'): 9}

In [21]:
G.nodes

{'a', 'b', 'c', 'd'}

In [23]:
G.adjacency_list()

{'d': {'b', 'c'}, 'b': {'a', 'c', 'd'}, 'c': {'a', 'b', 'd'}, 'a': {'b', 'c'}}

In [50]:
jax_prim(G,'a')

a set()
{'c': None, 'b': None, 'd': None, 'a': None}
3 c
{'c': 'a', 'b': None, 'd': None, 'a': None}
3 b
b {'a'}
{'c': 'a', 'b': 'a', 'd': None, 'a': None}
1 c
{'c': 'b', 'b': 'a', 'd': None, 'a': None}
1 a
{'c': 'b', 'b': 'a', 'd': None, 'a': None}
1 d
c {'b', 'a'}
{'c': 'b', 'b': 'a', 'd': 'b', 'a': None}
0 b
{'c': 'b', 'b': 'a', 'd': 'b', 'a': None}
0 a
{'c': 'b', 'b': 'a', 'd': 'b', 'a': None}
0 d
c {'c', 'b', 'a'}
c {'c', 'b', 'a'}


[(None, 'a', {'weight': 0.0}),
 ('a', 'b', {'weight': 4.0}),
 ('b', 'c', {'weight': 7.0})]