In [3]:
import csv
import random
import logging
import numpy as np
from typing import Optional, Union, List, Dict, Tuple
import collections
from tqdm import tqdm
from argparse import Namespace
from collections import defaultdict
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from datasets import load_dataset, ClassLabel, load_metric
from transformers import (
    AutoTokenizer, 
    AutoModelForTokenClassification, 
    AutoConfig, 
    PreTrainedTokenizerFast, 
    GPT2TokenizerFast, 
    GPT2ForTokenClassification,
    TrainingArguments,
    set_seed,
)
from transformers.trainer import JointTrainer
from run_uncertainty import DataCollatorForJointClassification
from common_functions import (
    entities2dict, 
    merge_ent_dict, 
    common_cal, 
    LockedDropoutMC, 
    WordDropoutMC, 
    DropoutMC,
    activate_mc_dropout,
    convert_dropouts,
    convert_to_mc_dropout,
    freeze_all_dpp_dropouts,
)

logger = logging.getLogger(__name__)

train_file = "../test_data/new_joint_train_NYT_1over4.json" # 172718
validation_file = "../test_data/new_joint_test_part_NYT.json" # 1680
output_dir = "../tok_cls_result/NYT_gpt2_logic"

label_all_tokens = False
task_name="ner"
model_name_or_path = "../tok_cls_result/NYT_uncertainty_prob_variance/checkpoint-7582/"  # "gpt2-medium"
classifier_type = "crf"
cache_dir = None
model_revision = "main"
use_auth_token = False
beta = 1.0
alpha = 0.5 
boot_start_epoch = 5 
threshold = 0.5 
use_subtoken_mask = False
pad_to_max_length = False
fp16 = False
padding = "max_length" if pad_to_max_length else False
preprocessing_num_workers = None
overwrite_cache = True

training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=5,
    per_device_eval_batch_size=20,
    num_train_epochs=10,
    weight_decay=0.01,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    greater_is_better=True,
    dataloader_num_workers=0,
    fp16=fp16,
    seed=42,
    do_train=True,
    do_eval=True,
)

In [2]:
datasets = load_dataset("json", data_files={"train": train_file, "validation": validation_file})
column_names = datasets["train"].column_names
features = datasets["train"].features
text_column_name = "tokens" if "tokens" in column_names else column_names[0]
label_column_name = (f"{task_name}_tags" if f"{task_name}_tags" in column_names else column_names[1])

# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the unique labels.
def get_label_list(labels):
    unique_labels = set()
    for label in labels:
        unique_labels = unique_labels | set(label)
    # NOTE Improvements for GPT2+CRF: check if a B-label has its corresponding I-label.
    # Related to changes of the label_all_tokens behavior.
    ilabels_to_add = set()
    if 'O' not in unique_labels:
        ilabels_to_add.add('O')
    if label_all_tokens:
        for ulabel in unique_labels:
            if ulabel.startswith("B-"):
                ilabel = "I" + ulabel[1:] # B-XXX -> I-XXX
                if ilabel not in unique_labels:
                    ilabels_to_add.add(ilabel)
    if ilabels_to_add:
        unique_labels = unique_labels | ilabels_to_add
        logger.info(f"Additional labels added: {ilabels_to_add}")
    
    label_list = list(unique_labels)
    label_list.sort()
    return label_list

if isinstance(features[label_column_name].feature, ClassLabel):
    label_list = features[label_column_name].feature.names
    # No need to convert the labels since they are already ints.
    label_to_id = {i: i for i in range(len(label_list))}
else:
    label_list = get_label_list(datasets["train"][label_column_name])
    label_to_id = {l: i for i, l in enumerate(label_list)}
    
num_labels = len(label_list)

Using custom data configuration default-f9a36bd4350fd2bd
Found cached dataset json (/home/dsi/yufli/.cache/huggingface/datasets/json/default-f9a36bd4350fd2bd/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 2/2 [00:00<00:00, 202.84it/s]


In [4]:
# Load pretrained model and tokenizer
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently download model & vocab.
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for i, l in enumerate(label_list)}
config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    label2id=label2id, # Workaround for GPT2 w/o predefined labels
    id2label=id2label, # Workaround for GPT2 w/o predefined labels
    token_classifier_o_label_id=label2id['O'], # GPT2TokenClassificaton specific
    token_classifier_type=classifier_type, # GPT2TokenClassificaton specific
    finetuning_task=task_name,
    cache_dir=cache_dir,
    revision=model_revision,
    use_auth_token=True if use_auth_token else None,
    beta=beta, # attention loss parameter
    alpha=alpha, # logic loss parameter
    use_subtoken_mask=use_subtoken_mask,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    cache_dir=cache_dir,
    use_fast=True,
    revision=model_revision,
    use_auth_token=True if use_auth_token else None,
    add_prefix_space=True, # Workaround for GPT2
)
tokenizer.pad_token = tokenizer.eos_token # Workaround for GPT2
model = AutoModelForTokenClassification.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=config,
    cache_dir=cache_dir,
    revision=model_revision,
    use_auth_token=True if use_auth_token else None,
)
# torch.save(self.tokenizer, 'examples/tok_cls_result/tokenizer.pt')
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
    raise ValueError(
        "This example script only works for models that have a fast tokenizer. Checkout the big table of models "
        "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
        "requirement"
    )

# Tokenize all texts and align the labels with them.
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples[text_column_name],
        padding=padding,
        truncation=True,
        max_length=512, # Workaround for GPU memory consumption
        # We use this argument because the texts in our dataset are lists of words (with a label for each word).
        is_split_into_words=True,
        return_special_tokens_mask=True,
    )
    labels = []
    queryID = []
    for i, label in enumerate(examples[label_column_name]):
        tokens = tokenized_inputs.tokens(batch_index=i) # subtokens after GPT2 Tokenizer
        word_ids = tokenized_inputs.word_ids(batch_index=i) # [0, 1, 1, 2, 3, 3, 3, 4, ...]
        previous_word_idx = None
        label_ids = []
        for j, word_idx in enumerate(word_ids):
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx: # and tokens[j].startswith("Ġ"): # ADDED condition, should be held when add_prefix_space=True
                label_ids.append(label_to_id[label[word_idx]])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                # label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
                # NOTE Change behavior of label_all_tokens:
                # NOTE The word_ids trick does not always work, e.g., a file path /usr/bin/bash will be split into subparts with different word_ids.
                # To solve this problem, as we specific add_prefix_space=True, we can use the leading Ġ to check word boundary.
                if label_all_tokens:
                    if label[word_idx].startswith("B-"):
                        ilb = "I" + label[word_idx][1:]
                        label_ids.append(label_to_id[ilb])
                    else: # label starts with "I-" or "O": directly add it to label_ids
                        label_ids.append(label_to_id[label[word_idx]])
                else:
                    label_ids.append(-100)
            previous_word_idx = word_idx

        labels.append(label_ids)
        # add query id
        query_id = examples["query_ids"][i]
        try:
            queryID.append([word_ids.index(query_id)])
        except:
            queryID.append([0])

    tokenized_inputs["labels"] = labels
    tokenized_inputs["query_ids"] = queryID

    return tokenized_inputs

In [6]:
# Preprocessing the dataset
# Data collator
data_collator = DataCollatorForJointClassification(tokenizer, pad_to_multiple_of=8 if fp16 else None)
# Metrics
metric = load_metric("seqeval")
# Datasets
train_dataset = datasets["train"]
eval_dataset = datasets["validation"]
# Tokenize datasets
init_dataset = train_dataset.map(
    tokenize_and_align_labels,
    batched=True,
    num_proc=preprocessing_num_workers,
    load_from_cache_file=not overwrite_cache,
)
eval_dataset = eval_dataset.map(
    tokenize_and_align_labels,
    batched=True,
    num_proc=preprocessing_num_workers,
    load_from_cache_file=not overwrite_cache,
)


def group_sub_entities(entities: List[dict]) -> dict:
    """
    Group together the adjacent tokens with the same entity predicted.
    Args:
        entities (:obj:`dict`): The entities predicted by the pipeline (List of entity dicts).
    """
    # Get the first entity in the entity group
    entity = entities[0]["entity"].split("-")[-1]
    tokens = [entity["word"] for entity in entities]
    index = [entity["index"] for entity in entities]

    entity_group = {
        "entity_group": entity,
        "word": tokenizer.convert_tokens_to_string(tokens),
        "index": index,
    }
    return entity_group


def group_entities(ignore_subwords, entities: List[dict]) -> List[dict]:
    """
    Find and group together the adjacent tokens with the same entity predicted.
    Args:
        entities (:obj:`dict`): The entities predicted by the pipeline.
    """

    entity_groups = []
    entity_group_disagg = []

    if entities:
        last_idx = entities[-1]["index"]

    for entity in entities:
        is_last_idx = entity["index"] == last_idx
        is_subword = ignore_subwords and entity["is_subword"]
        if not entity_group_disagg:
            if not is_subword: # the first entity can never be a subword
                entity_group_disagg += [entity]
            if is_last_idx and entity_group_disagg:
                entity_groups += [group_sub_entities(entity_group_disagg)]
            # print("entity group disagg: {}".format(entity_group_disagg))
            continue

        # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
        # The split is meant to account for the "B" and "I" suffixes
        # Shouldn't merge if both entities are B-type
        if (
            (
                entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
                and entity["entity"].split("-")[0] != "B"
            )
            and entity["index"] == entity_group_disagg[-1]["index"] + 1
        ) or is_subword:
            # Modify subword type to be previous_type
            if is_subword:
                entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
                # print("entity (after aligning tag): {}".format(entity))

            entity_group_disagg += [entity]
            # Group the entities at the last entity
            if is_last_idx:
                entity_groups += [group_sub_entities(entity_group_disagg)]
        # If the current entity is different from the previous entity, aggregate the disaggregated entity group
        else:
            entity_groups += [group_sub_entities(entity_group_disagg)]
            entity_group_disagg = [entity]
            # If it's the last entity, add it to the entity groups
            if is_last_idx:
                entity_groups += [group_sub_entities(entity_group_disagg)]

    return entity_groups


def handling_score(
    labels_idx, input_ids, special_tokens_mask, 
    gen_labels, grouped_entities, ignore_subwords, 
    apply_gpt2_subword_mask, detect_gpt2_leading_space, 
    ignore_labels, is_label=False,
):  
    entities = []
    # Filter to labels not in `self.ignore_labels`
    # Filter special_tokens
    filtered_labels_idx = []
    true_idx = None

    if is_label: # handling groud truth labels
        for idx, label_idx in enumerate(labels_idx):
            if label_idx == -100: # subword
                if true_idx is not None and idx == true_idx + 1: # the current subword belongs to the latest true token
                    filtered_labels_idx.append((idx, label_idx))
                    true_idx += 1 # the true idx is updated as the current idx
            elif (
                model.config.id2label[label_idx] not in ignore_labels 
                and not special_tokens_mask[idx]
            ):
                true_idx = idx # record the latest true idx
                filtered_labels_idx.append((idx, label_idx))
    else: # handling predictions
        for idx, label_idx in enumerate(labels_idx):
            if gen_labels[idx] == -100: # subword
                if true_idx is not None and idx == true_idx + 1: # the current subword belongs to the latest true token
                    filtered_labels_idx.append((idx, label_idx))
                    true_idx += 1 # the true idx is updated as the current idx
            elif (
                model.config.id2label[label_idx] not in ignore_labels 
                and not special_tokens_mask[idx]
            ):
                true_idx = idx # record the latest true idx
                filtered_labels_idx.append((idx, label_idx))


    for idx, label_idx in filtered_labels_idx:
        word = tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] # contains "Ġ"
        is_subword = False
        # NOTE Patch for GPT2 subword detection by using word_ids (ref. line 192)
        if apply_gpt2_subword_mask:
            is_subword = gen_labels[idx] == -100
        
        # NOTE GPT2 specific subword detection for special words like IP/Email addresses
        # Only be correct when add_prefix_space = True in Tokenizer
        if detect_gpt2_leading_space:
            is_subword = is_subword or not word.startswith('Ġ')

        if int(input_ids[idx]) == tokenizer.unk_token_id:
            is_subword = False

        if is_subword:
            entity = {
                "word": word,
                "entity": "B-X",
                "index": idx,
            }
        else:
            entity = {
                "word": word,
                "entity": model.config.id2label[label_idx],
                "index": idx,
            }

        if grouped_entities and ignore_subwords:
            entity["is_subword"] = is_subword

        entities += [entity]

    if grouped_entities:
        return group_entities(ignore_subwords, entities) # Append ungrouped entities
    else:
        return entities


def preds_to_grouped_entity(
        preds: Union[np.ndarray, Tuple[np.ndarray]] = None,
        is_label: bool = False,
        ignore_labels=["O"],
        grouped_entities: bool = True,
        ignore_subwords: bool = True,  
        detect_gpt2_leading_space: bool = False, 
    ):
    """
    preds (np.ndarray: N X T X T X V): prediction logits from model outputs.
    """
    if detect_gpt2_leading_space:
        if not isinstance(tokenizer, GPT2TokenizerFast):
            raise ValueError("tokenizer must be a GPT2TokenizerFast")
        if not tokenizer.add_prefix_space:
            raise ValueError("tokenizer.add_prefix_space must be set to True when detect_gpt2_leading_space is True.")

    if ignore_subwords and not tokenizer.is_fast:
        raise ValueError(
            "Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option"
            "to `False` or use a fast tokenizer."
        )

    if isinstance(model, GPT2ForTokenClassification) and ignore_subwords:
        apply_gpt2_subword_mask = True
    else:
        apply_gpt2_subword_mask = False

    answers = []
    all_input_ids = eval_dataset["input_ids"] # N X T
    all_special_tokens_mask = eval_dataset["special_tokens_mask"] # N X T
    all_labels = eval_dataset["labels"] # N X T

    if not is_label: # preds is prediction N X T X T X V
        for i, logits in enumerate(preds): # logits: T X T X V (T is after padding)
            input_ids = all_input_ids[i] # T (non-padding)
            special_tokens_mask = all_special_tokens_mask[i] # T (non-padding)
            labels = all_labels[i] # T (non-padding)
            logits = logits[:len(input_ids), :len(input_ids)] # remove the padding part of the logit
            sent_res = []
            
            for query, logit in enumerate(logits): # T X V
                score = np.exp(logit) / np.exp(logit).sum(-1, keepdims=True) # T X V
                labels_idx = score.argmax(axis=-1) # T

                seq_entities = handling_score(
                    labels_idx, input_ids, special_tokens_mask, 
                    labels, grouped_entities, ignore_subwords, apply_gpt2_subword_mask, 
                    detect_gpt2_leading_space, ignore_labels, is_label=False,
                ) # List[Dict] 1-D
                sent_res.append(seq_entities) # List[List[Dict]] 2-D
    
            answers.append(sent_res)
    
    else: # preds are labels N X T
        for i, label_ids in enumerate(preds): # label_ids: T (T is after padding)
            input_ids = all_input_ids[i] # T (non-padding)
            special_tokens_mask = all_special_tokens_mask[i] # T (non-padding)
            labels_idx = label_ids[:len(input_ids)] # remove the padding part of the labels
            score = np.array([[1.0]*len(model.config.id2label)]*len(labels_idx)) # T X V
        
            seq_entities = handling_score(
                labels_idx, input_ids, 
                special_tokens_mask, labels_idx,
                grouped_entities, ignore_subwords, 
                apply_gpt2_subword_mask, detect_gpt2_leading_space, 
                ignore_labels, is_label=True,
            ) # List[Dict] 1-D
            answers.append(seq_entities)

    if len(answers) == 1:
        return answers[0]

    return answers


def extract_triplets(grouped_entities, dataset_name="eval", is_label=True):
    """
    dataset: self.train_dataset or self.eval_dataset
    """ 
    if dataset_name == "eval":
        dataset = eval_dataset
    else:
        dataset = train_dataset
    sentIDs = dataset['sentID']
    queryIDs = dataset['query_ids']

    if is_label: # extract triplets from grouped_labels (N X T)
        label_entities = []
        ID_set = set()
        for i, entities in enumerate(grouped_entities):
            sentid, queryid = sentIDs[i], queryIDs[i][0] # each instance
            if sentid not in ID_set: # new sentence
                if i != 0: # not the first instance
                    merge_ent_dict(ent_dict, sent_ents) # merge all the entities and relations into triplets
                    label_entities.append(sent_ents) # append each sentence triplets to output

                ID_set.add(sentid)
                sent_ents = []
                ent_dict = defaultdict(dict)
            
            entities2dict(entities, queryid, ent_dict) # build entity-relations dict

            if i == len(grouped_entities) - 1: # last instance
                merge_ent_dict(ent_dict, sent_ents) # merge all the entities and relations into triplets
                label_entities.append(sent_ents) # append each sentence triplets to output

    else: # extract triplets from grouped_preds (N X T X T')
        label_entities = []
        all_labels = dataset['labels']
        all_sentIDs = dataset['sentID']
        unique_pair = []
        id_set = set()
        for Id, tag in zip(all_sentIDs, all_labels):
            if Id not in id_set:
                id_set.add(Id)
                unique_pair.append((Id, tag))

        for i, sentence_entities in enumerate(grouped_entities): # every sentence (T X T')
            sent_ents = []
            label = unique_pair[i][1] # corresponding labels
            ent_dict = defaultdict(dict) # record each entities and related entities for each sentence
            for queryid, entities in enumerate(sentence_entities): # every query instance (T')
                if label[queryid] != -100: # we only extract triplets for non-subword positions
                    entities2dict(entities, queryid, ent_dict) # build entity-relations dict  

            merge_ent_dict(ent_dict, sent_ents) # merge all the entities and relations into triplets
            label_entities.append(sent_ents)

    return label_entities



def compute_metrics(p):
    """
    predictions logits (np.ndarray: N X T X T X V): # instances X query dimension X token dimension X label dimension.
    labels (np.ndarray: N X T (Dict)): ground truth of label_ids (corresponding to query_ids).
    """
    predictions, labels = p # N X T X T X V, N X T
    grouped_preds = preds_to_grouped_entity(preds=predictions)
    # remove repeated preds for the same sentence
    sent_id_pool = set()
    remove_idx = []
    for i, sent_id in enumerate(eval_dataset["sentID"]):
        if sent_id not in sent_id_pool:
            sent_id_pool.add(sent_id)
        else:
            remove_idx.append(i)
    
    grouped_preds = [preds for i, preds in enumerate(grouped_preds) if i not in remove_idx] # N' X T X T'
    grouped_labels = preds_to_grouped_entity(preds=labels, is_label=True) # N X T'
    
    true_predictions = extract_triplets(grouped_preds, is_label=False) # N X T' (quadratic dict)
    true_labels = extract_triplets(grouped_labels, is_label=True) # N X T' (quadratic dict)

    with open("pred_triplets(gpt2).csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(true_predictions)

    with open("label_triplets(gpt2).csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(true_labels)       

    TP_notag, TP_tag, Pos, Neg = 0, 0, 0, 0
    pred_F, ent_mention_F, ent_tag_F = 0, 0, 0
    re_mention_F, re_tag_FN, re_tag_FP, re_tag_F = 0, 0, 0, 0
    # calculate precision, recall, F1 and accuracy
    for hyp, ref in zip(true_predictions, true_labels):
        tp_notag, tp_tag, n_hyp, n_ref, false_tag, ent_mention_f, ent_tag_f, \
            re_mention_f, re_fn, re_fp, re_tag_f = common_cal(hyp, ref)
        TP_notag += tp_notag
        TP_tag += tp_tag
        Pos += n_hyp
        Neg += n_ref
        pred_F += false_tag
        ent_mention_F += ent_mention_f
        ent_tag_F += ent_tag_f
        re_mention_F += re_mention_f
        re_tag_FN += re_fn
        re_tag_FP += re_fp
        re_tag_F += re_tag_f

    pre_notag = TP_notag / Pos if Pos else 0.0
    rec_notag = TP_notag / Neg if Neg else 0.0
    f1_notag = 2.0 * pre_notag * rec_notag / (pre_notag + rec_notag) if (pre_notag or rec_notag) else 0.0

    pre_tag = TP_tag / Pos if Pos else 0.0
    rec_tag = TP_tag / Neg if Neg else 0.0
    f1_tag = 2.0 * pre_tag * rec_tag / (pre_tag + rec_tag) if (pre_tag or rec_tag) else 0.0

    ent_m_fr = ent_mention_F / pred_F if pred_F else 0.0
    ent_tag_fr = ent_tag_F / pred_F if pred_F else 0.0
    re_m_fr = re_mention_F / pred_F if pred_F else 0.0
    re_tag_fnr = re_tag_FN / pred_F if pred_F else 0.0
    re_tag_fpr = re_tag_FP / pred_F if pred_F else 0.0
    re_tag_fr = re_tag_F / pred_F if pred_F else 0.0
    
    pred_len = [len(pred) for pred in true_predictions]
    avg_pred_len = sum(pred_len) / len(pred_len) if len(pred_len) else 0.0
    label_len = [len(label) for label in true_labels]
    avg_label_len = sum(label_len) / len(label_len) if len(label_len) else 0.0

    return {
        "precision": pre_notag,
        "recall": rec_notag,
        "f1": f1_notag, 
        "precision(tag)": pre_tag,
        "recall(tag)": rec_tag,
        "f1(tag)": f1_tag,
        "ent_mention_fr": ent_m_fr,
        "ent_tag_fr": ent_tag_fr,
        "re_mention_fr": re_m_fr,
        "re_fpr": re_tag_fpr,
        "re_fnr": re_tag_fnr,
        "re_tag_fr": re_tag_fr,
        "avg_pred_len": avg_pred_len,
        "avg_true_len": avg_label_len,
    }


# Trainer
trainer = JointTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=init_dataset, # we feed the intial training data to the Trainer
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    bootstrap=None,
    original_dataset=train_dataset,
    use_bootstrap=False,
    boot_start_epoch=boot_start_epoch,
)
# Dataloaders
selected_dataset = init_dataset.select(range(1000))
# train_dataloader = trainer.get_train_dataloader(selected_dataset) # 5
train_dataloader = DataLoader(
            selected_dataset, 
            batch_size=20,
            shuffle=False, 
            collate_fn=data_collator,
            num_workers=1,
            pin_memory=True,
            drop_last=False,
        )
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
train_batch0 = next(iter(train_dataloader))

 99%|█████████▉| 173/174 [00:51<00:00,  3.33ba/s]
 50%|█████     | 1/2 [00:00<00:00,  3.11ba/s]


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
from common_functions import bald, probability_variance, sampled_max_prob
threshold = 0.5
committee_size = 15

dropout = Namespace(
    max_n=100,
    max_frac=0.4,
    mask_name='mc',
    dry_run_dataset='train',
)
args = Namespace(
    dropout_type='MC',
    inference_prob=0.1,
    committee_size=committee_size, # number of forward passes
    dropout_subs='last',
    eval_bs=1000,
    use_cache=True,
    eval_passes=False,
    dropout=dropout,
)

# # Inference (no dropout)
# eval_results = {}
# eval_results["sampled_probabilities"] = []
# logger.info("****************Start runs**************")

# for i in tqdm(range(args.committee_size)):
#     outputs = model(**train_batch0) # loss, logits, position_attentions
#     probs = F.softmax(outputs.logits.float(), dim=-1) # B X T X C
#     preds = torch.argmax(probs, dim=-1) # B X T
#     eval_results["sampled_probabilities"].append(probs.tolist())
# logger.info("Done!!!")

# Stochastic inference： MC Dropout
logger.info("*** Evaluate ***")
logger.info("******Perform stochastic inference...*******")
convert_dropouts(model, args)
activate_mc_dropout(model, activate=True, random=args.inference_prob)
logger.info("****************Start runs**************")
set_seed(42)
random.seed(42)
model.eval()

matched_idx1, matched_idx2, matched_idx3 = [], [], []
mscores1, mscores2, mscores3 = [], [], []

for step, inputs in tqdm(enumerate(train_dataloader)):
    # MC Dropout: stochastic inference
    dropout_eval_results = {}
    dropout_eval_results["sampled_probabilities"] = []
    # Prepare inputs {"input_ids", "query_ids", "target_att", ...} 
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.to(model.device)

    with torch.no_grad():
        for _ in range(args.committee_size):
            outputs = model(**inputs) # loss, logits, position_attentions
            probs = F.softmax(outputs.logits.float(), dim=-1) # B X T X C
            dropout_eval_results["sampled_probabilities"].append(probs.tolist()) # K X B X T X C

    prob_array = np.array(dropout_eval_results["sampled_probabilities"]) # K X B X T X C

    # Uncertainty estimation
    s1 = bald(prob_array) # B
    s2 = sampled_max_prob(prob_array) # B
    s3 = probability_variance(prob_array) # B

    mscores1.extend(s1.tolist())
    mscores2.extend(s2.tolist())
    mscores3.extend(s3.tolist())
    
    # soft_thre1 = sorted(s1)[int(0.8*len(s1))]
    batch_idx1 = (np.where(s1 < threshold)[0] + step * 20)
    matched_idx1.extend(batch_idx1.tolist())

    # soft_thre2 = sorted(s2)[int(0.8*len(s2))]
    batch_idx2 = (np.where(s2 < threshold)[0] + step * 20)
    matched_idx2.extend(batch_idx2.tolist())

    # soft_thre3 = sorted(s3)[int(0.8*len(s3))]
    batch_idx3 = (np.where(s3 <= threshold)[0] + step * 20)
    matched_idx3.extend(batch_idx3.tolist())

activate_mc_dropout(model, activate=False)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


50it [01:55,  2.31s/it]


In [12]:
# print(mscores1)
# print(mscores2)
# print(mscores3)

In [15]:
# print(len(matched_idx1), matched_idx1)
# print(len(matched_idx2), matched_idx2)
# print(len(matched_idx3), matched_idx3)
model_thre = sorted(mscores1)[int(0.7*len(mscores1))]
matched_idx = (np.where(np.array(mscores1) <= model_thre)[0]).tolist()
print(len(matched_idx), matched_idx)

806 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 68, 69, 71, 72, 73, 74, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 92, 93, 94, 95, 97, 98, 99, 100, 103, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 127, 128, 129, 130, 132, 133, 135, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152, 153, 154, 155, 156, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 194, 198, 199, 201, 202, 203, 204, 205, 206, 207, 208, 210, 211, 212, 213, 214, 216, 217, 222, 223, 224, 225, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 2

In [None]:
model.eval() # stop dropout
model(**train_batch0).position_attentions

In [None]:
test = np.random.normal(0, 1, 10)
print(test)
np.where(test > 0.5)[0].tolist()

In [17]:
def data_uncertainty(preds, ue='vanilla'):
    """
    Input:
        preds: B X T X C
    Output:
        scores: B
    """
    if ue == 'vanilla':
        token_score = torch.max(preds, dim=-1)[0] # B X T
    elif ue == 'entropy':
        token_score = torch.sum(-preds * torch.log(torch.clip(preds, 1e-8, 1)), axis=-1) # B X T
    else:
        raise ValueError('Unknown uncertainty estimation method.')
    score = torch.mean(token_score, dim=-1) # B

    return score

In [19]:
data_uncertainty(probs)

tensor([0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992,
        0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992, 0.9992,
        0.9992, 0.9992], device='cuda:0')