In [1]:
import numpy as np

In [44]:
class SpikeSheafNode:
    
    def __init__(self, cells, interval):
        self.l = None
        self.r = None
        self.cells = cells
        self.interval = interval

class SpikeSheaf:
    
    def __init__(self, spikedata):
        self.root = None
        self.layers = []
        self.prepare_base(spikedata)
        self.construct_layers()
        
    def prepare_base(self, spikedata):
        '''
        Sets the base layer 
        '''
        layer_0 = []
        for d in spikedata:
            new_node = SpikeSheafNode(d[0], (d[1], d[1]))
            layer_0.append(new_node)
        self.layers.append(layer_0)
    
    def _layer_n(self, layer):
        layer_n = []
        a=0
        while a+1 < len(layer):
            
            node_l = layer[a]
            node_r = layer[a+1]

            new_node = SpikeSheafNode(node_l.cells | node_r.cells, 
                                     (min(node_l.interval[0], node_r.interval[0]), 
                                      max(node_l.interval[1], node_r.interval[1])))
            new_node.l = node_l
            new_node.r = node_r
            layer_n.append(new_node)
            a+=1
        return layer_n
    
    def construct_layers(self):

        while len(self.layers[-1]) > 1:
            self.layers.append(self._layer_n(self.layers[-1]))
        

In [58]:
sp = [(set(), 0), ({10}, 1), ({11}, 2)]
sp = []
for t in range(1024):
    ncells = np.random.randint(1, high=4)
    cells = set(np.random.choice(101, ncells))
    sp.append((cells, t))
sp

[({29}, 0),
 ({34, 53, 73}, 1),
 ({16, 59}, 2),
 ({91}, 3),
 ({79}, 4),
 ({14, 25, 42}, 5),
 ({16, 76, 98}, 6),
 ({63, 73}, 7),
 ({14}, 8),
 ({17, 81, 96}, 9),
 ({16, 72}, 10),
 ({64}, 11),
 ({93}, 12),
 ({54, 61, 93}, 13),
 ({29, 56, 99}, 14),
 ({54, 91}, 15),
 ({75}, 16),
 ({40}, 17),
 ({48}, 18),
 ({96}, 19),
 ({34, 38, 98}, 20),
 ({19, 89}, 21),
 ({47, 79}, 22),
 ({40, 63}, 23),
 ({42, 48, 80}, 24),
 ({10, 41, 81}, 25),
 ({73}, 26),
 ({38}, 27),
 ({11, 94}, 28),
 ({32, 44, 87}, 29),
 ({7, 47}, 30),
 ({36, 47, 94}, 31),
 ({13, 54, 57}, 32),
 ({67, 83}, 33),
 ({2, 30, 77}, 34),
 ({97}, 35),
 ({2, 96}, 36),
 ({29, 40, 87}, 37),
 ({22, 56, 66}, 38),
 ({26, 42}, 39),
 ({45, 81}, 40),
 ({22, 28, 97}, 41),
 ({22}, 42),
 ({12, 41}, 43),
 ({20, 91}, 44),
 ({11, 87, 100}, 45),
 ({54}, 46),
 ({76, 96, 98}, 47),
 ({1, 51, 89}, 48),
 ({35}, 49),
 ({21, 64}, 50),
 ({27, 47, 70}, 51),
 ({33, 79}, 52),
 ({47, 52, 88}, 53),
 ({3, 6, 33}, 54),
 ({51}, 55),
 ({9, 10, 79}, 56),
 ({18}, 57),
 ({9, 54, 

In [59]:
t =SpikeSheaf(sp)

In [60]:
t.layers

KeyboardInterrupt: 

In [61]:
t.layers[2][0].cells

{16, 29, 34, 53, 59, 73}

In [62]:
for layer in range(22):
    print(t.layers[layer][0].cells)

{29}
{73, 34, 53, 29}
{34, 73, 16, 53, 59, 29}
{34, 73, 91, 16, 53, 59, 29}
{34, 73, 59, 79, 16, 53, 91, 29}
{34, 73, 59, 42, 14, 79, 16, 53, 25, 91, 29}
{73, 76, 14, 79, 16, 25, 91, 29, 34, 98, 42, 53, 59}
{73, 76, 14, 79, 16, 25, 91, 29, 34, 98, 42, 53, 59, 63}
{73, 76, 14, 79, 16, 25, 91, 29, 34, 98, 42, 53, 59, 63}
{73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 96, 34, 98, 42, 53, 59, 63}
{72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 96, 34, 98, 42, 53, 59, 63}
{64, 72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 96, 34, 98, 42, 53, 59, 63}
{64, 72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 93, 96, 34, 98, 42, 53, 59, 63}
{64, 72, 73, 76, 14, 79, 16, 17, 81, 25, 91, 29, 93, 96, 34, 98, 42, 53, 54, 59, 61, 63}
{14, 16, 17, 25, 29, 34, 42, 53, 54, 56, 59, 61, 63, 64, 72, 73, 76, 79, 81, 91, 93, 96, 98, 99}
{14, 16, 17, 25, 29, 34, 42, 53, 54, 56, 59, 61, 63, 64, 72, 73, 76, 79, 81, 91, 93, 96, 98, 99}
{14, 16, 17, 25, 29, 34, 42, 53, 54, 56, 59, 61, 63, 64, 72, 73, 75, 76, 79, 81, 91, 93, 96, 98,

In [None]:
class SpikeTree:
    def __init__(self, spikes):
        pass
    
    