In [1]:
import pickle
from collections import defaultdict

from framework.losses import interpoint_distance

In [2]:
def _load(name):
    with open(name, "rb") as f:
        return pickle.load(f)

In [3]:
results = _load("experimenttwo.pkl")

In [4]:
assert len(results) == 120

In [5]:
def group(results):
    
    def _sort(groups):
        
        sorted_groups = {}
        for key, x in groups.items():
            sorted_groups[key] = sorted(x, key=lambda k: k['size'])
            
        return sorted_groups
    
    groups = defaultdict(list)
    for result in results:
        samp, dataset, algo = result['sampling'], result['dataset'], result['algorithm']
        key = tuple([samp, dataset, algo])
        groups[key].append({
            "emb_x": result['emb_x'],
            "y": result['y'],
            "size": result['size']
        })
    return _sort(dict(groups))

In [6]:
def groups_to_interpoint(groups):
    
    def _do_interpoint(arr):
        embeddings = [x['emb_x'] for x in arr]
        interpoints = []
        for i in range(0, len(embeddings) - 1):
            j = i + 1
            interpoints.append(interpoint_distance(embeddings[i], embeddings[j]))
        return interpoints
        
    interpoint_groups = {}
    for key, value in groups.items():
        interpoint_groups[key] = _do_interpoint(value)

    return interpoint_groups

In [8]:
interpoint_results = groups_to_interpoint(group(results))

In [9]:
interpoint_results

{('random', 'mnist', 'umap'): [5083.291705131531,
  4914.576036453247,
  3303.4438693523407,
  2253.2539452314377,
  6785.8734956383705,
  6481.827231884003,
  2544.8183591365814,
  1726.6064143199474,
  1437.9420633390546],
 ('random', 'mnist', 'tsne'): [6261.594577617778,
  8809.71045379114,
  15727.571824024872,
  10456.128242940964,
  13695.623971621513,
  4346.810568908448,
  6114.2991616742,
  3604.811768151622,
  2148.5928544382436],
 ('random', 'fmnist', 'umap'): [49498.65532368049,
  272089.322820127,
  19302.23217964638,
  18920.69434143789,
  293641.10985150933,
  19707.126275768504,
  15329.575989204226,
  21154.687600805424,
  9046.29302257998],
 ('random', 'fmnist', 'tsne'): [277305.1110497446,
  187423.9889318488,
  219887.54962398365,
  122513.89461758616,
  130685.49806552648,
  97065.9975633683,
  108372.28874198651,
  88124.22150735593,
  72915.28367245979],
 ('random', 'olivetti', 'umap'): [1209.621063709259,
  1231.1787400245667,
  684.2414488196373,
  531.00749838