In [231]:
import sys
sys.path.insert(0, "../")

from IPython.display import clear_output

import warnings
warnings.filterwarnings('ignore')

import torch
import transformers
import wandb
from datasets import load_dataset
from tqdm.notebook import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_scheduler
from transformers import DistilBertForSequenceClassification
from transformers import AutoTokenizer

import numpy as np

from util.dialogue_manager import  DialogueManagerModel

In [222]:
DEVICE = 'cuda:4'
BATCH_SIZE = 32
SEED = 42

In [223]:
dataset = load_dataset("bavard/personachat_truecased")
persona_qualities = list(set([sent 
                      for item in dataset['train']['personality'] 
                          for sent in item]))
persona_quality_to_id = {quality: i for i, quality in enumerate(persona_qualities)}
print(f'{len(persona_qualities)} persona qualities')
clear_output()

In [224]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [225]:
def tokenize(example):
    outputs = {
        "input_ids": tokenizer("\n-----\n".join(example["history"]), padding='max_length', truncation=True)["input_ids"]
    }
    outputs["labels"] = outputs["input_ids"]
    return outputs

def get_classes(example):
    outputs = {
        "classes": tuple(persona_quality_to_id[item] for item in example["personality"] 
                          if item in persona_quality_to_id.keys())
    }
    return outputs

tokenized_datasets = (
    dataset
        .map(get_classes, num_proc=6)
        .map(tokenize, num_proc=6)
)
clear_output()


In [227]:
dialogue_manager = DialogueManagerModel(n_classes=len(persona_qualities), device=DEVICE, )

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.bias', 'pre_classifi

In [228]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['personality', 'candidates', 'history', 'conv_id', 'utterance_idx', 'classes', 'input_ids', 'labels'],
        num_rows: 131438
    })
    validation: Dataset({
        features: ['personality', 'candidates', 'history', 'conv_id', 'utterance_idx', 'classes', 'input_ids', 'labels'],
        num_rows: 7801
    })
})

In [236]:
train_dataset = tokenized_datasets["train"].shuffle(seed=SEED)
val_dataset = tokenized_datasets["validation"].shuffle(seed=SEED)

Loading cached shuffled indices for dataset at /home/jagiljazev/.cache/huggingface/datasets/bavard___personachat_truecased/full/1.0.0/73ee8f1a0d9e42255af5a8301877a2f3ac638e55b1cd9cbccca5ab7e23d2b638/cache-cd33c68f172c8d3f.arrow


In [245]:
lr = 1e-4
batch_size = 16
n_classes = len(persona_qualities)
optimizer = torch.optim.Adam(dialogue_manager.model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
n_epochs = 8

In [270]:
wandb.init(
    project="dialogue-manager-model-distilbert-base-uncased",
    config={
        "batch_size": batch_size, 
        "lr": lr, 
        "optimizer": 'adam', 
        'model_name': 'distilbert-base-uncased'
    }
)

0,1
Ap@k,▁
Train Loss,▇█▇▇▇▇▇▆▇▇▆▇█▆▆▇▇▆▆▇▆▇▇▇▆▇▆▆▆▆▆█▇▅▆▄▅▃▂▁
Validation Loss,▁

0,1
Ap@k,0.0
Train Loss,8.54361
Validation Loss,2.56849


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666895578382537, max=1.0)…

In [275]:
def p_k(outputs, y_true, k=10):
    true_num = len(y_true)
    best_ind = np.argpartition(outputs, -k)[-k:]
    precision = len(set(best_ind).intersection(set(y_true))) / true_num
    return precision

def ap_k(outputs, y_trues):
    precisions = [p_k(output, y_true) for output, y_true in zip(outputs, y_trues)]
    return np.mean(precisions)

In [276]:
val_freq = 50
val_size = 1000

In [277]:
def validate(model):
    loss = []
    predictions = []
    ys = []
    batch_size = 64
    with torch.no_grad():
        for i in tqdm(range(val_size // batch_size)):
            batch = val_dataset[i * batch_size: (i + 1) * batch_size]
            X = torch.tensor(batch['input_ids']).to(DEVICE)
            y = torch.zeros((X.shape[0], n_classes))
            for i, col in enumerate(batch['classes']):
                t = len(col)
                for val in col:
                    y[i, val] = 1 / t
            y = y.to(DEVICE)
            outputs = dialogue_manager(X)['logits']
            loss.append(criterinon(outputs, y).detach().cpu().numpy())
            predictions += [item.detach().cpu().numpy() for item in outputs]
            ys += [item.detach().cpu().numpy() for item in y]
    return ap_k(predictions, ys), np.mean(loss)

In [None]:
for epoch in tqdm(range(n_epochs)):
    for j in tqdm(range(len(train_dataset) // batch_size)):
        batch = train_dataset[j * batch_size: (j + 1) * batch_size]
        X = torch.tensor(batch['input_ids']).to(DEVICE)
        y = torch.zeros((X.shape[0], n_classes))
        for i, col in enumerate(batch['classes']):
            t = len(col)
            for val in col:
                y[i, val] = 1 / t
        y = y.to(DEVICE)
        dialogue_manager.train()
        outputs = dialogue_manager(X)['logits']
        loss = criterion(outputs, y)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"Train Loss": loss.detach().cpu().numpy()})
        if j % val_freq == 1:
            ap_k_val, val_loss = validate(dialogue_manager)
            wandb.log({"Ap@k": ap_k_val, 'Validation Loss': val_loss})


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8214 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]