<a href="https://colab.research.google.com/github/wpocl/topological-data-analysis/blob/main/projects/mnist/notebooks/topological_model1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install giotto-tda

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from gtda.homology import CubicalPersistence
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from google.colab import drive

In [3]:
plt.style.use("ggplot")
drive.mount("/content/drive")

Mounted at /content/drive


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
train = pd.read_csv("/content/drive/MyDrive/mnist/train.csv")
test = pd.read_csv("/content/drive/MyDrive/mnist/test.csv")

In [6]:
train_images = train.drop(columns="label").values
train_labels = train["label"].values

In [7]:
# Normalise the train images
train_images = train_images / 255

In [8]:
# Reshape to image format
train_images = train_images.reshape(-1, 1, 28, 28)

In [9]:
cubical = CubicalPersistence(homology_dimensions=[0, 1])
train_diagrams = cubical.fit_transform(train_images)

In [10]:
train_diagrams.shape[0]

42000

In [11]:
shape = (42000, -1, 3)

# Create persistence diagrams for 0-th homology dimension
train_diagrams0 = train_diagrams[train_diagrams[:, :, 2] == 0]
train_diagrams0 = train_diagrams0.reshape(shape)[:, :, :2]

# Create persistence diagrams for 1-st homology dimension
train_diagrams1 = train_diagrams[train_diagrams[:, :, 2] == 1]
train_diagrams1 = train_diagrams1.reshape(shape)[:, :, :2]

In [12]:
#@title
import torch
import numpy as np
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch import Tensor

def prepare_batch(batch, point_dim):
    """
    This method 'vectorizes' the multiset in order to take advances of GPU
    processing. The policy is to embed all multisets in batch to the highest
    dimensionality occurring in batch, i.e., ``max(t.size()[0]`` for ``t`` in batch).

    Args:
        batch:
            The input batch to process as a list of tensors.

        point_dim:
            The dimension of the points the inputs consist of.

    Returns:
        A four-tuple consisting of (1) the constructed ``batch``, i.e., a
        tensor with size
        ``batch_size`` x ``n_max_points`` x ``point_dim``; (2) a tensor
        ``not_dummy`` of size ``batch_size`` x ``n_max_points``, where
        ``1`` at position (i,j) indicates if the point is a dummy point,
        whereas ``0`` indicates a dummy point used for padding; (3)
        the max. number of points and (4) the batch size.

    Example::

        >>> from torchph.nn.slayer import prepare_batch
        >>> import torch
        >>> x = [torch.rand(10,2), torch.rand(20,2)]
        >>> batch, not_dummy, max_pts, batch_size = prepare_batch(x)
    """
    if point_dim is None:
        point_dim = batch[0].size(1)
    assert (all(x.size(1) == point_dim for x in batch if len(x) != 0))

    batch_size = len(batch)
    batch_max_points = max([t.size(0) for t in batch])
    input_device = batch[0].device

    if batch_max_points == 0:
        # if we are here, batch consists only of empty diagrams.
        batch_max_points = 1

    # This will later be used to set the dummy points to zero in the output.
    not_dummy_points = torch.zeros(
        batch_size,
        batch_max_points,
        device=input_device)

    prepared_batch = []

    for i, multi_set in enumerate(batch):
        n_points = multi_set.size(0)

        prepared_dgm = torch.zeros(
            batch_max_points,
            point_dim,
            device=input_device)

        if n_points > 0:
            index_selection = torch.tensor(range(n_points),
                                           device=input_device)

            prepared_dgm.index_add_(0, index_selection, multi_set)

            not_dummy_points[i, :n_points] = 1

        prepared_batch.append(prepared_dgm)

    prepared_batch = torch.stack(prepared_batch)

    return prepared_batch, not_dummy_points, batch_max_points, batch_size


def is_prepared_batch(input):
    if not (isinstance(input, tuple) and len(input) == 4):
        return False
    else:
        batch, not_dummy_points, max_points, batch_size = input
        return isinstance(batch, Tensor) and isinstance(not_dummy_points, Tensor) and max_points > 0 and batch_size > 0


def is_list_of_tensors(input):
    try:
        return all([isinstance(x, Tensor) for x in input])

    except TypeError:
        return False


def prepare_batch_if_necessary(input, point_dimension=None):
    batch, not_dummy_points, max_points, batch_size = None, None, None, None

    if is_prepared_batch(input):
        batch, not_dummy_points, max_points, batch_size = input
    elif is_list_of_tensors(input):
        if point_dimension is None:
            point_dimension = input[0].size(1)

        batch, not_dummy_points, max_points, batch_size = prepare_batch(
            input,
            point_dimension)

    else:
        raise ValueError(
            'SLayer does not recognize input format! Expecting [Tensor] or \
             prepared batch. Not {}'.format(input))

    return batch, not_dummy_points, max_points, batch_size


def parameter_init_from_arg(arg, size, default, scalar_is_valid=False):
    if isinstance(arg, (int, float)):
        if not scalar_is_valid:
            raise ValueError('Scalar initialization values are not valid. \
                              Got {} expected Tensor of size {}.'
                             .format(arg, size))
        return torch.Tensor(*size).fill_(arg)
    elif isinstance(arg, torch.Tensor):
        assert(arg.size() == size)
        return arg
    elif arg is None:
        if default in [torch.rand, torch.randn, torch.ones, torch.ones_like]:
            return default(*size)
        else:
            return default(size)
    else:
        raise ValueError('Cannot handle parameter initialization. \
                          Got "{}" '.format(arg))


class SLayerExponential(Module):
    """
    Proposed input layer for multisets [1].
    """
    def __init__(self, n_elements: int,
                 point_dimension: int=2,
                 centers_init: Tensor=None,
                 sharpness_init: Tensor=None):
        """
        Args:
            n_elements:
                Number of structure elements used.

            point_dimension: D
                Dimensionality of the points of which the
                input multi set consists of.

            centers_init:
                The initialization for the centers of the structure elements.

            sharpness_init:
                Initialization for the sharpness of the structure elements.
        """
        super().__init__()

        self.n_elements = n_elements
        self.point_dimension = point_dimension

        expected_init_size = (self.n_elements, self.point_dimension)

        centers_init = parameter_init_from_arg(
            centers_init,
            expected_init_size,
            torch.rand, scalar_is_valid=False)
        sharpness_init = parameter_init_from_arg(
            sharpness_init,
            expected_init_size,
            lambda size: torch.ones(*size)*3)

        self.centers = Parameter(centers_init)
        self.sharpness = Parameter(sharpness_init)

    def forward(self, input)->Tensor:
        batch, not_dummy_points, max_points, batch_size = prepare_batch_if_necessary(
            input,
            point_dimension=self.point_dimension)

        batch = torch.cat([batch] * self.n_elements, 1)

        not_dummy_points = torch.cat([not_dummy_points] * self.n_elements, 1)

        centers = torch.cat([self.centers] * max_points, 1)
        centers = centers.view(-1, self.point_dimension)
        centers = torch.stack([centers] * batch_size, 0)

        sharpness = torch.pow(self.sharpness, 2)
        sharpness = torch.cat([sharpness] * max_points, 1)
        sharpness = sharpness.view(-1, self.point_dimension)
        sharpness = torch.stack([sharpness] * batch_size, 0)

        x = centers - batch
        x = x.pow(2)
        x = torch.mul(x, sharpness)
        x = torch.sum(x, 2)
        x = torch.exp(-x)
        x = torch.mul(x, not_dummy_points)
        x = x.view(batch_size, self.n_elements, -1)
        x = torch.sum(x, 2)
        x = x.squeeze()

        return x

    def __repr__(self):
        return 'SLayerExponential (... -> {} )'.format(self.n_elements)

In [13]:
#@title
class TopologicalModel(nn.Module):
  def __init__(self):
    super(TopologicalModel, self).__init__()

    # Image input layers
    self.image_conv1 = nn.Conv2d(1, 8, kernel_size=5, padding=2)
    self.image_relu1 = nn.ReLU()
    self.image_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.image_dropout1 = nn.Dropout(0.25)

    self.image_conv2 = nn.Conv2d(8, 16, kernel_size=3,padding=1)
    self.image_relu2 = nn.ReLU()
    self.image_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.image_dropout2 = nn.Dropout(0.25)

    self.image_flatten = nn.Flatten()

    # diagram0 input layers
    self.diagram0_slayer = SLayerExponential(32, 2)

    self.diagram0_dropout1 = nn.Dropout(0.2)
    self.diagram0_norm1 = nn.BatchNorm1d(32)
    self.diagram0_fc1 = nn.Linear(32, 64)
    self.diagram0_relu1 = nn.ReLU()

    self.diagram0_dropout2 = nn.Dropout(0.1)
    self.diagram0_norm2 = nn.BatchNorm1d(64)
    self.diagram0_fc2 = nn.Linear(64, 64)
    self.diagram0_relu2 = nn.ReLU()

    self.diagram0_dropout3 = nn.Dropout(0.1)
    self.diagram0_norm3 = nn.BatchNorm1d(64)
    self.diagram0_fc3 = nn.Linear(64, 32)
    self.diagram0_relu3 = nn.ReLU()

    # diagram1 input layers
    self.diagram1_slayer = SLayerExponential(32, 2)

    self.diagram1_dropout1 = nn.Dropout(0.2)
    self.diagram1_norm1 = nn.BatchNorm1d(32)
    self.diagram1_fc1 = nn.Linear(32, 64)
    self.diagram1_relu1 = nn.ReLU()

    self.diagram1_dropout2 = nn.Dropout(0.1)
    self.diagram1_norm2 = nn.BatchNorm1d(64)
    self.diagram1_fc2 = nn.Linear(64, 64)
    self.diagram1_relu2 = nn.ReLU()

    self.diagram1_dropout3 = nn.Dropout(0.1)
    self.diagram1_norm3 = nn.BatchNorm1d(64)
    self.diagram1_fc3 = nn.Linear(64, 32)
    self.diagram1_relu3 = nn.ReLU()

    # Concatenated input layers
    self.fc1 = nn.Linear(16 * 7 * 7 + 32 + 32, 512)
    self.relu1 = nn.ReLU()
    self.dropout1 = nn.Dropout(0.5)

    self.fc2 = nn.Linear(512, 10)
    self.softmax = nn.Softmax(dim=1)

  def forward(self, image_input, diagram0_input, diagram1_input):

    # Image input layers
    x_image = self.image_conv1(image_input)
    x_image = self.image_relu1(x_image)
    x_image = self.image_pool1(x_image)
    x_image = self.image_dropout1(x_image)

    x_image = self.image_conv2(x_image)
    x_image = self.image_relu2(x_image)
    x_image = self.image_pool2(x_image)
    x_image = self.image_dropout2(x_image)

    x_image = self.image_flatten(x_image)

    # diagram0 input layers
    x_diagram0 = self.diagram0_slayer(diagram0_input)
    x_diagram0 = self.diagram0_dropout1(x_diagram0)
    x_diagram0 = self.diagram0_norm1(x_diagram0)
    x_diagram0 = self.diagram0_fc1(x_diagram0)
    x_diagram0 = self.diagram0_relu1(x_diagram0)

    x_diagram0 = self.diagram0_dropout2(x_diagram0)
    x_diagram0 = self.diagram0_norm2(x_diagram0)
    x_diagram0 = self.diagram0_fc2(x_diagram0)
    x_diagram0 = self.diagram0_relu2(x_diagram0)

    x_diagram0 = self.diagram0_dropout3(x_diagram0)
    x_diagram0 = self.diagram0_norm3(x_diagram0)
    x_diagram0 = self.diagram0_fc3(x_diagram0)
    x_diagram0 = self.diagram0_relu3(x_diagram0)

    # diagram1 input layers
    x_diagram1 = self.diagram1_slayer(diagram1_input)
    x_diagram1 = self.diagram1_dropout1(x_diagram1)
    x_diagram1 = self.diagram1_norm1(x_diagram1)
    x_diagram1 = self.diagram1_fc1(x_diagram1)
    x_diagram1 = self.diagram1_relu1(x_diagram1)

    x_diagram1 = self.diagram1_dropout2(x_diagram1)
    x_diagram1 = self.diagram1_norm2(x_diagram1)
    x_diagram1 = self.diagram1_fc2(x_diagram1)
    x_diagram1 = self.diagram1_relu2(x_diagram1)

    x_diagram1 = self.diagram1_dropout3(x_diagram1)
    x_diagram1 = self.diagram1_norm3(x_diagram1)
    x_diagram1 = self.diagram1_fc3(x_diagram1)
    x_diagram1 = self.diagram1_relu3(x_diagram1)

    # Combine inputs
    combined_input = torch.cat((x_image, x_diagram0, x_diagram1), dim=1)

    # Concatenated input layers
    x = self.fc1(combined_input)
    x = self.relu1(x)
    x = self.dropout1(x)

    # Output layer
    output = self.fc2(x)
    output = self.softmax(output)

    return output

In [14]:
# Convert inputs and outputs to tensors on GPU
train_images = torch.Tensor(train_images).to(device)
train_diagrams0 = torch.Tensor(train_diagrams0).to(device)
train_diagrams1 = torch.Tensor(train_diagrams1).to(device)

train_labels_oh = torch.eye(10)[train_labels].to(device)

In [15]:
# Create model and put on GPU
model = TopologicalModel()
model = model.to(device)

In [16]:
learning_rate = 0.001
num_epochs = 200
batch_size = 500

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
num_samples = train.shape[0]

In [None]:
for epoch in range(num_epochs):

    correct = 0
    total = 0

    for start_index in range(0, num_samples, batch_size):
        end_index = max(start_index + batch_size, num_samples)

        # Create batched inputs
        images_batch = train_images[start_index: end_index]
        diagrams0_batch = train_diagrams0[start_index: end_index]
        diagrams1_batch = train_diagrams1[start_index: end_index]

        # Create batched output
        labels_batch = train_labels_oh[start_index: end_index]

        # Forward pass
        output_batch = model(images_batch, diagrams0_batch, diagrams1_batch)
        loss = criterion(output_batch, labels_batch)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(output_batch, 1)
        predicted = predicted.cpu().numpy()
        total += labels_batch.size(0)
        correct += (predicted == train_labels[start_index: end_index]).sum().item()

    # Calculate accuracy on the training set
    accuracy = correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Accuracy: {accuracy:.4f}")

Epoch [1/200], Train Accuracy: 0.6483
Epoch [2/200], Train Accuracy: 0.8525
Epoch [3/200], Train Accuracy: 0.9330
Epoch [4/200], Train Accuracy: 0.9673
Epoch [5/200], Train Accuracy: 0.9748
Epoch [6/200], Train Accuracy: 0.9791


In [None]:
# model = model.to("cpu")
# train_images = train_images.to("cpu")
# train_diagrams0 = train_diagrams0.to("cpu")
# train_diagrams1 = train_diagrams1.to("cpu")

In [None]:
output = model(train_images, train_diagrams0, train_diagrams1)
train_labels_ = torch.argmax(output, dim=1).cpu().numpy()

accuracy = accuracy_score(train_labels_, train_labels)

print(f"Training accuracy: {accuracy}")

In [None]:
disp = ConfusionMatrixDisplay(confusion_matrix(train_labels, train_labels_))
disp.plot()