In [None]:
import numpy as np
import pandas as pd
import os
import json
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
from joblib import Parallel, delayed
from ast import literal_eval

In [2]:
# Paths
LABEL_FILE = '../data/Supersense-Role/entities/train.txt'
ACTIVATION_FILE = '../data/Supersense-Role/SS_fine_tuned.zip'
GRAPH_OUTPUT_PATH = '../frontend/public/static/mapper_graphs/euclidean_l2_50_50/'

NUM_EPOCHS = 177
NUM_POINTS = 4282

In [4]:
def node_purity(node, label):
    metadata = node['membership']['metadata']
    label_counts = Counter([x[3] for x in metadata])
    return (label_counts[label], len(metadata))

def compute_purities(fileindex):
    # compute purities for all nodes in the file
    graph_path = GRAPH_OUTPUT_PATH
    filename = f'{fileindex}.json'

    with open(os.path.join(graph_path,filename), 'r') as graph_file:
        graph = json.load(graph_file)

        point_node_purities = defaultdict(list)

        for node in graph['nodes']:
            metadata = node['membership']['metadata']

            for i, point_id in enumerate(node['membership']['membership_ids']):
                point_node_purities[point_id].append((i, node_purity(node, metadata[i][3])))
                
    return point_node_purities

# List where each element summarizes the purities of the points for a particular epoch
# len(purities) = num of epochs

In [5]:
%%time
# purities = [compute_purities(fileindex) for fileindex in tqdm(range(NUM_EPOCHS))]
purities = Parallel(n_jobs=-1)(delayed(compute_purities)(fileindex) for fileindex in tqdm(range(NUM_EPOCHS)))

  0%|          | 0/177 [00:00<?, ?it/s]

Wall time: 5.24 s


In [20]:
# consolidate purities into one 
# For each point, track all purities across all epochs

def purity_summary(purities):
    if purities is None:
        return 1
    else:
        plist = [x[1] for x in purities]
        
        numerator = sum([x[0] for x in plist])
        denominator = sum([x[1] for x in plist])
        if denominator != 0:
            return numerator / denominator
        else:
            return 1

point_purity_tracker = defaultdict(list)

for point_number in range(NUM_POINTS):
    for iteration in range(NUM_EPOCHS):
        
        iteration_purity_dict = purities[iteration]
        
        point_purity_tracker[point_number].append((iteration, purity_summary(iteration_purity_dict[point_number])))
        

node_purities_df = pd.DataFrame(columns=['pointNum'] + [f'epoch_{x}_purity' for x in range(NUM_EPOCHS)])

for pointNum in range(NUM_POINTS):
    row = [pointNum] + [x[1] for x in point_purity_tracker[pointNum]]
    node_purities_df.loc[pointNum] = row

node_purities_df.to_csv(os.path.join(GRAPH_OUTPUT_PATH, 'node_purities.csv'), index=False)

In [21]:
point_data = {}

with open('../backend/data/Supersense-Role/entities/train.txt') as point_file:
    for idx, line in enumerate(point_file):
        start, label = line.strip().split('\t')
        loc, word = start.split(':')
        loc = literal_eval(loc)
        point_data[idx] = [loc[0], loc[1], word, label]

In [22]:
with open('../frontend/src/assets/data/point_data_labels.json', 'w') as out:
    json.dump(point_data, out)