In [98]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np

import torch
from torch import nn

from gdeep.models import FFNet

from gdeep.visualisation import  persistence_diagrams_of_activations

from torch.utils.tensorboard import SummaryWriter
from gdeep.data import TorchDataLoader


from gtda.diagrams import BettiCurve

from gtda.plotting import plot_betti_surfaces




In [99]:
writer = SummaryWriter()

In [100]:
dl = TorchDataLoader(name="MNIST")
dl_tr, dl_ts = dl.build_dataloader(batch_size=32)

In [101]:
from gdeep.pipeline import Pipeline

model = nn.Sequential(nn.Flatten(), FFNet(0, arch=[28*28, 50,20, 10]))

In [102]:
from torch.optim import SGD

print(model)
loss_fn = nn.CrossEntropyLoss()

pipe = Pipeline(model, (dl_tr, dl_ts), loss_fn, writer)

# train the model
#pipe.train(SGD, 3, lr=0.1)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): FFNet(
    (layer0): Linear(in_features=784, out_features=50, bias=True)
    (layer1): Linear(in_features=50, out_features=20, bias=True)
    (layer2): Linear(in_features=20, out_features=10, bias=True)
  )
)


In [103]:
from gdeep.models import ModelExtractor

me = ModelExtractor(model, loss_fn)

x = next(iter(dl_tr))[0]
list_activations = me.get_activations(x)
len(list_activations)


6

In [104]:
from gudhi import SimplexTree as ST
import networkx as nx
import numpy as np


In [126]:

activations = list_activations
arch = [28*28, 50,20,10]

def get_activation_graph(activations,arch, index_batch = 1):
    n_layer = len(arch)
    current_node = 0
    activation_graph = ST()
    edge_list = []
    for i in range(n_layer - 1):
        f = lambda x: (x[0] + current_node, x[1] + current_node)
        G = nx.complete_bipartite_graph(arch[i], arch[i+1])
        l = list(G.edges())
        l = map(f,l)
        edge_list.extend(l)
        current_node += arch[i]
    for edge in edge_list:
        activation_graph.insert(list(edge), 0.0)
    activations_flatten = torch.empty(0)
    for layer in range(n_layer):
        activations_flatten = torch.cat((activations_flatten, activations[layer][index_batch]))
    for neuron in range(activations_flatten.size()[0]):
        activation_graph.insert([neuron], float(activations_flatten[neuron]))
    return activation_graph
    
def gudhi_to_giotto(diagrams_gudhi):
    diagrams_giotto = []
    for diagram in diagrams_gudhi:
        diagram_giotto = []
        for dim, bar in diagram:
            diagram_giotto.append([bar[0],bar[1],dim])
        diagrams_giotto.append(diagram_giotto)
    return diagrams_giotto


def reverse_diagrams(diagrams):
    diagrams_reversed = []
    for diagram in diagrams:
        diagram_reversed = []
        for b,d,q in diagram:
            diagram_reversed.append([d,b,q])
        diagrams_reversed.append(diagram_reversed)       
    return diagrams_reversed

def get_extended_persistence(model, loss_fn, x):
    me = ModelExtractor(model, loss_fn)
    activations = me.get_activations(x)
    diagrams = []
    for i in range(len(x)):
        L = get_activation_graph(activations, arch, index_batch = i)
        L.extend_filtration()
        pers = L.extended_persistence()
        diagrams.append(pers[3])
    return diagrams
    
def diagrams_accross_training(model, loss_fn, x, n_epochs):
    diagrams = []
    for i in range(n_epochs):
        diagrams.append(get_extended_persistence(model,loss_fn,x))
        pipe.train(SGD, 2, lr=0.1)
    return diagrams

def get_entropy_accross_training(diagrams_accross_training):
    entropy = PE(normalize = True)
    n_epochs = len(diagrams_accross_training)
    n_batch = len(diagrams_accross_training[0])
    E = np.empty([n_epochs,n_batch])
    for epoch in range(n_epochs):
        diagrams_giotto_reversed = reverse_diagrams(gudhi_to_giotto(diagrams[epoch]))
        for batch in range(n_batch):
            e = entropy.fit_transform(np.array([diagrams_giotto_reversed[batch]]))[0][0]
            E[epoch][batch] = e
    return E
    
        

In [127]:
diagrams_training = diagrams_accross_training(model, loss_fn, x, 2)


Epoch 1
-------------------------------
Test Error: 00  [57600/60000]
 Accuracy: 87.0%,                 Avg loss: 0.049803 

Epoch 2
-------------------------------
Test Error: 00  [57600/60000]
 Accuracy: 86.9%,                 Avg loss: 0.049839 

Done!
Epoch 1
-------------------------------
Test Error: 00  [57600/60000]
 Accuracy: 86.8%,                 Avg loss: 0.049852 

Epoch 2
-------------------------------
Test Error: 00  [57600/60000]
 Accuracy: 86.8%,                 Avg loss: 0.049855 

Done!


In [128]:
E = get_entropy_accross_training(diagrams_training)

In [129]:
E

array([[0.81333363, 0.80943545, 0.82666803, 0.8168468 , 0.81464089,
        0.79933031, 0.81869068, 0.79040708, 0.82767742, 0.81022885,
        0.80261478, 0.82154047, 0.80982464, 0.80806373, 0.83201313,
        0.80633995, 0.82352765, 0.81347777, 0.87108246, 0.82445659,
        0.79176242, 0.81285016, 0.8385907 , 0.81649729, 0.82424697,
        0.78044384, 0.84579245, 0.78861365, 0.79500831, 0.83150974,
        0.81802856, 0.79884588],
       [0.81569591, 0.81134945, 0.82648976, 0.81897053, 0.81526796,
        0.79942153, 0.81924653, 0.79081971, 0.82929312, 0.81097853,
        0.80328489, 0.82224471, 0.80928084, 0.81066431, 0.83456631,
        0.80919661, 0.82252889, 0.8156613 , 0.87062598, 0.8243306 ,
        0.79187187, 0.81246421, 0.83682435, 0.81768348, 0.82381542,
        0.78022146, 0.8462899 , 0.7910112 , 0.7931555 , 0.83147528,
        0.81646622, 0.79904215]])

In [107]:
from gtda.diagrams import PersistenceEntropy as PE
entropy = PE(normalize = True)
E = np.empty([n_epochs,n_batch])
for epoch in range(n_epochs):
    diagrams_giotto_reversed = reverse_diagrams(gudhi_to_giotto(diagrams[epoch]))
    for batch in range(n_batch):
        e = entropy.fit_transform(np.array([diagrams_giotto_reversed[batch]]))[0][0]
        E[epoch][batch] = e
    


In [108]:
import pandas as pd
import plotly.express as px
df = pd.DataFrame(E)
fig = px.area(df)
fig.show()

## Study of overfitting on one batch

In [116]:
from torch import optim
n_epochs = 1
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(n_epochs):
    image , labels = next(iter(dl_tr))
    optimizer.zero_grad()
    output = model(image)
    loss = loss_fn(output,labels)
    loss.backward()
    optimizer.step()

In [90]:
import numpy as np
np.random.seed(seed=42)
from gtda.homology import VietorisRipsPersistence
from sklearn.datasets import make_circles

X = np.asarray([
    make_circles(100, factor=np.random.random())[0]
    for i in range(10)
])
from gtda.plotting import plot_point_cloud
i = 0
plot_point_cloud(X[i])
VR = VietorisRipsPersistence()
Xt = VR.fit_transform(X)
VR.plot(Xt, sample=i)

In [92]:
Xt

array([[[0.        , 0.04703514, 0.        ],
        [0.        , 0.04703514, 0.        ],
        [0.        , 0.04703514, 0.        ],
        ...,
        [0.62545991, 0.63016409, 1.        ],
        [0.12558104, 0.63016409, 1.        ],
        [0.04703514, 0.65642405, 1.        ]],

       [[0.        , 0.02497084, 0.        ],
        [0.        , 0.02497084, 0.        ],
        [0.        , 0.02497084, 0.        ],
        ...,
        [0.80115759, 0.80311227, 1.        ],
        [0.12558104, 0.80311227, 1.        ],
        [0.02497084, 0.34849384, 1.        ]],

       [[0.        , 0.03008443, 0.        ],
        [0.        , 0.03008443, 0.        ],
        [0.        , 0.03008443, 0.        ],
        ...,
        [0.76043808, 0.76291817, 1.        ],
        [0.12558104, 0.76291817, 1.        ],
        [0.03008443, 0.41985938, 1.        ]],

       ...,

       [[0.        , 0.07146584, 0.        ],
        [0.        , 0.07146584, 0.        ],
        [0.        , 0

In [93]:
entropy.fit_transform(np.array([[diagram[batch]]]))[0][0]

0.0

In [94]:
diagram

[[-79.65325927734375, -49.66649627685547, 1],
 [-78.68570709228516, -49.66649627685547, 1],
 [-79.65325927734375, -47.24231338500976, 1],
 [-78.68570709228516, -47.24231338500976, 1],
 [-71.1774673461914, -49.66649627685547, 1],
 [-71.1774673461914, -47.24231338500976, 1],
 [-79.65325927734375, -35.014949798583984, 1],
 [-78.68570709228516, -35.014949798583984, 1],
 [-79.65325927734375, -31.08163833618164, 1],
 [-79.65325927734375, -30.930641174316413, 1],
 [-78.68570709228516, -31.08163833618164, 1],
 [-78.68570709228516, -30.930641174316413, 1],
 [-71.1774673461914, -35.014949798583984, 1],
 [-71.1774673461914, -31.08163833618164, 1],
 [-71.1774673461914, -30.930641174316413, 1],
 [-79.65325927734375, -18.070802688598626, 1],
 [-79.65325927734375, -17.253404617309585, 1],
 [-78.68570709228516, -18.070802688598626, 1],
 [-78.68570709228516, -17.253404617309585, 1],
 [-71.1774673461914, -18.070802688598626, 1],
 [-79.65325927734375, -9.254110336303711, 1],
 [-79.65325927734375, -8.8323

diagram

In [113]:
a,b = next(iter(dl_tr))

In [115]:
b

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
        1, 2, 4, 3, 2, 7, 3, 8])