In [None]:
import numpy as np

In [None]:
def load_dataset(path):
    points = []
    max_index = -1
    with open(path) as f:
        for line in f:
            indices = [int(group_count.split(':')[0]) - 1 for group_count in line.split(' ')[1:]]
            points.append(indices)
            max_index = max(max_index, max(indices))

    return points, max_index + 1

In [None]:
class Dataset:
    def __init__(self, points, features, name):
        self.shape = len(points), features
        self.points = points
        self.name = name

In [None]:
def get_column_perm(data, perm, parts):
    examples = np.full((parts, data.shape[1]), 0)
    per_part = data.shape[0] // parts
    for i in range(data.shape[0]):
        part = min(parts - 1, i // per_part)
        for index in data.points[perm[i]]:
            examples[part, index] += 1
    ftrs = [()] * data.shape[1]
    for i in range(data.shape[1]):
        group_count = np.sum(examples[:, i])
        owner = np.argmax(examples[:, i])
        if group_count == 0:
            target = 0.0
            owner = np.random.randint(0, parts)
        else:
            target = np.max(examples[:, i]) / group_count
        ftrs[i] = (target, owner, i)
    ftrs.sort(reverse=True)
    groups = [[] for _ in range(parts)]
    for _, owner, i in ftrs:
        groups[owner].append(i)
    owner = {}

    for i in range(parts):
        for f in groups[i]:
            owner[f] = i
    order = [i for l in groups for i in l]
    return order, owner


def show_dataset(dataset, parts, perm=None, column_perm=None, alpha=0.01, beta=0.01, seed=None):
    if seed is not None:
        np.random.seed(seed)
    if perm is None:
        perm = list(range(dataset.shape[0]))
    if column_perm is None:
        column_perm = get_column_perm(dataset, perm, parts)
    v_index = np.random.choice(dataset.shape[0], int(dataset.shape[0] * alpha), replace=False)
    v_index.sort()

    h_index = set(np.random.choice(dataset.shape[1], int(dataset.shape[1] * beta), replace=False))
    num = {}
    for i, index in enumerate([x for x in column_perm[0] if x in h_index]):
        num[index] = i

    colors = [[255, 255, 255], [255, 0, 0], [0, 255, 0]]
    per_part = dataset.shape[0] // parts
    pic = np.full((len(v_index), len(h_index), 3), [0, 0, 0], dtype=np.uint8)
    for ii, i in enumerate(v_index):
        part = min(parts - 1, i // per_part)
        indices = dataset.points[perm[i]]
        for index in indices:
            if index in h_index:
                f_part = column_perm[1][index]
                pic[ii, num[index]] = colors[(f_part + part) % len(colors)]

    from PIL import Image
    img = Image.fromarray(pic, 'RGB')
    img.show("Split {} into {} parts".format(dataset.name, part))



In [None]:
from os import path
import os

base_dir = "../permutations"
datasets = [f for f in os.listdir(base_dir) if path.isdir(path.join(base_dir, f))]
for dataset in datasets:
    subdir = path.join(base_dir, dataset)
    p, m = load_dataset("../data/{}".format(dataset))
    d = Dataset(p, m, dataset)
    for perm_file in [f for f in os.listdir(subdir) if path.isfile(path.join(subdir, f))]:
        parts = int(perm_file[:-4])
        with open(path.join(subdir, perm_file)) as f:
            permutation = [int(line) for line in f]
        show_dataset(d, parts, permutation, seed=42, alpha=0.01, beta=0.2)