In [1]:
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Resize, Normalize

from torch.optim import Adam
from torch.nn import CrossEntropyLoss

import gudhi as gd
import torch.nn as nn

torch.set_printoptions(precision=2, sci_mode=False, linewidth=100)

In [2]:
def persistence_diagram(image):

    h, w = image.shape
    img_flat = image.flatten()

    ccomplex = gd.CubicalComplex(
        dimensions = (h, w), 
        top_dimensional_cells=img_flat
    )
    
    # get pairs of critical simplices
    ccomplex.compute_persistence()
    critical_pairs = ccomplex.cofaces_of_persistence_pairs()

    # get essential critical pixels (never vanish)
    essential_features = critical_pairs[1][0]

    # 0-homology image critical pixels
    try:
        critical_pairs_0 = critical_pairs[0][0]
    except:
        critical_pairs_0 = np.empty((0, 2))
    critical_0_ver_ind = critical_pairs_0 // w
    critical_0_hor_ind = critical_pairs_0 % w
    critical_pixels_0 = np.stack([critical_0_ver_ind, critical_0_hor_ind], axis=2)

    # 0-homology essential pixels (ends with last added pixel)
    last_pixel = torch.argmax(image).item()
    essential_pixels_0 = np.array([[essential_features[0] // w, essential_features[0] % w], [last_pixel // w, last_pixel % 4]])[np.newaxis, ...]
    critical_pixels_0 = np.vstack([critical_pixels_0, essential_pixels_0])

    # 0-homology persistance diagram
    pd0 = image[critical_pixels_0[:, :, 0].flatten(), critical_pixels_0[:, :, 1].flatten()].reshape((critical_pixels_0.shape[0], 2))

    # 1-homology image critical pixels
    try:
        critical_pairs_1 = critical_pairs[0][1]
    except:
        critical_pairs_1 = np.empty((0, 2))
    critical_1_ver_ind = critical_pairs_1 // w
    critical_1_hor_ind = critical_pairs_1 % w
    critical_pixels_1 = np.stack([critical_1_ver_ind, critical_1_hor_ind], axis=2)

    # 1-homology persistance diagram
    pd1 = image[critical_pixels_1[:, :, 0].flatten(), critical_pixels_1[:, :, 1].flatten()].reshape((critical_pixels_1.shape[0], 2))

    return pd0, pd1

In [3]:
def diagram(image, device, sublevel=True):
    # get height and square image
    h = int(np.sqrt(image.shape[0]))
    image_sq = image.reshape((h,h))

    # create complex
    cmplx = gd.CubicalComplex(dimensions=(h, h), top_dimensional_cells=(sublevel*image))

    # get pairs of critical simplices
    cmplx.compute_persistence()
    critical_pairs = cmplx.cofaces_of_persistence_pairs()
    
    # get essential critical pixel
    bpx0_essential = critical_pairs[1][0][0] // h, critical_pairs[1][0][0] % h

    # get critical pixels corresponding to critical simplices
    try:
        bpx0 = [critical_pairs[0][0][i][0] for i in range(len(critical_pairs[0][0]))]
        dpx0 = [critical_pairs[0][0][i][1] for i in range(len(critical_pairs[0][0]))]
    except IndexError:
        bpx0 = []
        dpx0 = []
        
    try:
        bpx1 = [critical_pairs[0][1][i][0] for i in range(len(critical_pairs[0][1]))]
        dpx1 = [critical_pairs[0][1][i][1] for i in range(len(critical_pairs[0][1]))]
    except IndexError:
        bpx1 = []
        dpx1 = []
    

    flat_image = image_sq.flatten()
    pd0_essential = torch.tensor([[image_sq[bpx0_essential], torch.max(image)]])

    if (len(bpx0)!=0):
        pdb0 = flat_image[bpx0][:, None]
        pdd0 = flat_image[dpx0][:, None]
        pd0 = torch.Tensor(torch.hstack([pdb0, pdd0]))
        pd0 = torch.vstack([pd0, pd0_essential.to(device)])
    else:
        pd0 = pd0_essential

    if (len(bpx1)!=0):
        pdb1 = flat_image[bpx1][:, None]
        pdd1 = flat_image[dpx1][:, None]
        pd1 = torch.Tensor(torch.hstack([pdb1, pdd1]))
    else:
        pd1 = torch.zeros((1, 2))
    
    return pd0, pd1

### Data

In [4]:
transform = Compose([Resize((14, 14)), ToTensor(), Normalize(0.0, 1.0)])

In [5]:
data_train = MNIST("./data", train=False, download=True, transform=transform)
dataloader_train = DataLoader(data_train, batch_size=2, shuffle=True)

In [6]:
X, y = next(iter(dataloader_train))
X

tensor([[[[0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.04, 0.09, 0.01, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.69, 0.16, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.28, 0.90, 0.49, 0.01, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.91, 0.72, 0.08, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.18, 0.84, 0.86, 0.22, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.04, 0.59, 0.98, 0.47, 0.00, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.34, 0.94, 0.59, 0.03, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.24, 0.90, 0.73, 0.08, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.15, 0.80, 0.80, 0.15, 0.00, 0.00, 0.00, 0.00],
          [0.00, 0.00, 0.00, 0.00, 0.00,

### Differentiability of persistent homology

In [7]:
class Img2PD(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3))
        
    def forward(self, X):
        X_conv = self.conv(X)
        
        peristence_diagrams = []
        for i, x_conv in enumerate(X_conv):
            pd = persistence_diagram(x_conv[0])
            peristence_diagrams.append(pd)
        
        return peristence_diagrams

In [8]:
model = Img2PD()
model(X)

[(tensor([[-0.52, -0.51],
          [-0.54, -0.25]], grad_fn=<ReshapeAliasBackward0>),
  tensor([[ 0.02,  0.08],
          [ 0.02,  0.09],
          [-0.05,  0.11]], grad_fn=<ReshapeAliasBackward0>)),
 (tensor([[-0.39, -0.38],
          [-0.30, -0.30],
          [-0.29, -0.27],
          [-0.43, -0.26],
          [-0.34, -0.25],
          [-0.44, -0.03]], grad_fn=<ReshapeAliasBackward0>),
  tensor([[-0.25, -0.25],
          [-0.25, -0.25],
          [-0.17, -0.15],
          [-0.10, -0.10],
          [-0.19, -0.09],
          [-0.13, -0.07],
          [-0.18, -0.06],
          [-0.12, -0.04],
          [-0.22, -0.03]], grad_fn=<ReshapeAliasBackward0>))]

### Model

- any differentiable image transform (convolution, directional filter, etc.),
- sublevel filtration persistent homology layer, mapping an image to a set of persistent diagrams of dimensions 0 and 1,
- any layer taking a set as an input -- DeepSets, Transformer, etc.,
- an aggregation layer, aggregating a transformed persistent diagrams into a vector.

In [9]:
class ConvDiagram(nn.Module):
    def __init__(self, device):
        super(ConvDiagram, self).__init__()
        self.device = device
        
    def forward(self, x):
        diagrams = []
        for i in range(x.shape[0]):
            res = diagram(x[i].flatten(), self.device)
            for j in range(len(res)):
                diagrams.append(torch.concatenate([res[j], torch.Tensor([[j, i] for _ in range(res[j].shape[0])]).to(self.device)], axis=1))
        diagrams = torch.concatenate(diagrams)
        return diagrams


class Transformer(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out, seq_size=1024, nhead=2, num_layers=2, dim_feedforward=16):
        super(Transformer, self).__init__()
        self.embeddings = nn.Linear(n_in, n_hidden)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=n_hidden, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(seq_size, n_out)

    def forward(self, X):
        X = self.embeddings(X)
        X = self.transformer(X)
        X = X.mean(dim=-1)
        X = self.classifier(X)
        X = X.softmax(dim=-1)
        return X


class TopologicalConvTransformer(nn.Module):
    def __init__(self, n_in, n_conv, max_sequence, n_diag, n_hidden, n_out, nhead=2, num_layers=2, dim_feedforward=16, device='cuda'):
        super(TopologicalConvTransformer, self).__init__()
        
        self.max_sequence = max_sequence
        self.conv = nn.Conv2d(n_in, n_conv, 3)
        self.diagram = ConvDiagram(device)
        self.transformer = Transformer(n_diag, n_hidden, n_out, max_sequence, nhead, num_layers, dim_feedforward)

    def forward(self, xs):
        result = []
        for i in range(xs.shape[0]):
            x = xs[i][None, :, :] / 256
            x = self.conv(x)
            x = self.diagram(x)
            if x.shape[0] > self.max_sequence:
                x = x[:self.max_sequence]
            x = F.pad(x, (0, 0, 0, self.max_sequence - x.shape[0]), "constant", 0)
            x = self.transformer(x)
            result.append(x[None, :])
        result = torch.concatenate(result, axis=0)
        return result

In [10]:
kwargs = {"n_in": 1,
 "n_conv": 1,
 "max_sequence": 64,
 "n_diag": 4,
 "n_hidden": 32, "n_out": 10, "nhead": 2, "num_layers": 2, "dim_feedforward": 16, "device": "cpu"}

In [11]:
model = TopologicalConvTransformer(**kwargs)
model(X)[0]

tensor([0.09, 0.11, 0.11, 0.11, 0.10, 0.09, 0.09, 0.10, 0.10, 0.10], grad_fn=<SelectBackward0>)

In [12]:
%%time
n_repeats = 1
n_epochs = 20
batch_size = 100
lr = 0.001

history = np.zeros((n_repeats, n_epochs, 3))
criterion = CrossEntropyLoss()

for repeat_idx in range(n_repeats):
    
    # data init
    dataloader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
    
    # model init
    model = TopologicalConvTransformer(**kwargs)
    optimizer = Adam(model.parameters(), lr=lr)
    
    print("{:3} {:6} {:6}".format(repeat_idx, "Loss", "Acc"))
    
    for epoch_idx in range(n_epochs):
        
        # train
        model.train()
        
        loss_epoch = []
        for batch in dataloader_train:
            loss_batch = criterion(model(batch[0]), batch[1])
            loss_batch.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_epoch.append(loss_batch.detach())
        
        loss_epoch_mean = np.array(loss_epoch).mean()
        history[repeat_idx,epoch_idx,0] = loss_epoch_mean
        
        # test
        model.eval()
        
        correct = 0
        for batch in dataloader_train:
            y_hat = model(batch[0]).argmax(dim=1)
            correct += int((y_hat == batch[1]).sum())
        accuracy_train = correct / len(dataloader_train.dataset)
        history[repeat_idx,epoch_idx,1] = accuracy_train
        
        print("{:3} {:.4f} {:.4f}".format(epoch_idx, loss_epoch_mean, accuracy_train))
    print("\r")

  0 Loss   Acc   
  0 2.3009 0.1138
  1 2.2643 0.1863
  2 2.2337 0.1730
  3 2.2264 0.2261
  4 2.2210 0.2285
  5 2.2141 0.2314
  6 2.2086 0.2458
  7 2.2047 0.2532
  8 2.1976 0.2749
  9 2.1915 0.2750
 10 2.1850 0.2817
 11 2.1833 0.2799
 12 2.1783 0.2823
 13 2.1746 0.2770
 14 2.1714 0.2838
 15 2.1698 0.2761
 16 2.1684 0.2550
 17 2.1675 0.2792
 18 2.1654 0.2832
 19 2.1659 0.2836

CPU times: user 9min 48s, sys: 2min 45s, total: 12min 34s
Wall time: 10min 31s
