In [5]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import datasets

import medmnist
from medmnist import INFO, Evaluator

from matplotlib import pyplot as plt

from gudhi import CubicalComplex
from gudhi.sklearn.cubical_persistence import CubicalPersistence

from medmnist import PneumoniaMNIST

from torch_cube_perslap_faster import ColumnRearrangeRegular, ColumnRearrangeGeneral

In [525]:
# class ColumnRearrangeGeneral:
#     def __init__(self, D):
#         self.edges = dict()
#         self.hyperedges = dict()
#         self.vertices = None
#         self.D = D
#         self.D_nonzero = D.nonzero()
#         self.visited = None
#         self.CC = []
#         # self.count = 0
#         # self.label = torch.zeros(D.shape[1])

#     def find_vertices(self):
#         self.vertices = (torch.nonzero(self.D.abs().sum(axis=0))).squeeze().tolist()

#     def vertices_visited(self):
#         # self.visited = torch.zeros(len(self.vertices), dtype=torch.bool)
#         self.visited = {i: False for i in self.vertices}

    
#     def find_edges(self):
#         v_count = self.D.abs().sum(axis=1)

        
#         if len((v_count == 1).nonzero()) != 1:
#             hyperedge_idx = (v_count == 1).nonzero().squeeze().tolist()
#         elif len((v_count == 1).nonzero()) == 1:
#             hyperedge_idx = [(v_count == 1).nonzero().squeeze().tolist()]

    
#         if len((v_count == 2).nonzero()) != 1:
#             edge_idx = (v_count == 2).nonzero().squeeze().tolist()
#         elif len((v_count == 2).nonzero()) == 1:
#             edge_idx = [(v_count == 2).nonzero().squeeze().tolist()]

#         self.edges = {i: (self.D_nonzero[self.D_nonzero[:, 0] == i][:, 1]).tolist() 
#                       for i in edge_idx}
#         self.hyperedges = {i: (self.D_nonzero[self.D_nonzero[:, 0] == i][:, 1]).tolist() 
#                       for i in hyperedge_idx}
    
#     def dfs(self, u, temp):
#         self.visited[u] = True
        
#         temp.append(u)

#         for edge in self.edges.values():
#             if u in edge:
#                 v = edge[edge != u]
#                 if not self.visited[v]:
#                     temp = self.dfs(v, temp)
#         return temp

#     def connected_components(self):

#         for v in self.vertices:
#             if not self.visited[v]:
#                 temp = []
#                 self.CC.append(self.dfs(v, temp))

#         if len(self.hyperedges) != 0:
#             temp = []
#             if len(list(self.hyperedges.values())) == 1:
#                 hyperedges = list(self.hyperedges.values())[0]
#             elif len(list(self.hyperedges.values())) > 1:
#                 hyperedges = torch.tensor(list(self.hyperedges.values())).squeeze().tolist()
#             for i in range(len(self.CC)):
#                 if len(set(hyperedges) & set(self.CC[i])) == 0:
#                     temp.append(i)
#             self.CC = [self.CC[i] for i in temp]

#     def find_zero_column(self):
#         self.zero_column = (self.D.abs().sum(axis=0) == 0).nonzero().squeeze().tolist()
#         if not isinstance(self.zero_column, list):
#             self.zero_column = [self.zero_column]

        
#     # def select_components(self):
#     #     if len(self.hyperedges) != 0:
#     #         temp = []
#     #         if len(list(self.hyperedges.values())) == 1:
#     #             hyperedges = list(self.hyperedges.values())[0]
#     #         elif len(list(self.hyperedges.values())) > 1:
#     #             hyperedges = torch.tensor(list(self.hyperedges.values())).squeeze().tolist()
#     #         for i in range(len(self.CC)):
#     #             if len(set(hyperedges) & set(self.CC[i])) == 0:
#     #                 temp.append(i)
#     #         self.CC = [self.CC[i] for i in temp]
#     #     else:
#     #         pass

In [6]:
# D = torch.tensor([[0, 1, -1, 0, 0, 0], 
#                   [1, -1, 0, 0, 0, 0], 
#                   [0, 0, 0, 1, 0, -1]])

# D = torch.tensor([[0, 1, -1, 0, 0, 0, 0], 
#                   [1, -1, 0, 0, 0, 0, 0], 
#                   [0, 0, 0, 1, 0, -1, 0]])

D = torch.tensor([[0, 0, 0, 1, 0, 0, 0, 0], 
                  [1, -1, 0, 0, 0, 0, 0, 0], 
                  [0, 0, 0, 1, 0, -1, 0, 0],
                  [0, 0, 0, 0, 1, 0, -1, 0]])

# D = torch.tensor([[0, 1, 0, 0, 0, 0, 0], 
#                   [1, -1, 0, 0, 0, 0, 0], 
#                   [0, 1, 0, 0, 0, -1, 0],
#                   [0, 0, 0, 0, 1, 0, -1]])

In [7]:
D

tensor([[ 0,  0,  0,  1,  0,  0,  0,  0],
        [ 1, -1,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  1,  0, -1,  0,  0],
        [ 0,  0,  0,  0,  1,  0, -1,  0]])

In [8]:
graph_D = ColumnRearrangeGeneral(D)

In [9]:
# Find vertices wrt D
graph_D.find_vertices()

graph_D.vertices

[0, 1, 3, 4, 5, 6]

In [10]:
# Find edges wrt D
graph_D.find_edges()

In [11]:
graph_D.edges

{1: [0, 1], 2: [3, 5], 3: [4, 6]}

In [12]:
graph_D.hyperedges

{0: [3]}

In [13]:
graph_D.vertices_visited()

In [14]:
graph_D.connected_components()

In [15]:
graph_D.CC

[[0, 1], [4, 6]]

In [16]:
[5]

[5]

In [17]:
list(graph_D.hyperedges.values())

[[3]]

In [18]:
torch.tensor(list(graph_D.hyperedges.values())).squeeze().tolist()

3