In [46]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np

import torch
from torch import nn

from gdeep.models import FFNet

from gdeep.visualisation import  persistence_diagrams_of_activations

from torch.utils.tensorboard import SummaryWriter
from gdeep.data import TorchDataLoader


from gtda.diagrams import BettiCurve

from gtda.plotting import plot_betti_surfaces




In [47]:
writer = SummaryWriter()

In [48]:
dl = TorchDataLoader(name="MNIST")
dl_tr, dl_ts = dl.build_dataloader(batch_size=32)

In [49]:
from gdeep.pipeline import Pipeline

model = nn.Sequential(nn.Flatten(), FFNet(0, arch=[28*28, 200, 32, 10]))

In [50]:
from torch.optim import SGD

print(model)
loss_fn = nn.CrossEntropyLoss()

pipe = Pipeline(model, (dl_tr, dl_ts), loss_fn, writer)

# train the model
pipe.train(SGD, 3, lr=0.1)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): FFNet(
    (layer0): Linear(in_features=784, out_features=200, bias=True)
    (layer1): Linear(in_features=200, out_features=32, bias=True)
    (layer2): Linear(in_features=32, out_features=10, bias=True)
  )
)
Epoch 1
-------------------------------
Test Error: 46  [57600/60000]
 Accuracy: 38.4%,                 Avg loss: 0.064992 

Epoch 2
-------------------------------
Test Error: 21  [57600/60000]
 Accuracy: 45.1%,                 Avg loss: 0.062901 

Epoch 3
-------------------------------
Test Error: 56  [57600/60000]
 Accuracy: 56.0%,                 Avg loss: 0.059483 

Done!


In [51]:
from gdeep.models import ModelExtractor

me = ModelExtractor(model, loss_fn)

x = next(iter(dl_tr))[0]
list_activations = me.get_activations(x)
len(list_activations)


6

In [151]:
from gudhi import SimplexTree as ST
import networkx as nx

t = (1,1)
(t[0]+1,t[1]+2)
        

(2, 3)

In [39]:
arch = [28*28, 200, 32, 10]
G = nx.complete_multipartite_graph(28*28, 200)
list(G.edges())

[(0, 784),
 (0, 785),
 (0, 786),
 (0, 787),
 (0, 788),
 (0, 789),
 (0, 790),
 (0, 791),
 (0, 792),
 (0, 793),
 (0, 794),
 (0, 795),
 (0, 796),
 (0, 797),
 (0, 798),
 (0, 799),
 (0, 800),
 (0, 801),
 (0, 802),
 (0, 803),
 (0, 804),
 (0, 805),
 (0, 806),
 (0, 807),
 (0, 808),
 (0, 809),
 (0, 810),
 (0, 811),
 (0, 812),
 (0, 813),
 (0, 814),
 (0, 815),
 (0, 816),
 (0, 817),
 (0, 818),
 (0, 819),
 (0, 820),
 (0, 821),
 (0, 822),
 (0, 823),
 (0, 824),
 (0, 825),
 (0, 826),
 (0, 827),
 (0, 828),
 (0, 829),
 (0, 830),
 (0, 831),
 (0, 832),
 (0, 833),
 (0, 834),
 (0, 835),
 (0, 836),
 (0, 837),
 (0, 838),
 (0, 839),
 (0, 840),
 (0, 841),
 (0, 842),
 (0, 843),
 (0, 844),
 (0, 845),
 (0, 846),
 (0, 847),
 (0, 848),
 (0, 849),
 (0, 850),
 (0, 851),
 (0, 852),
 (0, 853),
 (0, 854),
 (0, 855),
 (0, 856),
 (0, 857),
 (0, 858),
 (0, 859),
 (0, 860),
 (0, 861),
 (0, 862),
 (0, 863),
 (0, 864),
 (0, 865),
 (0, 866),
 (0, 867),
 (0, 868),
 (0, 869),
 (0, 870),
 (0, 871),
 (0, 872),
 (0, 873),
 (0, 874),

In [152]:
activations = list_activations
arch = [28*28, 200, 32, 10]

def get_activation_graph(activations,arch, index_batch = 1):
    n_layer = len(arch)
    current_node = 0
    activation_graph = ST()
    edge_list = []
    for i in range(n_layer - 1):
        f = lambda x: (x[0] + current_node, x[1] + current_node)
        G = nx.complete_bipartite_graph(arch[i], arch[i+1])
        l = list(G.edges())
        l = map(f,l)
        edge_list.extend(l)
        current_node += arch[i]
    for edge in edge_list:
        activation_graph.insert(list(edge), 0.0)
    activations_flatten = torch.empty(0)
    for layer in range(n_layer):
        activations_flatten = torch.cat((activations_flatten, activations[layer][index_batch]))
    for neuron in range(activations_flatten.size()[0]):
        activation_graph.insert([neuron], float(activations_flatten[neuron]))
    return activation_graph
    
    
    
    
        

In [153]:
L = get_activation_graph(activations,arch, index_batch = 1)

In [159]:
pers = L.extended_persistence()

In [163]:
pers[3]

[(1, (-53.13246536254883, -144.86459350585938)),
 (1, (-51.29071044921875, -144.86459350585938)),
 (1, (-47.754272460937514, -144.86459350585938)),
 (1, (-36.95832443237305, -144.86459350585938)),
 (1, (-33.90981292724611, -144.86459350585938)),
 (1, (-28.573806762695312, -144.86459350585938)),
 (1, (-27.518409729003906, -144.86459350585938)),
 (1, (-26.530456542968736, -144.86459350585938)),
 (1, (-26.22930145263672, -144.86459350585938)),
 (1, (-22.03402328491211, -144.86459350585938)),
 (1, (-19.785245895385728, -144.86459350585938)),
 (1, (-16.406009674072266, -144.86459350585938)),
 (1, (-11.309928894042969, -144.86459350585938)),
 (1, (-9.204490661621065, -144.86459350585938)),
 (1, (-9.18233585357666, -144.86459350585938)),
 (1, (-9.168368339538574, -144.86459350585938)),
 (1, (-9.047283172607393, -144.86459350585938)),
 (1, (-4.628385066986084, -144.86459350585938)),
 (1, (-2.2956182956695272, -144.86459350585938)),
 (1, (-1.6573508977890015, -144.86459350585938)),
 (1, (0.0, -

In [101]:
l = [0,1].append([2,3])
l

In [103]:
print(l)

None
