# Toplogy of Deep Neural Networks

This notebook will show you how easy it is to use gdeep to reproduce the experiments of the paper *Topology of Deep Neural Networks*, by Naizat et. al.

In [1]:
%reload_ext autoreload
%autoreload 2

# deep learning
import torch
from torch.optim import Adam, SGD
import numpy as np
from torch import nn
from torch import autograd  

#gdeep
from gdeep.data.datasets import DatasetBuilder, DataLoaderBuilder
from gdeep.models import FFNet
from gdeep.visualisation import persistence_diagrams_of_activations
from gdeep.data.preprocessors import ToTensorImage
from gdeep.trainer import Trainer


# plot
import plotly.express as px
import pandas as pd
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

# ML
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import make_blobs
from sklearn.metrics import pairwise_distances

# TDA
from gtda.homology import VietorisRipsPersistence
from gtda.plotting import plot_diagram


No TPUs...


# Initialize the tensorboard writer

In order to analyse the reuslts of your models, you need to start tensorboard.
On the terminal, move inside the `/example` folder. There run the following command:

```
tensorboard --logdir=runs
```

Then go [here](http://localhost:6006/) after the training to see all the visualisation results.


In [28]:
db = DatasetBuilder(name="EntangledTori")
ds_tr, ds_val, ds_ts = db.build()
dl_tr, dl_val, dl_ts = DataLoaderBuilder((ds_tr, ds_val, ds_ts)).build()

# Plot dataset to tensorboard

In [29]:
data_tensor = torch.cat([batch for batch, _ in dl_tr])
label_tensor = torch.cat([batch for _ ,batch in dl_tr])
data_tensor.shape, label_tensor.shape

(torch.Size([1600, 3]), torch.Size([1600]))

In [26]:
# train NN
model = FFNet(arch=[3,10,10,10,10,2])
print(model)
pipe = Trainer(model, (dl_tr, dl_ts), nn.CrossEntropyLoss(), writer)
pipe.train(Adam, 100, False, {"lr":0.01}, {"batch_size":50})

FFNet(
  (linears): ModuleList(
    (0): Linear(in_features=3, out_features=10, bias=True)
    (1): Linear(in_features=10, out_features=10, bias=True)
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): Linear(in_features=10, out_features=10, bias=True)
    (4): Linear(in_features=10, out_features=2, bias=True)
  )
)
Epoch 1
-------------------------------
Epoch training loss: 0.682101 	Epoch training accuracy: 54.28%                                                           
Time taken for this epoch: 0.00s
Learning rate value: 0.01000000
Validation results: 
 accuracy: 58.29%,                 Avg loss: 0.664563 

Epoch 2
-------------------------------
Batch training loss:  0.6711804866790771  	Batch training accuracy:  52.0  	[ 8 / 26 ]                                  



Epoch training loss: 0.647352 	Epoch training accuracy: 56.77%                                                           
Time taken for this epoch: 0.00s
Learning rate value: 0.01000000
Validation results: 
 accuracy: 64.00%,                 Avg loss: 0.608776 

Epoch 3
-------------------------------
Epoch training loss: 0.605375 	Epoch training accuracy: 61.72%                                                           
Time taken for this epoch: 0.00s
Learning rate value: 0.01000000
Validation results: 
 accuracy: 58.00%,                 Avg loss: 0.588700 

Epoch 4
-------------------------------
Epoch training loss: 0.577258 	Epoch training accuracy: 61.92%                                                          
Time taken for this epoch: 0.00s
Learning rate value: 0.01000000
Validation results: 
 accuracy: 67.14%,                 Avg loss: 0.538880 

Epoch 5
-------------------------------
Epoch training loss: 0.551721 	Epoch training accuracy: 63.36%                           

(0.36779201882226126, 81.71428571428571)

In [25]:
from gdeep.analysis.interpretability import Interpreter
from gdeep.visualisation import Visualiser

vs = Visualiser(pipe)
data_tensor = torch.cat([batch for batch, _ in dl_tr])

# the diagrams can be seen on tensorboard!
vs.plot_persistence_diagrams(data_tensor)


ValueError: too many values to unpack (expected 2)