In [1]:


import numpy as np
from datasets import DatasetDict
from datasets import load_metric
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding

dataset = DatasetDict.load_from_disk('/home/pavel/work/active_learning_project/exploded_dataset')

f1 = load_metric('f1')


def compute_metrics(eval_preds):
    metric = load_metric('f1')
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels, average='weighted')



In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [3]:
def preprocess_collator(batch):
    print(batch)
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    # pad inputs and labels
    # print(features)
    # data={'input_ids':batch['input_ids'], 'token_type_ids':batch['token_type_ids'], 'attention_mask':batch['attention_mask']}
    batch = tokenizer.pad()
    return batch

In [4]:
# {TypeError}TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]
def preprocess_function(examples):
    # print(examples['dialog'])
    # print(examples)
    result = tokenizer(examples['dialog'], truncation=True, padding=True, max_length=512)
    # print(result)
    result['labels'] = examples['act']
    return result

In [5]:
tokenized_dataset = dataset.map(preprocess_function, batched=True, batch_size=8)

Loading cached processed dataset at /home/pavel/work/active_learning_project/exploded_dataset/train/cache-9f1ef0579996321d.arrow


  0%|          | 0/968 [00:00<?, ?ba/s]

Loading cached processed dataset at /home/pavel/work/active_learning_project/exploded_dataset/validation/cache-1cf20b8a7d425f03.arrow


In [7]:
from torch.utils.data import DataLoader

In [8]:
from transformers import AdamW

In [9]:
from transformers import get_scheduler

In [10]:
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


In [11]:
from tqdm.auto import tqdm

In [12]:
from collections import defaultdict
import heapq


class DialogStats:
    def __init__(self):
        self.correct_ans = 0
        self.total_ans = 0

    def add_ans(self, correct):
        self.total_ans += 1
        if correct:
            self.correct_ans += 1

    @property
    def ratio(self):
        if self.total_ans == 0:
            return 0
        else:
            return self.correct_ans / self.total_ans

    def __repr__(self):
        return f'{self.correct_ans}/{self.total_ans}'


class DialogPrediction:
    def __init__(self):
        self.answers = None
        self.reset()

    def reset(self):
        self.answers = defaultdict(lambda: DialogStats())

    def add_answer(self, dialog_id, correct):
        self.answers[dialog_id].add_ans(correct)

    def get_bottom_k_percents(self, k):
        answer = []
        result_count = len(self.answers) * k // 100
        result_count = max(result_count, 1)
        for k, v in self.answers.items():
            if len(answer) < result_count:
                heapq.heappush(answer, (-v.ratio, k))
            else:
                prev_ratio, dialog_id = heapq.heappop(answer)
                if prev_ratio > -v.ratio:
                    heapq.heappush(answer, (prev_ratio, dialog_id))
                else:
                    heapq.heappush(answer, (-v.ratio, k))
        return [dialog_id for _, dialog_id in answer]

    def __repr__(self):
        return str(self.answers)

In [13]:
from torch.utils.data import Sampler
from typing import Iterator


class WorstDialogSampler(Sampler):

    def __init__(self, data_source,
                 dialog_predictions: DialogPrediction,
                 bottom_k_percents: int):
        super().__init__(data_source)
        self.data_source = data_source
        self.dialog_prediction = dialog_predictions
        self.full_length = len(data_source)
        self.bottom_k_percents = bottom_k_percents
        self.is_init = False
        self.worst_dialog_ids = None
        self.worst_dataset_indices = None

    def set_init(self, is_init=True):
        self.is_init = is_init

    def choose_worst(self, bottom_k_percents=None):
        if bottom_k_percents is None:
            bottom_k_percents = self.bottom_k_percents
        self.set_init(True)
        self.worst_dialog_ids = set(self.dialog_prediction.get_bottom_k_percents(bottom_k_percents))
        self.worst_dataset_indices = []

        for i in range(len(self.data_source)):
            d = self.data_source[i]
            if d['dialog_num'] in self.worst_dialog_ids:
                self.worst_dataset_indices.append(i)

    def __iter__(self) -> Iterator[int]:
        if not self.is_init:
            return iter(range(len(self.data_source)))
        else:
            return iter(self.worst_dataset_indices)

    def __len__(self):
        if not self.is_init:
            return self.full_length
        else:
            return len(self.worst_dataset_indices)

In [15]:
num_epochs = 10
batch_size = 32
bottom_percents = 10

model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=5)
model.to(device)

dp = DialogPrediction()

train_worst_sampler = WorstDialogSampler(dataset['train'], dp, bottom_percents)
train_dataloader = DataLoader(dataset['train'], batch_size=batch_size, sampler=train_worst_sampler)
eval_dataloader = DataLoader(dataset['validation'], batch_size=batch_size)

optimizer = AdamW(model.parameters(), lr=5e-5)

num_training_steps = (len(dataset['train']) + len(dataset)*bottom_percents //10 * (num_epochs-1)) // batch_size
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epochs):
    model.train()
    batches = 0
    for batch in train_dataloader:
        batch_dict = {k: v for k, v in batch.items()}
        data = tokenizer(batch_dict['dialog'], truncation=True, padding=True, max_length=512, return_tensors='pt')
        data['labels'] = batch_dict['act']
        batch = {k: v.to(device) for k, v in data.items()}

        outputs = model(**batch)
        if not train_worst_sampler.is_init:
            predictions = torch.argmax(outputs.logits, dim=-1)
            for i in range(len(data['labels'])):
                dp.add_answer(int(batch_dict['dialog_num'][i]), predictions[i] == data['labels'][i])
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        batches += 1

    if not train_worst_sampler.is_init:
        train_worst_sampler.choose_worst()

    model.eval()
    for batch in eval_dataloader:
        batch_dict = {k: v for k, v in batch.items()}
        data = tokenizer(batch_dict['dialog'], truncation=True, padding=True, max_length=512, return_tensors='pt')
        data['labels'] = batch_dict['act']
        batch = {k: v.to(device) for k, v in data.items()}

        outputs = model(**batch)
        predictions = torch.argmax(outputs.logits, dim=-1)
        f1.add_batch(predictions=predictions, references=data['labels'])
    print(f1.compute(average='weighted'))
    print(batches)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

  0%|          | 0/2724 [00:00<?, ?it/s]

{'f1': 0.8012698891828767}
2725
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
{'f1': 0.8012698891828767}
206
