In [None]:
import sys
root_dir = '../../../'
sys.path.append(root_dir)

In [1]:
import pickle
import numpy as np
from collections import defaultdict
from itertools import combinations, product
from itertools import groupby
from scipy.special import gamma
from scipy.spatial.distance import pdist, squareform, euclidean
from utils import *

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
def volume(r, m):
    return np.pi ** (m / 2) * r ** m / gamma(m / 2 + 1)


def significant(cluster, h, p):
    max_diff = max(abs(p[i] - p[j]) for i, j in product(cluster, cluster))
    return max_diff >= h


def partition(dist, l, r, order):
    if l == r:
        return l

    pivot = dist[order[(l + r) // 2]]
    left, right = l - 1, r + 1
    while True:
        while True:
            left += 1
            if dist[order[left]] >= pivot:
                break

        while True:
            right -= 1
            if dist[order[right]] <= pivot:
                break

        if left >= right:
            return right

        order[left], order[right] = order[right], order[left]


def nth_element(dist, order, k):
    l, r = 0, len(order) - 1
    while True:
        if l == r:
            break
        m = partition(dist, l, r, order)
        if m < k:
            l = m + 1
        elif m >= k:
            r = m


def get_clustering(x, k, h, verbose=True):
    n = len(x)
    if isinstance(x[0], list):
        m = len(x[0])
    else:
        m = 1
    dist = squareform(pdist(x))

    dk = []
    for i in range(n):
        order = list(range(n))
        nth_element(dist[i], order, k - 1)
        dk.append(dist[i][order[k - 1]])

    p = [k / (volume(dk[i], m) * n) for i in range(n)]

    w = np.full(n, 0)
    completed = {0: False}
    last = 1
    vertices = set()
    for d, i in sorted(zip(dk, range(n))):
        neigh = set()
        neigh_w = set()
        clusters = defaultdict(list)
        for j in vertices:
            if dist[i][j] <= dk[i]:
                neigh.add(j)
                neigh_w.add(w[j])
                clusters[w[j]].append(j)

        vertices.add(i)
        if len(neigh) == 0:
            w[i] = last
            completed[last] = False
            last += 1
        elif len(neigh_w) == 1:
            wj = next(iter(neigh_w))
            if completed[wj]:
                w[i] = 0
            else:
                w[i] = wj
        else:
            if all(completed[wj] for wj in neigh_w):
                w[i] = 0
                continue
            significant_clusters = set(wj for wj in neigh_w if significant(clusters[wj], h, p))
            if len(significant_clusters) > 1:
                w[i] = 0
                for wj in neigh_w:
                    if wj in significant_clusters:
                        completed[wj] = (wj != 0)
                    else:
                        for j in clusters[wj]:
                            w[j] = 0
            else:
                if len(significant_clusters) == 0:
                    s = next(iter(neigh_w))
                else:
                    s = next(iter(significant_clusters))
                w[i] = s
                for wj in neigh_w:
                    for j in clusters[wj]:
                        w[j] = s
    return w

In [4]:
class WishartClusterization(object):
    def __init__(self, k, h):
        self.k = k
        self.h = h

    def fit(self, x):
        n = len(x)
        if isinstance(x[0], list):
            m = len(x[0])
        else:
            m = 1
        dist = squareform(pdist(x))

        dk = []
        for i in range(n):
            order = list(range(n))
            nth_element(dist[i], order, self.k - 1)
            dk.append(dist[i][order[self.k - 1]])

        # print(dk)

        p = [self.k / (volume(dk[i], m) * n) for i in range(n)]

        w = np.full(n, 0)
        completed = {0: False}
        last = 1
        vertices = set()
        for d, i in sorted(zip(dk, range(n))):
            neigh = set()
            neigh_w = set()
            clusters = defaultdict(list)
            for j in vertices:
                if dist[i][j] <= dk[i]:
                    neigh.add(j)
                    neigh_w.add(w[j])
                    clusters[w[j]].append(j)

            vertices.add(i)
            if len(neigh) == 0:
                w[i] = last
                completed[last] = False
                last += 1
            elif len(neigh_w) == 1:
                wj = next(iter(neigh_w))
                if completed[wj]:
                    w[i] = 0
                else:
                    w[i] = wj
            else:
                if all(completed[wj] for wj in neigh_w):
                    w[i] = 0
                    continue
                significant_clusters = set(wj for wj in neigh_w if significant(clusters[wj], self.h, p))
                if len(significant_clusters) > 1:
                    w[i] = 0
                    for wj in neigh_w:
                        if wj in significant_clusters:
                            completed[wj] = (wj != 0)
                        else:
                            for j in clusters[wj]:
                                w[j] = 0
                else:
                    if len(significant_clusters) == 0:
                        s = next(iter(neigh_w))
                    else:
                        s = next(iter(significant_clusters))
                    w[i] = s
                    for wj in neigh_w:
                        for j in clusters[wj]:
                            w[j] = s
        self.labels_ = w
        return self

In [5]:
lorenz = pd.read_csv(root_dir + '/data/lorenz.csv', header=None).iloc[:, 0].values
X_train = lorenz[4_000: 24_000]

In [6]:
patterns = get_patterns(
    pattern_size=3, 
    max_distance=10, 
    patterns_percent=100
)

motifs = get_motifs(X_train, patterns)

In [7]:
WISHART_K = 4
WISHART_H = 0.2
motif_size = motifs[next(iter(motifs))].shape[-1]

inputs = []
for pattern, train in motifs.items():
    _input = (pattern, train)
    inputs.append(_input)

def task(x):
    pattern, train = x
    w = get_clustering(train, WISHART_K, WISHART_H)
    return pattern, w

results = run_concurrently(task, inputs)

ws = {}
for pattern, w in results:
    ws[pattern] = w

wishart_motifs = {}
for pattern, w in ws.items():
    sorted_by_cluster = sorted(range(len(w)), key=lambda x: w[x])
    for wi, cluster in groupby(sorted_by_cluster, lambda x: w[x]):
        cluster = list(cluster)
        center = np.full(motif_size, 0.0)
        for i in cluster:
            center += motifs[pattern][i]
        center = center / len(cluster)
        wishart_motifs.setdefault(pattern, []).append(center)
    wishart_motifs[pattern] = np.array(wishart_motifs[pattern])

In [8]:
with open(f'wishart-motifs-lorenz.pickle', 'wb') as f:
    pickle.dump(wishart_motifs, f)