In [1]:
import numpy as np

In [44]:
class SpikeSheafNode:
    
    def __init__(self, cells, interval):
        self.l = None
        self.r = None
        self.cells = cells
        self.interval = interval

class SpikeSheaf:
    
    def __init__(self, spikedata):
        self.root = None
        self.layers = []
        self.prepare_base(spikedata)
        self.construct_layers()
        
    def prepare_base(self, spikedata):
        '''
        Sets the base layer 
        '''
        layer_0 = []
        for d in spikedata:
            new_node = SpikeSheafNode(d[0], (d[1], d[1]))
            layer_0.append(new_node)
        self.layers.append(layer_0)
    
    def _layer_n(self, layer):
        layer_n = []
        a=0
        while a+1 < len(layer):
            
            node_l = layer[a]
            node_r = layer[a+1]

            new_node = SpikeSheafNode(node_l.cells | node_r.cells, 
                                     (min(node_l.interval[0], node_r.interval[0]), 
                                      max(node_l.interval[1], node_r.interval[1])))
            new_node.l = node_l
            new_node.r = node_r
            layer_n.append(new_node)
            a+=1
        return layer_n
    
    def construct_layers(self):

        while len(self.layers[-1]) > 1:
            self.layers.append(self._layer_n(self.layers[-1]))
        

In [64]:
sp = [(set(), 0), ({10}, 1), ({11}, 2)]
sp = []
for t in range(1024):
    ncells = np.random.randint(1, high=4)
    cells = set(np.random.choice(101, ncells))
    sp.append((cells, t))
sp

[({62, 94}, 0),
 ({51}, 1),
 ({83}, 2),
 ({8, 29}, 3),
 ({69}, 4),
 ({97}, 5),
 ({8, 91, 98}, 6),
 ({20, 75, 95}, 7),
 ({69, 94}, 8),
 ({34, 44}, 9),
 ({16, 97}, 10),
 ({46, 52, 90}, 11),
 ({19}, 12),
 ({84}, 13),
 ({5}, 14),
 ({84}, 15),
 ({6, 41, 55}, 16),
 ({15, 28, 43}, 17),
 ({49, 89, 98}, 18),
 ({41}, 19),
 ({52, 62, 94}, 20),
 ({62, 76}, 21),
 ({11, 14, 96}, 22),
 ({21}, 23),
 ({4, 41, 65}, 24),
 ({1, 47, 64}, 25),
 ({56, 75}, 26),
 ({60, 72}, 27),
 ({36}, 28),
 ({64}, 29),
 ({41, 44, 64}, 30),
 ({54, 62, 100}, 31),
 ({3, 85}, 32),
 ({22, 45, 69}, 33),
 ({81}, 34),
 ({44}, 35),
 ({93}, 36),
 ({12, 61, 70}, 37),
 ({2, 3, 14}, 38),
 ({7}, 39),
 ({24, 81, 82}, 40),
 ({2, 41, 57}, 41),
 ({23}, 42),
 ({20, 91}, 43),
 ({66}, 44),
 ({52, 93, 98}, 45),
 ({24}, 46),
 ({48, 74}, 47),
 ({5, 10, 20}, 48),
 ({64, 75, 100}, 49),
 ({81, 90}, 50),
 ({9, 68, 97}, 51),
 ({7, 16}, 52),
 ({15, 34, 74}, 53),
 ({9, 21}, 54),
 ({78}, 55),
 ({2, 91}, 56),
 ({81, 97}, 57),
 ({5, 89}, 58),
 ({32}, 59),
 

In [59]:
t =SpikeSheaf(sp)

In [60]:
t.layers

KeyboardInterrupt: 

In [61]:
t.layers[2][0].cells

{16, 29, 34, 53, 59, 73}

In [62]:
for layer in range(22):
    print(t.layers[layer][0].cells)

{29}
{73, 34, 53, 29}
{34, 73, 16, 53, 59, 29}
{34, 73, 91, 16, 53, 59, 29}
{34, 73, 59, 79, 16, 53, 91, 29}
{34, 73, 59, 42, 14, 79, 16, 53, 25, 91, 29}
{73, 76, 14, 79, 16, 25, 91, 29, 34, 98, 42, 53, 59}
{73, 76, 14, 79, 16, 25, 91, 29, 34, 98, 42, 53, 59, 63}
{73, 76, 14, 79, 16, 25, 91, 29, 34, 98, 42, 53, 59, 63}
{73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 96, 34, 98, 42, 53, 59, 63}
{72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 96, 34, 98, 42, 53, 59, 63}
{64, 72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 96, 34, 98, 42, 53, 59, 63}
{64, 72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 93, 96, 34, 98, 42, 53, 59, 63}
{64, 72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 93, 96, 34, 98, 42, 53, 54, 59, 61, 63}
{14, 16, 17, 25, 29, 34, 42, 53, 54, 56, 59, 61, 63, 64, 72, 73, 76, 79, 81, 91, 93, 96, 98, 99}
{14, 16, 17, 25, 29, 34, 42, 53, 54, 56, 59, 61, 63, 64, 72, 73, 76, 79, 81, 91, 93, 96, 98, 99}
{14, 16, 17, 25, 29, 34, 42, 53, 54, 56, 59, 61, 63, 64, 72, 73, 75, 76, 79, 81, 91, 93, 96, 98,

In [82]:
def spike_union(sp1, sp2):
    spu = (sp1[0] | sp2[0], (min(sp1[1][0], sp2[1][0]), max(sp1[1][1], sp2[1][1])))
    return spu

def layer(n, j, spikedata):
    if n ==0:
        return ({spikedata[j][0]}, (spikedata[j][1], spikedata[j][1]))
    else:
        return spike_union(layer(n-1, j, spikedata), layer(n-1, j+1, spikedata))

In [83]:
nspike = 128
maxt = 1024
maxcell = 64
spikedatatest = tuple(zip(np.random.randint(1, high=maxcell, size=nspike), sorted(np.random.randint(1, high=maxt, size=nspike))))
spikedatatest

((44, 34),
 (19, 36),
 (40, 39),
 (48, 43),
 (60, 48),
 (31, 72),
 (29, 77),
 (5, 79),
 (61, 87),
 (44, 94),
 (27, 97),
 (6, 100),
 (23, 106),
 (15, 109),
 (16, 124),
 (16, 130),
 (8, 130),
 (10, 134),
 (58, 135),
 (54, 137),
 (1, 162),
 (20, 178),
 (25, 179),
 (11, 183),
 (9, 194),
 (14, 211),
 (41, 219),
 (31, 224),
 (27, 226),
 (2, 227),
 (52, 232),
 (25, 244),
 (35, 253),
 (27, 264),
 (24, 276),
 (26, 281),
 (1, 289),
 (17, 300),
 (12, 305),
 (9, 307),
 (44, 309),
 (35, 322),
 (5, 323),
 (18, 327),
 (33, 339),
 (53, 340),
 (53, 346),
 (43, 349),
 (1, 352),
 (58, 364),
 (58, 371),
 (10, 385),
 (46, 385),
 (43, 395),
 (41, 403),
 (59, 404),
 (24, 407),
 (54, 422),
 (31, 430),
 (37, 470),
 (26, 482),
 (33, 483),
 (10, 489),
 (59, 497),
 (2, 498),
 (32, 500),
 (55, 503),
 (14, 504),
 (54, 531),
 (36, 546),
 (40, 549),
 (38, 580),
 (6, 586),
 (27, 586),
 (5, 588),
 (23, 631),
 (42, 638),
 (1, 640),
 (21, 655),
 (41, 659),
 (26, 664),
 (62, 674),
 (9, 676),
 (56, 676),
 (26, 691),
 (12, 

In [90]:
layer(16, 43, spikedatatest)

({1, 10, 18, 24, 31, 33, 37, 41, 43, 46, 53, 54, 58, 59}, (327, 470))