In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import logging
from tqdm import tqdm
import pandas as pd
from toolkit import simplex, link_prediction
from toolkit.utils import hyperedges_to_edges

logging.basicConfig(level=logging.INFO)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def read_zabka_csv(csv_path):
    df = pd.read_csv(csv_path)
    nodes = []
    hyper_edges = []
    edges = []
    logging.info(f' Reading baskets from zabka dataset')
    for order_id in tqdm(df["basket"].drop_duplicates()):
        basket = df[df["basket"] == order_id]["product_id"].values.tolist()
        hyper_edges.append(basket)
        nodes += basket
        edges += [(x, y) for x in basket for y in basket if x < y]
    nodes = list(set(nodes))
    edges = list(set(edges))
    return len(nodes), edges, hyper_edges

In [4]:
#(nodes_num, edges, hyperedges) = read_zabka_csv('~/Downloads/zabka.csv')

In [5]:
%store -r hyperedges

In [7]:
def run(hyperedges, max_simplex_dim, threshold, embedding_depths: list):

        s = simplex.SimplicalComplex(hyperedges, max_simplex_dim, threshold)

        edges = hyperedges_to_edges(s.hypergraph)

        edges_remapped = [
            tuple([s.node_persistor[simplex.Label([v])] for v in edge])
                for edge in edges
        ]

        logging.info(f' Number of edges left in the graph after truncating: {len(edges_remapped)}')

        (g, link_prediction_metric) = link_prediction.link_prediction_setup(
            len(s.node_persistor),
            edges_remapped
        )

        embedding = simplex.generate_embedding(s, 10)
        score = link_prediction_metric(embedding)

        logging.info(f' Score: {score}')
        return score

In [8]:
def experiment_driver(max_simplex_dims: list, tresholds: list, embedding_depths: list):

    hstb = simplex.HashTensorBuilder(2, 2)

    results = simplex.HashMatrixBuilder()

    for max_simplex_dim in max_simplex_dims:
        scores = []
        for treshold in tresholds:

            score = run(hyperedges, max_simplex_dim, treshold, embedding_depths)
            hstb[(max_simplex_dim, treshold)] = score
            
    hst = hstb.collapse()

    hst.pretty_print()
   

In [9]:
def exp1():
    data = experiment_driver(
        [5 + i for i in range(3)],
        [75 + 5 * i for i in range(3)],
        [3 + i for i in range(3)]
    )

__DATA TABLE__ 
row: 
cols: 
	75	80	85
5	0.87	0.86	0.85
6	0.83	0.88	0.79
7	0.7	0.77	0.78

__DATA TABLE__ 
row: 
cols: 
	75	80	85
5	0.86	0.85	0.85
6	0.82	0.88	0.78
7	0.7	0.77	0.77



In [10]:
logging.basicConfig(level=logging.WARNING)

exp1()

INFO:root: Creating Simplical Complex with 
             	 max_simplex_dim : 5 
             	 threshold : 75             
100%|██████████| 210408/210408 [00:15<00:00, 13781.03it/s]
  sparse_csr = torch.sparse_csr_tensor(
INFO:root: Number of edges left in the graph after truncating: 377993
INFO:root: Generating embedding with 
         	 depth: 3 
         	 max_dimenson_per_simplex: 128        
INFO:root: Creating dense matrix of size: rows: 113, cols: 113.
INFO:root: shape of T: torch.Size([113, 125])
INFO:root: Creating dense matrix of size: rows: 125, cols: 125.
INFO:root: shape of T: torch.Size([125, 20])
INFO:root: Creating dense matrix of size: rows: 125, cols: 125.
INFO:root: It took 0.011640787124633789 to compute this product (1).
INFO:root: shape of t: torch.Size([113, 125])
INFO:root: Generated embedding with dimesnions: torch.Size([4168, 78])
100%|██████████| 100/100 [00:11<00:00,  8.82it/s]
INFO:root: Score: [0.8673284091119738, 0.86]
INFO:root: Creating Simplical Comple

AttributeError: 'NoneType' object has no attribute 'pretty_print'