In [1]:
import random

from torch.utils import data
import numpy as np
import torch
from tqdm.notebook import tqdm

from pan20 import auth, util
from pan20.auth.trans import distbert
from pan20.auth import pytorch
from pan20.util.pytorch import anneal, config, metrics, opt, stopping, training

In [5]:
util.set_random_seed(42)

In [6]:
X_train, y_train, X_dev, y_dev, X_test, y_test = auth.small()
train = pytorch.Dataset(X_train, y_train)
dev = pytorch.Dataset(X_dev, y_dev)
test = pytorch.Dataset(X_test, y_test)

In [8]:
cfg = config.ExperimentConfig(
    experiment_name='distilbert1',
    ckpt_dir='ckpts/distilbert1',
    results_dir='results/distilbert1',
    lambda_fd=0.5,
    lambda_grad=1.,
    train=training.TrainConfig(
        n_epochs=20,
        seed=42,
        train_batch_size=32,
        run_no=0,
        tune_batch_size=64,
        p_drop=0.1,
        dev_metric='acc',
        memory_limit=8,
        weight_decay=0.
    ),
    anneal=anneal.ReduceLROnPlateauConfig(
        factor=0.5,
        patience=3,
    ),
    optim=opt.AdamWConfig(
        lr=6e-5,
    ),
    stop=stopping.NoDevImprovementConfig(
        patience=3,
        k=3,
        metric='acc'
    )
)

In [9]:
collate = distbert.CollateFirstK()
train_loader = data.DataLoader(
    batch_size=cfg.train.train_batch_size, 
    collate_fn=collate, 
    dataset=train, 
    shuffle=True)
dev_loader = data.DataLoader(
    batch_size=cfg.train.tune_batch_size,
    collate_fn=collate,
    dataset=dev,
    shuffle=False)

In [10]:
net = distbert.DistilBERTComparisonAdvFd1(
    p_drop=cfg.train.p_drop, 
    weight_decay=cfg.train.weight_decay,
    lambda_fd=cfg.lambda_fd,
    lambda_grad=cfg.lambda_grad)
model = training.TrainableModel(net, cfg)

In [12]:
model.train(train_loader, dev_loader)

HBox(children=(FloatProgress(value=0.0, description='epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='iter', max=5326.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='tune', max=313.0, style=ProgressStyle(description_width='…

KeyboardInterrupt: 

In [None]:
model.evaluate(test_loader)

In [None]:
model.model.doc_enc.combine_layers.layer_weights