In [30]:
from calendar import leapdays
from collections import defaultdict
import dataclasses
from email.policy import default

import kneed
import numpy as np
import pandas as pd
import zipfile
import ast
import kmapper as km
from sklearn.cluster import DBSCAN
from scipy.spatial.distance import pdist, squareform
from sklearn.neighbors import KDTree
from networkx.readwrite import json_graph
import networkx as nx
import json
import os
import itertools

from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from sklearn.datasets import make_moons
from tqdm import tqdm
from dataclasses import dataclass

import matplotlib.pyplot as plt

In [153]:
class BallMapper:
    def __init__(self) -> None:
        pass

    def map(self, data: pd.DataFrame, epsilon: float):
        """
        Maps the data to a ball graph.

        Parameters
        ---------
        data: pd.DataFrame
            The data to map.
        epsilon: float
            The radius of the ball.
        """
        graph = {}

        coverage = [False] * len(data)
        nodes = dict()
        edges = set()

        # compute a KD-tree for fast radius search
        tree = KDTree(data, leaf_size=5)

        # iterate over all points
        for i in range(len(data)):
            if not coverage[i]:
                # find all points in the ball
                neighbors = tree.query_radius([data.iloc[i]], r=epsilon)[0]

                # add the point to the graph
                nodes[i] = set(neighbors.tolist())

                # mark all points in the ball as covered
                for j in neighbors:
                    coverage[j] = True

                # mark the point as covered
                coverage[i] = True

        # iterate over all nodes
        for i in nodes:
            for j in nodes:
                if i != j and nodes[i].intersection(nodes[j]):
                    edges.add(tuple(sorted((i, j))))

        graph['nodes'] = {k: list(v) for k, v in nodes.items()}
        graph['links'] = {i: [j] for i, j in edges}

        return graph
        
def elbow_eps(data):
    n_neighbors = 4
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(data)
    distances, indices = nbrs.kneighbors(data)
    distances = np.sort(distances, axis=0)[::-1]
    kneedle = kneed.KneeLocator(distances[:, n_neighbors - 1], np.linspace(0, 1, num=len(distances)), curve='convex', direction='decreasing')
    kneedle.plot_knee()
    eps = kneedle.knee
    return eps

In [154]:
activations = pd.read_csv('../backend/data/ss-role/fine-tuned-bert-base-uncased/train/176/12.txt', delim_whitespace=True, header=None)
labels = pd.read_csv('../backend/data/ss-role/entities/train.txt', delim_whitespace=True, header=None)
X = activations
# plt.scatter(X.iloc[:, 0], X.iloc[:, 1])

bm = BallMapper()
graph = bm.map(X, epsilon=20)

G = nx.Graph()
G.add_nodes_from(graph['nodes'])
G.add_edges_from(graph['links'])

pos = nx.spring_layout(G)
nx.draw(G, pos=pos, node_size=10, node_color='blue', edge_color='black', width=1, alpha=0.5)

TypeError: object of type 'int' has no len()

In [155]:
graph['links']

{293: [552],
 239: [932],
 9: [35],
 13: [24],
 11: [30],
 78: [2849],
 35: [70],
 110: [798],
 2: [962],
 24: [932],
 47: [782],
 151: [932],
 44: [2499],
 21: [110],
 23: [30],
 32: [552],
 135: [3004],
 14: [41],
 472: [2499],
 1: [27],
 4: [30],
 20: [2849],
 36: [239],
 27: [2259],
 77: [3004],
 70: [97],
 33: [35],
 3: [13],
 798: [932],
 30: [35],
 2259: [3004],
 552: [932],
 12: [21],
 1503: [3004],
 62: [239],
 41: [70],
 43: [110],
 128: [1503],
 2499: [3194],
 2849: [3798],
 0: [13],
 2888: [3004],
 1671: [2849],
 97: [147]}

In [151]:
kmapper = km.KeplerMapper(verbose=0)
projected_data = kmapper.fit_transform(activations, projection='l2norm')
mgraph = kmapper.map(projected_data, activations, clusterer=DBSCAN(eps=8))

mgraph['links']

defaultdict(list,
            {'cube6_cluster0': ['cube7_cluster10'],
             'cube7_cluster0': ['cube8_cluster0'],
             'cube7_cluster4': ['cube8_cluster7'],
             'cube7_cluster5': ['cube8_cluster16'],
             'cube7_cluster6': ['cube8_cluster26'],
             'cube7_cluster10': ['cube8_cluster24'],
             'cube7_cluster11': ['cube8_cluster5'],
             'cube7_cluster12': ['cube8_cluster19'],
             'cube7_cluster13': ['cube8_cluster10'],
             'cube7_cluster14': ['cube8_cluster11'],
             'cube7_cluster20': ['cube8_cluster12'],
             'cube7_cluster22': ['cube8_cluster15'],
             'cube7_cluster25': ['cube8_cluster29'],
             'cube7_cluster27': ['cube8_cluster1'],
             'cube7_cluster28': ['cube8_cluster23'],
             'cube7_cluster29': ['cube8_cluster9'],
             'cube7_cluster33': ['cube8_cluster30'],
             'cube7_cluster34': ['cube8_cluster31'],
             'cube7_cluster35': ['cube

In [145]:
np.vstack([np.mean(activations.iloc[graph['nodes'][node_name]], axis=0) for node_name in graph['nodes']]).shape
# [activations.iloc[list(graph['nodes'][node_name])] for node_name in graph['nodes']]

(52, 768)