In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import torch
import gudhi as gd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import networkx as nx
import scipy.sparse
from nn_homology import nn_graph

torch.manual_seed(2)

<torch._C.Generator at 0x7f48b63719f0>

In [2]:
card  = 20    # max number of points in the diagrams
hom   = 0     # homological dimension
ml    = 10.   # max distance in Rips
lr    = 1e-1  # learning rate
lbda  = 0.95  # hyperparameter for topological loss
percentile = 95 # threshold percentile on weights

In [3]:
def compute_persistence(st, card, hom_dim):
    st.persistence()
    pairs = st.persistence_pairs()
    dgm = np.zeros([card, 2], dtype=np.float32)
    spl = np.zeros([card, 2*hom_dim + 3], dtype=np.int32)
    mask = np.zeros([card], dtype=np.int32)
    count = 0
    for [splx0, splx1] in pairs:
        # splx0 is negative simplex, splx1 is positive simplex
        # these are arrays of ??indices?? that will be of length hom_dim+1
        if len(splx0) - 1 == hom_dim and count < card and len(splx1) > 0:
            dgm[count,0], dgm[count,1] = st.filtration(splx0), st.filtration(splx1)
            # store the positive and negative simplices
            spl[count,:hom_dim+1], spl[count,hom_dim+1:] = np.array(splx0), np.array(splx1)
            # and track which points are relevant with mask
            mask[count] = 1
            count += 1
    return [dgm, spl, mask]

def compute_rips(card, hom_dim, D, max_length):
    D[D == 0] = np.finfo(float).eps # need slightly nonzero values because Gudhi is weird??
    rc = gd.RipsComplex(distance_matrix=D)
    st = rc.create_simplex_tree(max_dimension=hom_dim+1)
    return compute_persistence(st, card, hom_dim)

def compute_rips_grad(grad_dgm, dgm, spl, mask, c, h, params, adj, idx_vec, tol=1e-6):
    grad_x = torch.zeros(params.shape)
    for i in range(c[0]):
        if mask[i] == 1:
            val0, val1 = dgm[i,0], dgm[i,1]
            splx0, splx1 = spl[i,:h[0]+1], spl[i,h[0]+1:]
            # get rows in distance matrix according to each face of each simplex
            D0, D1 = adj[splx0,:][:, splx0], adj[splx1,:][:, splx1]
            
            # find maximally distant simplices in filtration
            [v0a, v0b] = list(splx0[np.argwhere(np.abs(D0-val0) <= tol)[0,:]])
            [v1a, v1b] = list(splx1[np.argwhere(np.abs(D1-val1) <= tol)[0,:]])
            
            v0a = idx_vec[v0a]
            v0b = idx_vec[v0b]
            
            v1a = idx_vec[v1a]
            v1b = idx_vec[v1b]
            
            # v0a and v0b refer to the indices of the simplices in the lower dimension
            # v1a and v1b refer to the indices of the simplices in the higher dimension
            if h[0] > 0:
                grad_x[v0a] += grad_dgm[i,0] * (params[v0a] - params[v0b]) / val0
                grad_x[v0b] += grad_dgm[i,0] * (params[v0b] - params[v0a]) / val0
            grad_x[v1a] += grad_dgm[i,1] * (params[v1a] - params[v1b]) / val1
            grad_x[v1b] += grad_dgm[i,1] * (params[v1b] - params[v1a]) / val1
    return grad_x

class Rips(torch.autograd.Function):

    @staticmethod
    def forward(ctx, params, adj, idx_vec, card, hom_dim, max_length):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        
        dgm, spl, mask = compute_rips(card, hom_dim, adj, max_length)
        ctx.dgm = dgm
        ctx.spl = spl
        ctx.card = card
        ctx.hom_dim = hom_dim
        ctx.adj = adj
        ctx.idx_vec = idx_vec
        ctx.params = params
        return torch.tensor(dgm), torch.tensor(spl), torch.tensor(mask)

    @staticmethod
    def backward(ctx, dgm_grad, spl_grad, mask_grad):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        dgm = ctx.dgm
        spl = ctx.spl
        c = ctx.card
        h = ctx.hom_dim
        params = ctx.params
        adj = ctx.adj
        idx_vec = ctx.idx_vec
        grad_x = compute_rips_grad(dgm_grad.detach().numpy(), dgm, spl, mask, c, h, params, adj, idx_vec)
        return None, None, grad_x, None, None, None

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 12, 3)
        self.fc1 = nn.Linear(6912, 128)
        self.fc2 = nn.Linear(128, 10)
        
        self.param_info = [{'layer_type': 'Conv2d', 'kernel_size':(3,3), 'stride':1, 'padding': 0, 'name':'Conv1'},
                            {'layer_type': 'Conv2d', 'kernel_size':(3,3), 'stride':1, 'padding':0, 'name':'Conv2'},
                            {'layer_type':'Linear', 'name': 'Linear1'},
                            {'layer_type':'Linear', 'name': 'Linear2'}]
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = torch.flatten(x,1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
def sum_diag_loss(output, dgm, target, lbda=0.3):
    diagloss = -torch.sum((dgm[:,1]-dgm[:,0])**2)
    fn = F.nll_loss(output,target)
    print(diagloss.detach().cpu().numpy(), fn.detach().cpu().numpy())
    return (1-lbda)*fn + lbda*diagloss

def flatten_params_torch(param_info, device):
    param_vecs = []
    for param in param_info:
        if param['layer_type'] == 'Conv2d':
            p = param['param']
            param_vecs.append(p.reshape(p.shape[0],-1).flatten())
        if param['layer_type'] == 'Linear':
            p = param['param']
            param_vecs.append(p.flatten())

    # make the first element zero (could be anything given we're filtering below)
    param_vecs = [torch.zeros(1).to(device)] + param_vecs
    param_vec = torch.cat(param_vecs)

    return param_vec

def train_regular(model, device, train_loader, optimizer, epoch, log_interval=100):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
def train(model, param_info, G, device, train_loader, optimizer, epoch, lbda, p, log_interval=100, update_every=10):
    rips = Rips.apply
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        up = nn_graph.append_params(model.param_info, nn_graph.get_weights(model, tensors=True))
        params = flatten_params_torch(up, device)
        
        if batch_idx % update_every == 0:
            params = flatten_params_torch(up, device).cpu().numpy()
            thresh = np.percentile(1./(1.+np.abs(params)), p)
            G.parameter_graph(model, param_info, (1,1,28,28), update_indices=True, threshold=thresh)
        else:
            G.update_adjacency(model)
        
        output = model(data)
        dgm, spl, msk = rips(params, G.get_adjacency(), G.graph_idx_vec, card, hom, ml)
        dgmnumpy = dgm.detach().cpu().numpy()
#         print(dgmnumpy[dgmnumpy > 0].shape)
        loss = sum_diag_loss(output, dgm, target, lbda=lbda)
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
                datasets.FashionMNIST('../../data', train=True, download=False,
                transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=32, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
        datasets.FashionMNIST('../../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=32, shuffle=True, **kwargs)

In [6]:
NNG = nn_graph.NNGraph()
model = Net()

modeldev = Net().to(device)
optimizer = torch.optim.Adam(modeldev.parameters(), lr=0.001)

for epoch in range(1, 2):
    train(modeldev, model.param_info, NNG, device, train_loader, optimizer, epoch, lbda, percentile, update_every=100)
    test(modeldev, device, test_loader)

Layer: Conv1
Layer: Conv2
Layer: Linear1
Layer: Linear2
-9.860761e-31 2.304923
-9.860761e-31 2.2989836
-9.860761e-31 2.2035003
-9.860761e-31 2.1120586
-9.860761e-31 1.8236247
-9.860761e-31 1.7109694
-9.860761e-31 1.5530121
-9.860761e-31 1.5143003


Traceback (most recent call last):
  File "/home/schraterlab/anaconda3/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/schraterlab/anaconda3/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/schraterlab/anaconda3/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/schraterlab/anaconda3/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/schraterlab/gebhart/projects/ripsreg/env/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-a4842d042867>", line 8, in <module>
    train(modeldev, model.param_info, NNG, device, train_loader, optimizer, epoch, lbda, percentile, update_every=100)
  File "<ipython-input-4-00d1ebd80173>", line 79, in train
    dgm, spl, msk = rips(params, G.get_adjacency(), G.graph_idx_vec, card, hom, ml)
  File "<ipython-input-3-038ae775592d>", line 65, in forward
    dgm, spl, mask = compute_rips(card, hom_dim, adj, max_length)
  File "<ipython-input-3-038ae775592d>", line 24, in compute_rips
    return compute_persistence(st, card, hom_dim)
  File "<ipython-input-3-038ae775592d>", line 2, in compute_persistence
    st.persistence()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent 

KeyboardInterrupt: 

In [None]:
# modeldev = Net().to(device)
# optimizer = torch.optim.Adam(modeldev.parameters(), lr=0.001)

# for epoch in range(1, 2):
#     train_regular(modeldev, device, train_loader, optimizer, epoch)
#     test(modeldev, device, test_loader)