<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Tabnet-as-event-encoder" data-toc-modified-id="Tabnet-as-event-encoder-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Tabnet as event encoder</a></span></li></ul></div>

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

sns.set_style("whitegrid", {'axes.grid' : False})

from tqdm.auto import tqdm

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
torch.cuda.is_available()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

In [None]:
from pytorch_metric_learning import losses, miners, distances, reducers, samplers
import torch.optim as optim
from sklearn.model_selection import train_test_split
import torch.nn as nn
import random

In [None]:
from code.dataloader import AgeGroupMLDataset, AgeGroupClfDataset
from code.encoder_tabnet import Encoder
from code.decoder import Decoder
from code.classifier import Classifier
from code.utils import train_ml_model, train_classifier

In [None]:
BATCH_SIZE = 64 # BATCH_SIZE unique persons
NUM_OF_SUBSEQUENCES = 5
SUBSEQUENCE_LENGTH = 90

EMBEDDING_DIM = 256
LR = 0.002
NUM_EPOCHS = 20

cat_vocab_sizes = [204]
cat_embedding_dim = 20
num_input_dim = 4
NUM_OBS = 30000

In [None]:
arch = 'tabnet'

In [None]:
dataset = AgeGroupMLDataset(num_observations=NUM_OBS)

In [None]:
dataset.load_client_to_indices()

In [None]:
clfdataset = AgeGroupClfDataset()

In [None]:
clfdataset.load_client_to_indices()

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [None]:
targets = dataset.targets

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE,
    num_workers=0,
)

targets = clfdataset.targets

train_idx, test_idx= train_test_split(
    np.arange(len(targets)),
    test_size=0.3,
    shuffle=True,
    stratify=targets,
    random_state=228
)

train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)

trainloader = torch.utils.data.DataLoader(
    clfdataset, batch_size=BATCH_SIZE,
    sampler=train_sampler)
testloader = torch.utils.data.DataLoader(
    clfdataset, batch_size=BATCH_SIZE,
    sampler=test_sampler)

## Tabnet as event encoder

In [None]:
LR = 0.002

# train decoder

encoder = Encoder(
    numerical_input_dim=num_input_dim,
    cat_vocab_sizes=cat_vocab_sizes,
    cat_embedding_dim=cat_embedding_dim,
    embedding_dim=EMBEDDING_DIM,
)
encoder.to(device);
optimizer = optim.Adam(encoder.parameters(), lr=LR)

distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low = 0) # basically, returns average
loss_func = losses.TripletMarginLoss(margin = 0.4, distance = distance, reducer = reducer)
mining_func = miners.TripletMarginMiner(margin = 0.4, distance = distance, type_of_triplets = "semihard")

In [None]:
train_losses = train_ml_model(
    encoder, NUM_EPOCHS, dataloader, NUM_OF_SUBSEQUENCES,
    mining_func, loss_func, optimizer)

In [None]:
fig, axs = plt.subplots(figsize=(12, 6))

plt.plot(train_losses, label='train')
plt.xlabel('iter')
plt.ylabel('loss');
plt.savefig(f'plots/ML_{arch}_{EMBEDDING_DIM}_{NUM_OBS}_{NUM_EPOCHS}.png')

todo: interpretability from tabnet

In [None]:
NUM_EPOCHS=50

In [None]:
SCHEDULER_EPOCHS = 2
LR = 0.002

# train classifier decoder

classifier = Classifier(
    numerical_input_dim=num_input_dim,
    cat_vocab_sizes=cat_vocab_sizes,
    cat_embedding_dim=cat_embedding_dim,
    embedding_dim=EMBEDDING_DIM
)
classifier.encoder = encoder
classifier.freeze_encoder()
classifier.to(device);

optimizer = optim.Adam(classifier.decoder.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    patience=SCHEDULER_EPOCHS,
)

In [None]:
train_losses, train_accuracy, val_losses, val_accuracy = train_classifier(
    classifier, NUM_EPOCHS, trainloader, testloader,
    optimizer, criterion, scheduler,
    enable_train_mode = lambda: classifier.decoder.train(),
    enable_test_mode = lambda: classifier.decoder.eval(),
)

In [None]:
fig, axs = plt.subplots(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='validation')
plt.xlabel('iter')
plt.ylabel('loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.xlabel('iter')
plt.ylabel('accuracy');
plt.plot(train_accuracy, label='train')
plt.plot(val_accuracy, label='validation')
plt.legend()

plt.savefig(f'plots/clfdec_{arch}_{EMBEDDING_DIM}_{NUM_OBS}_{NUM_EPOCHS}.png')