In [1]:
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid

from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

import time
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

import os

# !pip install einops
# !wget -O utils.py https://drive.google.com/uc?id=1O3QWBKrqA7s8nIGzhKMIz-YNK1-jzwml
# import utils
# from utils import plot_loss_and_accuracy

In [2]:
 def make_model():
    model = nn.Sequential(
        nn.Conv1d(in_channels=30, out_channels=50, kernel_size=3, padding=1), # [1, 50, 40]
        nn.BatchNorm1d(50),
        nn.ReLU(),
        nn.MaxPool1d(kernel_size=2),   # [1, 50, 20]

        nn.Flatten(),                  # [1, 1000]
        nn.Linear(1000, 500),          # [1, 500]
        nn.Dropout(0.3),
        nn.Linear(500, 10),            # [1, 10]
        # nn.LogSoftmax(dim=1)
    )
    return model

In [3]:
"""
This file contains implementation of metrics from A Dirichlet Mixture Model 
of Hawkes Processes for Event Sequence Clustering
https://arxiv.org/pdf/1701.09177.pdf
"""
import torch
import numpy as np


def consistency(trials_labels):
    """
    Args:
    - trials_labels - array-like sequence of 1-D tensors. Each tensor is a sequence of labels
    """
    J = len(trials_labels)
    values = torch.zeros(J)
    
    for trial_id, labels in enumerate(trials_labels):
        ks = torch.unique(labels)
        sz_M = 0 # number of pairs within same cluster
        for k in ks:
            mask = labels == k
            sz = mask.sum()
            s = sz * (sz - 1.) / 2.
            sz_M += s

        for trial_id2, labels2 in enumerate(trials_labels):
            if trial_id == trial_id2:
                continue

            for k in ks:
                mask = labels == k
                s2 = 0
                for k2 in labels2[mask].unique():
                    sz = (labels2[mask] == k2).sum()  # same cluster within j trial, same cluster within j' trial
                    s2 += sz * (sz - 1.) / 2.
                #values[trial_id] += (sz_M - s2) / ((J-1) * sz_M)
                values[trial_id] += s2 
        values[trial_id] /= ((J-1) * sz_M)
    
    return torch.min(values)


def purity(learned_ids, gt_ids):
    """
    Args:
    - learned_ids - 1-D tensor of labels obtained from model
    - gt_ids - 1-D tensor of ground truth labels
    """
    assert len(learned_ids) == len(gt_ids)
    pur = 0
    ks = torch.unique(learned_ids)
    js = torch.unique(gt_ids)
    for k in ks:
        inters = []
        for j in js:
            inters.append(((learned_ids == k) * (gt_ids == j)).sum().item())
        pur += 1./len(learned_ids) * max(inters)

    return pur

In [4]:
# def compute_accuracy(logits, y_true, device='cuda:0'):
#     y_pred = torch.argmax(logits, dim=1)
#     return (y_pred == y_true.to(device)).float().mean()

In [5]:
# def accuracy(model, images, labels):
#     logits = model.forward(images)
#     y_pred = torch.argmax(logits, dim=1)
#     return (y_pred == labels).float().mean()

In [6]:
def iterate_minibatches(inputs, batchsize, shuffle=False):
    if shuffle:
        indices = np.random.permutation(len(inputs))
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], excerpt

In [7]:
def sinkhorn(P, r, c, lmbd=0.05, eps=1e-3):
    P = -P
    Q = torch.exp(- lmbd * P).T
    Q /= torch.sum(Q)
    K, N = Q.shape
    u = torch.zeros(K)
    res = []
    res.append(-torch.sum(torch.trace(Q.T @ P.T)))
    while True:
        Q_prev = Q.clone()
        u = torch.sum(Q, dim=1)
        Q *= (c / u).unsqueeze(1)
        Q *= (r / torch.sum(Q, dim=0)).unsqueeze(0)
        res.append(-torch.sum(torch.trace(Q.T @ P.T)))
        if torch.sum(((Q-Q_prev)/Q_prev)**2) <eps:
            break
    return Q.T, res

def opt_sk(model, epoch):
    PS = np.zeros((K, N))
    
    for batch_idx, (data, _selected) in enumerate(iterate_minibatches(X_train, batch_size, shuffle=False)):
        data = data.to(device) # [100, 30, 40]
        p = nn.functional.softmax(model(data), 1) # [100, 10]
        PS[_selected, :] = p.detach().cpu().numpy()
    
    r = torch.ones(K)/K
    c = torch.ones(N)/N
    selflabels, res = sinkhorn(torch.from_numpy(PS), r, c)

    return selflabels

In [8]:
def train(model, num_epochs=2, lr=0.1, weight_decay=0, exp_name='my_network'):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=7, gamma=0.1)

    train_loss = []
    val_loss = []
    train_accuracy = []
    val_accuracy = []

    for epoch in range(num_epochs):
        print (f'epoch {epoch}')
        start_time = time.time()
        model.train(True)

        train_accuracy_batch = []

        for batch_idx, (inputs, indexes) in enumerate(iterate_minibatches(X_train, batch_size, shuffle=False)):
            inputs = inputs.float().to(device)
            
            with torch.no_grad():
                selflabels = opt_sk(model, epoch) # [10, 1000]. should be [100, 10]
                
            inputs = inputs.to(device)
            opt.zero_grad()

            outputs = model(inputs) # [100, 10]
            print (selflabels.shape, outputs.shape)
            
            loss = nn.CrossEntropyLoss(outputs, selflabels[indexes])
           
            loss.backward()
            opt.step()
                                                    
        # train_accuracy_overall = np.mean(train_accuracy_batch) * 100 # mean train accuracy over 1 epoch in %
        # train_accuracy.append(train_accuracy_overall.item())

    return train_accuracy_overall, 0

In [9]:
N = 1000
K = 10
# selflabels - [KxN]
X_train = torch.Tensor(N, 30, 40) # [samples, n_steps, features]

X_train.shape

torch.Size([1000, 30, 40])

In [10]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cuda:0'

In [11]:
model = make_model()
model = model.to(device)

In [12]:
train_accuracy = -1
val_accuracy = -1

DO_TRAIN = True
if DO_TRAIN:
    %%time
    lr=0.01
    exp_name = datetime.now().isoformat(timespec='seconds') + f'exp{lr}'
    batch_size = 100
    train_accuracy, val_accuracy = train(model, num_epochs=2, weight_decay=0, lr=lr, exp_name=exp_name)

    torch.save(model.state_dict(), "checkpoint.pth")
    print (train_accuracy, val_accuracy)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.48 µs
epoch 0


ValueError: ignored