In [None]:
import numpy as np
import torch
from torch import utils
import pandas as pd
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch import nn 
from torch.nn import functional as F
import pytorch_lightning as pl
from matplotlib import cm
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from scipy import signal as sig
import os
from pathlib import Path
import re
from torch.utils import data
import random
import pandas as pd
import numpy as np
from pathlib import Path
from dataloader import LandmarkDataset, SequenceDataset, LandmarkWaveletDataset
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import normalized_mutual_info_score, confusion_matrix, accuracy_score

pd.set_option('mode.chained_assignment', None)
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)

In [None]:
# root directory of data
data_root = Path("/home/orel/Storage/Data/K6/")
landmark_files = []
for subdir in os.listdir(data_root):
    for file in os.listdir(data_root/subdir/'Down'):
        if re.match(r"00\d*DeepCut_resnet50_Down2May25shuffle1_1030000\.h5", file):
            lfile = data_root/subdir/'Down'/file
            landmark_files.append(lfile)

In [None]:
%pdb on
from simple_autoencoder import Autoencoder, PLWaveletAutoencoder
model = PLWaveletAutoencoder(landmark_files[:5], n_neurons=[480, 512, 512, 30], lr=1e-3, patience=20)
# model.prepare_data()

In [None]:
bx = next(iter(model.train_dataloader()))

In [None]:
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=10, max_epochs=50, logger=pl.loggers.WandbLogger("wavelet landmarks autoencoder"))
trainer.fit(model)

In [None]:
K = 30
X_encoded = model.model.encode(model.all_ds)
kmeans = KMeans(K)
labels = kmeans.fit_predict(X_encoded)

In [None]:
from collections import Counter

Counter(labels).most_common(), len(set(labels)), len(labels)

In [None]:
from collections import Counter

Counter(labels).most_common(), len(set(labels)), len(labels)

In [None]:
from collections import Counter

Counter(labels).most_common()

In [None]:
def split_labels(labels):
    split_at = np.where(np.diff(labels) != 0)[0] + 1
    sequence = [[seg[0], split_at[i-1]*4, len(seg)*4] \
                for i, seg in enumerate(np.split(labels, indices_or_sections=split_at))]
    sequence[0][1] = 0
    return sequence

segments = split_labels(labels)

In [None]:
clusters = set([s[0] for s in segments])
segment_lengths_by_cluster = {c: [seg[2] for seg in segments if seg[0] == c] for c in clusters}


In [None]:
ncols = 5; nrows = 6
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(24, 24))
bins=np.log(np.linspace(1, 1000, 100))
for i in range(nrows):
    for j in range(ncols):
        cluster_id = i*ncols + j
        axes[i][j].set_title(f"cluster {cluster_id}, with {len(segment_lengths_by_cluster[cluster_id])} segments")
        axes[i][j].hist(segment_lengths_by_cluster[cluster_id], bins=100, log=False, density=False)

In [None]:
ncols = 5; nrows = 6
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(24, 24))
bins=np.log(np.linspace(1, 1000, 100))
for i in range(nrows):
    for j in range(ncols):
        cluster_id = i*ncols + j
        axes[i][j].set_title(f"cluster {cluster_id}, with {len(segment_lengths_by_cluster[cluster_id])} segments")
        axes[i][j].hist(np.log(segment_lengths_by_cluster[cluster_id]), bins=100, log=True, density=True)

In [None]:
def split_labels(labels):
    split_at = np.where(np.diff(labels) != 0)[0] + 1
    sequence = [[seg[0], split_at[i-1]*4 + model.seqlen*2, len(seg)*4] \
                for i, seg in enumerate(np.split(labels, indices_or_sections=split_at))]
    sequence[0][1] = model.seqlen*2
    return sequence

labels_dict = dict(zip(landmark_files, 
                        np.split(labels, indices_or_sections=video_change_idxs)))

data_dict = dict(zip(landmark_files,
                    np.split(all_data, indices_or_sections=video_change_idxs)))

X_encoded_dict = dict(zip(landmark_files,
                         np.split(X_encoded, indices_or_sections=video_change_idxs)))


segment_dict = dict(zip(landmark_files, 
                        map(split_labels, np.split(labels, indices_or_sections=video_change_idxs))))
