In [1]:
from asmt.fasmt import FASMT
from src.smt.input_signal_subsampled import Signal
from src.smt.utils import dec_to_bin_vec
from smt.random_group_testing import decode
from scipy.sparse import csc_array
import numpy as np

In [4]:
indptr = [0]
indices = []
data = []

max_n = 1000
unique_edges =set()

with open("data/mathoverflow-answers/hyperedges-mathoverflow-answers.txt") as file:
    for line in file:
        edge = line.split(",")
        edge = sorted(i for i in edge if int(i) < max_n)
        if len(edge) > 0 and tuple(edge) not in unique_edges:
            for index in edge:
                indices.append(int(index))
                data.append(len(edge))
            indptr.append(len(indices))
            unique_edges.add(tuple(edge))

inc_mat = csc_array((data, indices, indptr), dtype=int)
n, k = inc_mat.shape
print(f"{n} nodes, {k} edges")
degrees = inc_mat.getnnz(axis=0)
def get_subgraph_size(nodes):
    return np.sum(inc_mat[nodes].getnnz(axis=0) == degrees)

1000 nodes, 1008 edges


In [5]:
class GraphSignal(Signal):
    """
    This is a signal object, except it implements the unimplemented 'subsample' function.
    """

    def __init__(self, **kwargs):

        def sampling_function(query_batch):
            return get_subgraph_size(np.where(query_batch))

        self.sampling_function = sampling_function

        super().__init__(**kwargs)

    def subsample(self, query_indices):
        """
        Computes the signal/function values at the queried indicies on the fly
        """
        return self.sampling_function(query_indices)
    
signal = GraphSignal(n=n, q=2)

'''
Create a ASMT instance and perform the transformation
'''
fasmt = FASMT()
transform_args = {"query_method": "group_testing", "decoder": decode, "notebook": True}
result = fasmt.transform(signal, verbosity=5, timing_verbose=True, report=True, sort=True, **transform_args)

0 samples [00:00, ? samples/s]

In [76]:
# %timeit get_subgraph_size(np.arange(n)[::2])

93.5 µs ± 3.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
sum(result['transform'].values())

1008