In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from zipfile import ZipFile
import pandas as pd
import torch
import random
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm, trange

from HP import PointProcessStorage, DirichletMixtureModel, EM_clustering
from metrics import consistency, purity
from data_utils import load_data

### IPTV Data

In [None]:
path = Path('../..', 'data', 'IPTV_Data')
ss, Ts, class2idx = load_data(path, nfiles=300, type='txt', datetime=True, maxlen=-1)

In [None]:
for i in range(len(ss)):
    ss[i] = ss[i][:1000]

In [None]:
N = len(ss)

In [None]:
D = 3
basis_fs = [lambda x: torch.exp(- x**2 / (2.*(k+1)**2) ) for k in range(D)]

In [None]:
C = len(class2idx)
K = 15

In [None]:
ntrials = 10
niter = 5

labels = torch.zeros(ntrials, len(ss))
nlls = torch.zeros(ntrials, niter)

for i in trange(ntrials):
    Sigma = (torch.rand(C, C)).unsqueeze(-1).unsqueeze(-1).repeat(1,1, D, K) * 10
    B = torch.rand(C, K) * 10
    alpha = 1.

    train_ids = np.sort(np.random.choice(np.arange(len(ss)), size=len(ss) // 2, replace=False))
    train_fold = [ss[i] for i in range(len(ss)) if i in train_ids]
    train_Ts = Ts[train_ids]
    
    # learn
    hp = PointProcessStorage(train_fold, train_Ts, basis_fs)
    model = DirichletMixtureModel(K, C, D, alpha, B, Sigma)
    EM = EM_clustering(hp, model, n_inner=5)

    r, nll_history = EM.learn_hp(niter=niter)

    # validate
    EM.hp = PointProcessStorage(ss, Ts, basis_fs)
    EM.N = len(ss)
    EM.int_g = []
    EM.g = []
    r = EM.e_step()
    
    labels[i] = r.argmax(-1)
    nlls[i] = torch.FloatTensor(nll_history)

    print(f'Consistency of clustering: {consistency(labels[:i+1]).item():.4f}')

In [None]:
assert (model.A >= 0).all()
assert (model.mu > 0).all()

In [None]:
plt.figure(figsize=(9, 5))
plt.grid()
plt.plot(np.arange(niter)+1, nlls.mean(0).numpy() / len(ss))
plt.fill_between(np.arange(niter)+1, (nlls.mean(0).numpy() - nlls.std(0).numpy()) / len(train_ids), (nlls.mean(0).numpy() + nlls.std(0).numpy()) / len(train_ids), alpha=0.2)
plt.title('Mixing DMMHP', fontsize=15)
plt.xlabel(r'$n$ outer iterations', fontsize=15)
plt.ylabel(r'$\sim$ NLL / $N$', fontsize=15)
plt.show()

In [None]:
print(f'Consistency of clustering: {consistency(labels).item():.4f}')

### Synthetic data

In [None]:
path = Path('../..', 'data', 'simulated_Hawkes', 'K3_C5')
ss, Ts, class2idx = load_data(path, nfiles=300, maxlen=-1, endtime=100, type='csv')

In [None]:
gt_ids = pd.read_csv(Path(path, 'clusters.csv'))['cluster_id'].to_numpy()

In [None]:
gt_ids = torch.LongTensor(gt_ids)

In [None]:
N = len(ss)
D = 5
basis_fs = [lambda x: torch.exp(- x**2 / (2.*(k+1)**2) ) for k in range(D)]

hp = PointProcessStorage(ss, Ts, basis_fs)

C = len(class2idx)
K = 3

Sigma = torch.rand(C, C).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, D, K) * 10.
B = torch.rand(C, K) * 10.
alpha = 1.

model = DirichletMixtureModel(K, C, D, alpha, B, Sigma)
EM = EM_clustering(hp, model, n_inner=5)

pi = torch.FloatTensor([1./ K for _ in range(K)])

In [None]:
niter = 5

labels = torch.zeros(len(ss))
nlls = torch.zeros(niter)

Sigma = (torch.eye(C, C) + torch.rand(C, C)).unsqueeze(-1).unsqueeze(-1).repeat(1,1, D, K) * 10
B = 1 + torch.rand(C, K) * 10
alpha = 1.

model = DirichletMixtureModel(K, C, D, alpha, B, Sigma)
EM.model = model
r, nll_history = EM.learn_hp(niter=niter)

labels = r.argmax(-1)
nlls = torch.FloatTensor(nll_history)

In [None]:
assert (model.A >= 0).all()
assert (model.mu > 0).all()

In [None]:
plt.figure(figsize=(9, 5))
plt.grid()
plt.plot(np.arange(niter)+1, nlls.numpy() / len(ss))
plt.title('Mixing of DMMHP', fontsize=15)
plt.xlabel(r'$n$ outer iterations', fontsize=15)
plt.ylabel(r'$\sim$ NLL / $N$', fontsize=15)
plt.show()

In [None]:
pur_val = purity(labels, gt_ids)

In [None]:
print(f'Purity: {pur_val}')

In [None]:
labels