In [1]:
%load_ext Cython

In [25]:
%%cython -a
#cython: boundscheck=False, wraparound=False, nonecheck=False

from libc.math cimport exp 
import numpy as np

def rbf_networkC(double[:, :] X,  double[:] beta, double theta):

    cdef int N = X.shape[0]
    cdef int D = X.shape[1]
    cdef double[:] Y = np.zeros(N)
    cdef int i, j, d
    cdef double r = 0

    for i in range(N):
        for j in range(N):
            r = 0
            for d in range(D):
                r += (X[j, d] - X[i, d]) ** 2
            r = r**0.5
            Y[i] += beta[j] * exp(-(r * theta)**2)

    return Y

In [24]:
%timeit rbf_networkC(X, beta, theta)

92.1 ms ± 442 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [57]:
%%cython -a

cdef class GraphS:
    cdef int vindex, nedges
    cdef dict graph, ty
    def __init__(self):
        self.graph = dict()
        self.ty = dict()
        self.vindex = 0
        self.nedges = 0

    cdef add_vertices(self, int amount):
        cdef int i
        for i in range(self.vindex, self.vindex + amount):
            self.graph[i] = dict()
            self.ty[i] = 0
        self.vindex += amount

    cdef add_vertex(self):
        self.add_vertices(1)

    cdef add_edges(self, list edges):
        cdef int s,t
        for s,t in edges:
            self.nedges += 1
            self.graph[s][t] = True
            self.graph[t][s] = True

    cdef add_edge(self, tuple edge):
        self.add_edges([edge])

    cdef remove_vertices(self, list vertices):
        cdef int v, v1
        cdef list vs
        for v in vertices:
            vs = list(self.graph[v])
            # remove all edges
            for v1 in vs:
                self.nedges -= 1
                del self.graph[v][v1]
                del self.graph[v1][v]
            # remove the vertex
            del self.graph[v]
            del self.ty[v]

    cdef remove_vertex(self, int vertex):
        self.remove_vertices([vertex])

    cdef remove_solo_vertices(self):
        '''Deletes all vertices that are not connected to any other vertex.
        Should be replaced by a faster alternative if available in the backend'''
        self.remove_vertices([v for v in self.vertices() if self.get_vertex_degree(v)==0])

    cdef remove_edges(self, list edges):
        cdef int s,t
        for s,t in edges:
            self.nedges -= 1
            del self.graph[s][t]
            del self.graph[t][s]

    #cdef remove_edge(self, tuple edge):
    #    self.remove_edge([edge])

    cdef int num_vertices(self):
        return len(self.graph)

    cdef int num_edges(self):
        return self.nedges

    cdef vertices(self):
        return self.graph.keys()

    cdef list edges(self):
        cdef int v0, v1
        cdef dict adj
        cdef list output = []
        for v0,adj in self.graph.items():
            for v1 in adj:
                if v1 > v0: output.append((v0,v1))
        return output

    cdef set edge_set(self):
        '''Returns a set of indices of edges. Should be overloaded if the backend
        supplies a cheaper version than this.'''
        return set(self.edges())

    cdef inline tuple edge_st(self, tuple edge):
        return edge

    cdef get_neighbours(self, int vertex):
        return self.graph[vertex].keys()

    cdef int get_vertex_degree(self, int vertex):
        return len(self.graph[vertex])

    cdef list get_incident_edges(self, int vertex):
        return [(vertex, v1) if v1 > vertex else (v1, vertex) for v1 in self.graph[vertex]]

    cdef bint is_connected(self,int v1,int v2):
        return v2 in self.graph[v1]

    cdef int get_type(self, int vertex):
        return self.ty[vertex]

    cdef dict get_types(self):
        return self.ty

    cdef set_type(self, int vertex, int t):
        self.ty[vertex] = t

cdef list match_bialg_parallel(GraphS g, int num=-1):
    cdef set candidates = g.edge_set()
    cdef dict types = g.get_types()
    cdef int i = 0
    cdef int v0,v1,v0t,v1t,n
    cdef tuple c
    cdef list m = []
    while (num == -1 or i < num) and len(candidates) > 0:
        v0, v1 = g.edge_st(candidates.pop())
        v0t = types[v0]
        v1t = types[v1]
        if ((v0t == 1 and v1t == 2) or (v0t == 2 and v1t == 1)):
            v0n = [n for n in g.get_neighbours(v0) if not n == v1]
            v1n = [n for n in g.get_neighbours(v1) if not n == v0]
            if (
                all([types[n] == v1t for n in v0n]) and
                all([types[n] == v0t for n in v1n])):
                i += 1
                for v in v0n:
                    for c in g.get_incident_edges(v): candidates.discard(c)
                for v in v1n:
                    for c in g.get_incident_edges(v): candidates.discard(c)
                m.append([v0,v1,v0n,v1n])
    return m


cdef bialg(GraphS g, list matches):
    cdef list m, es, del_verts, add_edges, del_edges
    cdef tuple e
    del_verts = []
    add_edges = []
    del_edges = []
    for m in matches:
        del_verts.append(m[0])
        del_verts.append(m[1])
        es = [(i,j) for i in m[2] for j in m[3]]
        for e in es:
            if g.is_connected(e[0], e[1]): del_edges.append(e)
            else: add_edges.append(e)
    
    g.remove_edges(del_edges)
    g.add_edges(add_edges)
    g.remove_vertices(del_verts)
    g.remove_solo_vertices()

cpdef GraphS zigzag(int sz):
    cdef int i
    g = GraphS()
    g.add_vertices(2*sz+4)
    for i in range(1,sz+1):
        g.set_type(2*i, (i%2)+1)
        g.set_type(2*i+1, (i%2)+1)
    g.add_edges([(0,2),(1,3)])
    g.add_edges([(2*i,2*i+2) for i in range(1,sz)])
    g.add_edges([(2*i,2*i+3) for i in range(1,sz)])
    g.add_edges([(2*i+1,2*i+2) for i in range(1,sz)])
    g.add_edges([(2*i+1,2*i+3) for i in range(1,sz)])
    g.add_edges([(2*sz,2*sz+2),(2*sz+1,2*sz+3)])
    return g

cpdef total_reduce(GraphS g):
    print("normalising ZX diagram " + str(g))
    cdef int it
    cdef list m
    it = 0
    while True:
        it += 1
        m = match_bialg_parallel(g)
        if len(m) != 0: bialg(g, m)
        else: break
    print("completed in " + str(it) + " iterations")

In [58]:
g = zigzag(1000000)
%time total_reduce(g)

normalising ZX diagram <_cython_magic_7f300cec609a215475ad974fdc518235.GraphS object at 0x7f8b37d1f708>
completed in 18 iterations
CPU times: user 17.2 s, sys: 1.42 s, total: 18.6 s
Wall time: 18.6 s


In [15]:
from scipy.interpolate import Rbf
rbf = Rbf(X[:,0], X[:,1], X[:,2], X[:,3], X[:, 4], np.random.rand(N))
Xtuple = tuple([X[:, i] for i in range(D)])

In [16]:
%timeit rbf(Xtuple)

435 ms ± 8.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
