In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from pathlib import Path
import pandas as pd
import pickle
import numpy as np
import shutil
# from tqdm import tqdm_notebook as tqdm

import torch.utils.data
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, Engine
from ignite.metrics import Accuracy, Loss

from src.models import InsiderClassifier, LSTM_Encoder
from src.params import get_params
from src.dataset import CertDataset, create_data_loaders
from src.cnn_trainer import *

%load_ext autoreload
%autoreload 2

In [2]:
# output_dir = Path(r'C:\Users\Mvideo\Google Drive\Datasets\CERT_output')
# answers_dir = Path(r"C:/Users/Mvideo/Downloads/answers")

output_dir = Path(r'C:\Users\admin\Google Drive\Datasets\CERT_output')
answers_dir = Path(r"C:\Users\admin\Google Drive\Datasets\CERT\answers")
main_answers_file = answers_dir / "insiders.csv"

lstm_checkpoint = output_dir / 'checkpoints/lstm/final2-nll/final_model_3040.pth'
assert(lstm_checkpoint.is_file())

run_name = 'cnn/test3'
log_dir = output_dir / 'logs' / run_name
checkpoint_dir = output_dir / 'checkpoints' / run_name

# assert(not log_dir.is_dir())
# assert(not checkpoint_dir.is_dir())

if log_dir.is_dir():
    shutil.rmtree(log_dir)
if checkpoint_dir.is_dir():
    shutil.rmtree(checkpoint_dir)

In [3]:
actions, targets = CertDataset.prepare_dataset(output_dir / 'aggregated.pkl', main_answers_file, min_length=50, max_length=200)

In [4]:
cert_dataset = CertDataset(actions, targets)
train_loader, val_loader = create_data_loaders(cert_dataset, validation_split=0.3, random_seed=0, batch_size=128)

params = get_params()

device = 'cuda'

In [5]:
model = InsiderClassifier(params['model'], lstm_checkpoint)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters())

train_engine = create_supervised_trainer(
                                        model, optimizer, criterion, device=device,
                                        prepare_batch=prepare_batch,
                                        log_dir=(output_dir / 'logs' / 'cnn' / run_name).as_posix(),
                                        checkpoint_dir=output_dir / 'checkpoints' / 'cnn' / run_name
                                       )

val_engine = create_supervised_evaluator(
        model, device=device,
        prepare_batch=prepare_batch,
        metrics={},
        criterion=criterion,
        log_dir=log_dir.as_posix(),
)

@train_engine.on(Events.STARTED)
def log_training_results(trainer):
    print('Initial validation run:')
    val_engine.train_epoch = 0
    val_engine.run(val_loader)

@train_engine.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    print('Validation run:')
    val_engine.train_epoch = train_engine.state.epoch
    val_engine.run(val_loader)


  from tqdm.autonotebook import tqdm


In [6]:
train_engine.run(train_loader, max_epochs=5)

Initial validation run:


HBox(children=(FloatProgress(value=0.0, max=260.0), HTML(value='')))


Validation Results - Avg loss: 0.775352, Accuracy: 0.003666


HBox(children=(FloatProgress(value=0.0, max=607.0), HTML(value='')))


Validation run:


HBox(children=(FloatProgress(value=0.0, max=260.0), HTML(value='')))


Validation Results - Avg loss: 0.101304, Accuracy: 0.996334


HBox(children=(FloatProgress(value=0.0, max=607.0), HTML(value='')))


Validation run:


HBox(children=(FloatProgress(value=0.0, max=260.0), HTML(value='')))


Validation Results - Avg loss: 0.101304, Accuracy: 0.996334


HBox(children=(FloatProgress(value=0.0, max=607.0), HTML(value='')))


Validation run:


HBox(children=(FloatProgress(value=0.0, max=260.0), HTML(value='')))


Validation Results - Avg loss: 0.101304, Accuracy: 0.996334


HBox(children=(FloatProgress(value=0.0, max=607.0), HTML(value='')))


Validation run:


HBox(children=(FloatProgress(value=0.0, max=260.0), HTML(value='')))


Validation Results - Avg loss: 0.101304, Accuracy: 0.996334


HBox(children=(FloatProgress(value=0.0, max=607.0), HTML(value='')))


Validation run:


HBox(children=(FloatProgress(value=0.0, max=260.0), HTML(value='')))


Validation Results - Avg loss: 0.101304, Accuracy: 0.996334


State:
	iteration: 3035
	epoch: 5
	epoch_length: 607
	max_epochs: 5
	output: <class 'dict'>
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: 12

# Prediction exploration

In [14]:
it = iter(val_loader)
for batch in it:
    x, y = prepare_batch(batch)
    if y.sum() > 0:
        break

In [32]:
ind = (y==1).nonzero()[0]

In [33]:
model(x.to(device))[ind]

tensor([[0., 1.]], device='cuda:0', grad_fn=<IndexBackward>)