## Deep Learning with Topological Signatures

https://arxiv.org/pdf/1707.04041.pdf

### Evaluation on the fashion-mnist Dataset

This notebook follows the same experimental methodology as the original with few differences adapted for the specific problem. 

Ref: https://github.com/c-hofer/jmlr_2019


![image.png](attachment:image.png)

In [None]:
import torch
torch.manual_seed(123)
import random
random.seed(123)


import torch.nn as nn
import os
import shutil
import itertools


import core.config as config
from chofer_tda_datasets import Animal
from chofer_tda_datasets.transforms import Hdf5GroupToDict
from core.utils import *

from torchph.nn.slayer import SLayerExponential, \
SLayerRational, \
LinearRationalStretchedBirthLifeTimeCoordinateTransform, \
prepare_batch, SLayerRationalHat
from sklearn.model_selection import ShuffleSplit
from collections import Counter, defaultdict
from torch.utils.data import DataLoader, RandomSampler
from collections import OrderedDict
from torch.autograd import Variable

from sklearn.model_selection import StratifiedShuffleSplit

%matplotlib notebook

os.environ['CUDA_VISIBLE_DEVICES'] = str(1)


class train_env:
    nu = 0.01
    n_epochs = 40
    lr_initial = 0.01
    momentum = 0.9
    lr_epoch_step = 10
    batch_size = 100
used_directions = ['dim_0_dir_{}'.format(i) for i in range(0, 32,2)]


In [None]:
import pickle

# 32 directions of persistent homologies

train_dataset, test_dataset = pickle.load(open("fmnist_dir_32.pckl","rb"))

In [None]:
class PHTCollate:   
    def __init__(self, nu, cuda=True, rotation_augmentation=False):
        self.cuda = cuda
        self.rotation_augmentation = rotation_augmentation
        
    def __call__(self, sample_target_iter):
        
        augmented_samples = []
        if self.rotation_augmentation:
            samples, targets = [], []
            for x, y in sample_target_iter:                
                i = random.randint(0, len(used_directions)-1)
                shifted_keys = used_directions[i:] + used_directions[:i]                
                
                samples.append({k: x[ki] for k, ki in zip(used_directions, shifted_keys)})
                targets.append(y)
                
            sample_target_iter = zip(samples, targets)

        x, y = dict_sample_target_iter_concat(sample_target_iter)                                            
                                              
        for k in x.keys():
            batch_view = x[k]
            x[k] = prepare_batch(batch_view, 2)                  

        y = torch.LongTensor(y)    

        if self.cuda:
            # Shifting the necessary parts of the prepared batch to the cuda
            x = {k: collection_cascade(v,
                                       lambda x: isinstance(x, tuple),
                                       lambda x: (x[0].cuda(), x[1].cuda(), x[2], x[3]))
                 for k, v in x.items()}

            y = y.cuda()

        return x, y                       
    
collate_fn = PHTCollate(train_env.nu, cuda=True)

In [None]:
def Slayer(n_elements):
    return SLayerRationalHat(n_elements, radius_init=0.25, exponent=1)

def LinearCell(n_in, n_out):
    m = nn.Sequential(nn.Linear(n_in, n_out), 
                      nn.BatchNorm1d(n_out), 
                      nn.ReLU(),
                     )
    m.out_features = m[0].out_features
    return m

# Identical with AnimalModel used for the Animal Shapes for a fair comparison
class FMnistModel(nn.Module):
    def __init__(self):
        super().__init__()   
        self.n_elements = 100
        
        self.slayers = ModuleDict()
        for k in used_directions:
            s = Slayer(self.n_elements)
            self.slayers[k] = nn.Sequential(s)            
            
        cls_in_dim = len(used_directions)*self.n_elements
        self.cls = nn.Sequential(
                                nn.Dropout(0.3),
                                LinearCell(cls_in_dim, int(cls_in_dim/4)),    
                                nn.Dropout(0.2),
                                LinearCell(int(cls_in_dim/4), int(cls_in_dim/16)),  
                                nn.Dropout(0.1),
                                nn.Linear(int(cls_in_dim/16), 20))
        
    def forward(self, input):
        x = []
        for k in used_directions:            
            xx = self.slayers[k](input[k])
            x.append(xx)

        x = torch.cat(x, dim=1)          
        x = self.cls(x)       
                                              
        return x
    
    def center_init(self, sample_target_iter):
        centers = k_means_center_init(sample_target_iter, self.n_elements)
        
        for k, v in centers.items():
            self.slayers._modules[k][0].centers.data = v

In [None]:
dl_train = DataLoader(train_dataset,
                      batch_size=train_env.batch_size, 
                      collate_fn=collate_fn,
                      sampler=RandomSampler(train_dataset))

dl_test = DataLoader(test_dataset,
                     batch_size=train_env.batch_size, 
                     collate_fn=collate_fn, 
                     sampler=RandomSampler(test_dataset))

In [None]:
train_slayer=True

model = FMnistModel()
model.center_init([i for i in train_dataset])
model.cuda()


opt = torch.optim.SGD(model.parameters() if train_slayer else model.cls.parameters(), 
                      lr=train_env.lr_initial, 
                      momentum=train_env.momentum)



In [None]:
import numpy as np

stats_of_runs = []

stats = defaultdict(list)
stats_of_runs.append(stats)
for i_epoch in range(1, train_env.n_epochs+1):      

    model.train()



    epoch_loss = 0    

    if i_epoch % train_env.lr_epoch_step == 0:
        adapt_lr(opt, lambda lr: lr*0.5)

    for i_batch, (x, y) in enumerate(dl_train, 1):              

        y = torch.autograd.Variable(y)

        def closure():
            opt.zero_grad()
            y_hat = model(x)            
            loss = nn.functional.cross_entropy(y_hat, y)   
            loss.backward()
            return loss

        loss = opt.step(closure)

        epoch_loss += float(loss)
        stats['loss_by_batch'].append(float(loss))
        stats['centers'].append(model.slayers['dim_0_dir_0'][0].centers.data.cpu().numpy())

        print("Epoch {}/{}, Batch {}/{}, Loss {}".format(i_epoch, train_env.n_epochs, i_batch, len(dl_train),float(loss)), end="       \r")

    stats['train_loss_by_epoch'].append(epoch_loss/len(dl_train))            

    model.eval()    
    true_samples = 0
    seen_samples = 0
    epoch_test_loss = 0

    for i_batch, (x, y) in enumerate(dl_test):

        y_hat = model(x)
        epoch_test_loss += float(nn.functional.cross_entropy(y_hat, torch.autograd.Variable(y.cuda())).data)

        y_hat = y_hat.max(dim=1)[1].data.long()

        true_samples += (y_hat == y).sum()
        seen_samples += y.size(0)  

    test_acc = true_samples.item()/seen_samples
    stats['test_accuracy'].append(test_acc)
    stats['test_loss_by_epoch'].append(epoch_test_loss/len(dl_test))
    print("Epoch {}/{}, Test Loss {} Test Acc. {}".format(i_epoch, train_env.n_epochs, i_batch, epoch_test_loss/len(dl_test),test_acc, end="       \n"))


print('acc.', np.mean(stats['test_accuracy'][-10:]))



In [None]:
res_learned_slayer = experiment(True)
accs = [np.mean(s['test_accuracy'][-10:]) for s in res_learned_slayer]
print(accs)
print(np.mean(accs))
print(np.std(accs))

In [None]:
stats = res_learned_slayer[-1]
plt.figure()

if 'centers' in stats:
    c_start = stats['centers'][0]
    c_end = stats['centers'][-1]

    plt.plot(c_start[:,0], c_start[:, 1], 'bo', label='center initialization')
    plt.plot(c_end[:,0], c_end[:, 1], 'ro', label='center learned')

    all_centers = numpy.stack(stats['centers'], axis=0)
    for i in range(all_centers.shape[1]):
        points = all_centers[:,i, :]
        plt.plot(points[:, 0], points[:, 1], '-k', alpha=0.25)
        

    plt.legend()
    
plt.figure()
plt.plot(stats['train_loss_by_epoch'], label='train_loss')
plt.plot(stats['test_loss_by_epoch'], label='test_loss')
plt.plot(stats['test_accuracy'], label='test_accuracy')


plt.legend()
plt.show()
