In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np

import torch
from torch import nn

from gdeep.models import FFNet

from gdeep.visualisation import  persistence_diagrams_of_activations

from torch.utils.tensorboard import SummaryWriter
from gdeep.data import TorchDataLoader


from gtda.diagrams import BettiCurve

from gtda.plotting import plot_betti_surfaces

from gtda.plotting import plot_diagram

import pandas as pd
import plotly.express as px
from gtda.diagrams import PersistenceEntropy as PE

from gdeep.models import ModelExtractor


In [2]:
writer = SummaryWriter()

In [3]:
dl = TorchDataLoader(name="MNIST")
dl_tr, dl_ts = dl.build_dataloader(batch_size=32)




In [4]:
from gdeep.pipeline import Pipeline

arch = [28*28,  300, 100, 10]

model = nn.Sequential(nn.Flatten(), FFNet(0, arch= arch ))

In [5]:
from torch.optim import SGD

print(model)
loss_fn = nn.CrossEntropyLoss()

pipe = Pipeline(model, (dl_tr, dl_ts), loss_fn, writer)

# train the model
#pipe.train(SGD, 3, lr=0.1)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): FFNet(
    (layer0): Linear(in_features=784, out_features=300, bias=True)
    (layer1): Linear(in_features=300, out_features=100, bias=True)
    (layer2): Linear(in_features=100, out_features=10, bias=True)
  )
)


In [6]:
from gdeep.models import ModelExtractor



x = next(iter(dl_tr))[0]


In [7]:
from gudhi import SimplexTree as ST
import networkx as nx
import numpy as np


In [8]:




def get_activation_graph(activations,arch, index_batch = 1):
    n_layer = len(arch)
    current_node = 0
    activation_graph = ST()
    edge_list = []
    for i in range(n_layer - 1):
        f = lambda x: (x[0] + current_node, x[1] + current_node)
        G = nx.complete_bipartite_graph(arch[i], arch[i+1])
        l = list(G.edges())
        l = map(f,l)
        edge_list.extend(l)
        current_node += arch[i]
    for edge in edge_list:
        activation_graph.insert(list(edge), 0.0)
    activations_flatten = torch.empty(0)
    for layer in range(n_layer):
        activations_flatten = torch.cat((activations_flatten, activations[layer][index_batch]))
    for neuron in range(activations_flatten.size()[0]):
        activation_graph.insert([neuron], float(activations_flatten[neuron]))
    return activation_graph
    
def gudhi_to_giotto(diagrams_gudhi):
    diagrams_giotto = []
    for diagram in diagrams_gudhi:
        diagram_giotto = []
        for dim, bar in diagram:
            diagram_giotto.append([bar[0],bar[1],dim])
        diagrams_giotto.append(np.array(diagram_giotto))
    return diagrams_giotto


def reverse_diagrams(diagrams):
    diagrams_reversed = []
    for diagram in diagrams:
        diagram_reversed = []
        for b,d,q in diagram:
            diagram_reversed.append([d,b,q])
        diagrams_reversed.append(diagram_reversed)       
    return diagrams_reversed

def get_extended_persistence(model, loss_fn, x):
    me = ModelExtractor(model, loss_fn)
    activations = me.get_activations(x)
    diagrams = []
    for i in range(len(x)):
        L = get_activation_graph(activations, arch, index_batch = i)
        L.extend_filtration()
        pers = L.extended_persistence()
        diagrams.append(pers[3])
    return diagrams
    
def diagrams_accross_training(model, loss_fn, x, n_epochs):
    diagrams = []
    for i in range(n_epochs):
        diagrams.append(get_extended_persistence(model,loss_fn,x))
        pipe.train(SGD, 2, lr=0.01)
    return diagrams

def get_entropy_accross_training(diagrams_accross_training):
    entropy = PE(normalize = True)
    n_epochs = len(diagrams_accross_training)
    n_batch = len(diagrams_accross_training[0])
    E = np.empty([n_epochs,n_batch])
    for epoch in range(n_epochs):
        diagrams_giotto_reversed = reverse_diagrams(gudhi_to_giotto(diagrams_accross_training[epoch]))
        for batch in range(n_batch):
            e = entropy.fit_transform(np.array([diagrams_giotto_reversed[batch]]))[0][0]
            E[epoch][batch] = e
    return E
    
        

## Activation extended persistence

In [9]:
n_epochs = 30
diagrams_training = diagrams_accross_training(model, loss_fn, x, n_epochs)


Epoch 1
-------------------------------
Test Error: 97  [57600/60000]
 Accuracy: 14.7%,                 Avg loss: 0.072409 

Epoch 2
-------------------------------
Test Error: 43  [57600/60000]
 Accuracy: 21.3%,                 Avg loss: 0.070347 

Done!
Epoch 1
-------------------------------
Test Error: 98  [57600/60000]
 Accuracy: 25.4%,                 Avg loss: 0.069064 

Epoch 2
-------------------------------
Test Error: 51  [57600/60000]
 Accuracy: 26.4%,                 Avg loss: 0.068764 

Done!
Epoch 1
-------------------------------
Test Error: 01  [57600/60000]
 Accuracy: 26.8%,                 Avg loss: 0.068616 

Epoch 2
-------------------------------
Test Error: 14  [57600/60000]
 Accuracy: 27.3%,                 Avg loss: 0.068473 

Done!
Epoch 1
-------------------------------
Test Error: 48  [57600/60000]
 Accuracy: 27.5%,                 Avg loss: 0.068397 

Epoch 2
-------------------------------
Test Error: 79  [57600/60000]
 Accuracy: 27.6%,                 Avg

KeyboardInterrupt: 

In [None]:
E = get_entropy_accross_training(diagrams_training)
df_e = pd.DataFrame(E)


In [None]:
fig = px.line(df_e)
fig.show()

In [None]:
from gtda.diagrams import PairwiseDistance
dist = PairwiseDistance(metric = 'bottleneck')


In [None]:
diagrams = gudhi_to_giotto(diagrams_training[9]) 
distance_matrix = dist.fit_transform(diagrams)

## Study of overfitting on one batch

In [None]:
from torch import optim
n_epochs = 1
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(n_epochs):
    image , labels = next(iter(dl_tr))
    optimizer.zero_grad()
    output = model(image)
    loss = loss_fn(output,labels)
    loss.backward()
    optimizer.step()

## Generate video of barcodes during training

In [None]:
births = []
deaths = []
epochs = []
batchs = []

n_batch = 32
for epoch in range(n_epochs):
    diagrams_training_giotto = gudhi_to_giotto(diagrams_training[epoch])
    for batch in range(n_batch):
        for b,d,q in diagrams_training_giotto[batch]:
            births.append(b)
            deaths.append(d)
            epochs.append(epoch)
            batchs.append(batch)
        

In [None]:
data = {'birth' : births , 'death' : deaths, 'epoch': epochs, 'batch' : batchs}
df = pd.DataFrame (data, columns = ['birth','death','epoch', 'batch'])


In [None]:


for epoch in range(n_epochs):
    fig = px.scatter(df[df['epoch'] == epoch], x="birth", y="death",  color="batch",
          range_x=[-100, 0], range_y=[-100,0])
    fig.write_image('barcode_' + str(epoch) + '.png')

In [None]:
fig.write_html("video_barcodes.html")

In [None]:
df2 = df[df['batch'] == 0]

In [None]:
fig.write_image('barcodes.png')