<a href="https://colab.research.google.com/github/ronenbendavid/IDC_NLP/blob/master/BERT_EntityLinkingKB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Assignment 3
Training a neural named entity recognition (NER) tagger 

In [228]:
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('version: {}, device: {}'.format(torch.__version__, device))

version: 1.6.0+cu101, device: cuda


In this assignment you are required to build a full training and testing pipeline for a neural sequentail tagger for named entities, using LSTM.

The dataset that you will be working on is called ReCoNLL 2003, which is a corrected version of the CoNLL 2003 dataset: https://www.clips.uantwerpen.be/conll2003/ner/

[Train data](https://drive.google.com/file/d/1hG66e_OoezzeVKho1w7ysyAx4yp0ShDz/view?usp=sharing)

[Dev data](https://drive.google.com/file/d/1EAF-VygYowU1XknZhvzMi2CID65I127L/view?usp=sharing)

[Test data](https://drive.google.com/file/d/16gug5wWnf06JdcBXQbcICOZGZypgr4Iu/view?usp=sharing)

As you can see, the annotated texts are labeled according to the IOB annotation scheme, for 3 entity types: Person, Organization, Location.

**Task 1:** Write a funtion for reading the data from a single file (of the ones that are provided above). The function recieves a filepath and then it encodes every sentence individually using a pair of lists, one list contains the words and one list contains the tags. Each list pair will be added to a general list (data), which will be returned back from the function.

In [229]:
import requests
import re

def read_data(filepath):
    data = []

    result = re.compile(".*drive.google.com/file/d/([^/]*)/.*").match(filepath)
    if result:
      filepath = 'https://docs.google.com/uc?export=download&id={}'.format(result.group(1))
    print(filepath)

    response = requests.get(filepath)
    words = []
    tags = []

    for line in response.text.split('\n'):
        if not line:
            if len(words) > 0:
                data.append((words, tags))
            words = []
            tags = []
        else:
            line = line.strip().split()
            words.append(line[0].lower())
            tags.append(line[1])

    return data

train = read_data('https://drive.google.com/file/d/1hG66e_OoezzeVKho1w7ysyAx4yp0ShDz/view?usp=sharing')
dev = read_data('https://drive.google.com/file/d/1EAF-VygYowU1XknZhvzMi2CID65I127L/view?usp=sharing')
test = read_data('https://drive.google.com/file/d/16gug5wWnf06JdcBXQbcICOZGZypgr4Iu/view?usp=sharing')


https://docs.google.com/uc?export=download&id=1hG66e_OoezzeVKho1w7ysyAx4yp0ShDz
https://docs.google.com/uc?export=download&id=1EAF-VygYowU1XknZhvzMi2CID65I127L
https://docs.google.com/uc?export=download&id=16gug5wWnf06JdcBXQbcICOZGZypgr4Iu


The following Vocab class can be served as a dictionary that maps words and tags into Ids. The UNK_TOKEN should be used for words that are not part of the training data.

In [230]:
UNK_TOKEN = 0

# class Vocab:
#     def __init__(self):
#         self.word2id = {"__unk__": UNK_TOKEN}
#         self.id2word = {UNK_TOKEN: "__unk__"}
#         self.n_words = 1
        
#         self.tag2id = {"O":0, "B-PER":1, "I-PER": 2, "B-LOC": 3, "I-LOC": 4, "B-ORG": 5, "I-ORG": 6}
#         self.id2tag = {0:"O", 1:"B-PER", 2:"I-PER", 3:"B-LOC", 4:"I-LOC", 5:"B-ORG", 6:"I-ORG"}
        
#     def index_words(self, words):
#       word_indexes = [self.index_word(w) for w in words]
#       return word_indexes

#     def index_tags(self, tags):
#       tag_indexes = [self.tag2id[t] for t in tags]
#       return tag_indexes
    
#     def index_word(self, w):
#         if w not in self.word2id:
#             self.word2id[w] = self.n_words
#             self.id2word[self.n_words] = w
#             self.n_words += 1
#         return self.word2id[w]
            
            

In [231]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**Task 2:** Write a function prepare_data that takes one of the [train, dev, test] and the Vocab instance, for converting each pair of (words,tags) to a pair of indexes. Each pair should be added to data_sequences, which will be returned back from the function.

In [233]:
!pip install transformers

import pickle

from transformers import BertTokenizer
class Vocab:
    def __init__(self, args=None):
        self.tag2idx = None
        self.idx2tag = None
        self.OUTSIDE_ID = None
        self.PAD_ID = None
        self.SPECIAL_TOKENS = None
        self.tokenizer = None
        if args is not None:
            self.load(args)

    def load(self, args, popular_entity_to_id_dict=None):

        if popular_entity_to_id_dict is None:
            with open(f"/content/drive/My Drive/data/versions/dummy/indexes/popular_entity_to_id_dict.pickle", "rb") as f:
              # with open(f"/content/drive/My Drive/data/versions/dummy/indexes/mention_entity_counter.pickle", "rb") as f:
              
                popular_entity_to_id_dict = pickle.load(f)

        # if args.uncased:
        if True:
            tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
        else:
            tokenizer = BertTokenizer.from_pretrained("bert-base-cased", do_lower_case=False)

        self.tag2idx = popular_entity_to_id_dict

        self.OUTSIDE_ID = len(self.tag2idx)
        self.tag2idx["|||O|||"] = self.OUTSIDE_ID

        self.PAD_ID = len(self.tag2idx)
        self.tag2idx["|||PAD|||"] = self.PAD_ID

        self.SPECIAL_TOKENS = [self.OUTSIDE_ID, self.PAD_ID]

        self.idx2tag = {v: k for k, v in self.tag2idx.items()}

        self.tokenizer = tokenizer

        # args.vocab_size = self.size()

    def size(self):
        return len(self.tag2idx)



In [234]:
import argparse
import subprocess
from collections import Counter

import torch.optim
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau


def capitalize(text: str) -> str:
    return text[0].upper() + text[1:]


def snip(string, search, keep, keep_search):
    pos = string.find(search)
    if pos != -1:
        if keep == "left":
            if keep_search:
                pos += len(search)
            string = string[:pos]
        if keep == "right":
            if not keep_search:
                pos += len(search)
            string = string[pos:]
    return string


def snip_anchor(text: str) -> str:
    return snip(text, "#", keep="left", keep_search=False)


def normalize_wiki_entity(i, replace_ws=False):
    i = snip_anchor(i)
    if len(i) == 0:
        return None
    i = capitalize(i)
    if replace_ws:
        return i.replace(" ", "_")
    return i


# most frequent English words from English Wikipedia
stopwords = {
    "a",
    "also",
    "an",
    "are",
    "as",
    "at",
    "be",
    "by",
    "city",
    "company",
    "film",
    "first",
    "for",
    "from",
    "had",
    "has",
    "her",
    "his",
    "in",
    "is",
    "its",
    "john",
    "national",
    "new",
    "of",
    "on",
    "one",
    "people",
    "school",
    "state",
    "the",
    "their",
    "these",
    "this",
    "time",
    "to",
    "two",
    "university",
    "was",
    "were",
    "with",
    "world",
}


def get_stopwordless_token_set(s):
    result = set(s.lower().split(" "))
    result_minus_stopwords = result.difference(stopwords)
    if len(result_minus_stopwords) == 0:
        return result
    else:
        return result_minus_stopwords


def argparse_bool_type(v):
    "Type for argparse that correctly treats Boolean values"
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")


def get_gpu_memory_map():
    """Get the current gpu usage.
    Returns
    -------
    usage: dict
        Keys are device ids as integers.
        Values are memory usage as integers in MB.
    """
    result = subprocess.check_output(
        ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"], encoding="utf-8"
    )
    # Convert lines into a dictionary
    gpu_memory = [int(x) for x in result.strip().split("\n")]
    gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
    return gpu_memory_map


def create_chunks(a_list, n):
    for i in range(0, len(a_list), n):
        yield a_list[i : i + n]


def unescape(s):
    if s.startswith('"'):
        s = s[1:-1]
    return s.replace('""""', '"').replace('""', '"')


def create_overlapping_chunks(a_list, n, overlap):
    for i in range(0, len(a_list), n - overlap):
        yield a_list[i : i + n]


def running_mean(new, old=None, momentum=0.9):
    if old is None:
        return new
    else:
        return momentum * old + (1 - momentum) * new


def get_topk_ids_aggregated_from_seq_prediction(logits, topk_per_token, topk_from_batch):
    topk_logit_per_token, topk_eids_per_token = logits.topk(topk_per_token, sorted=False, dim=-1)

    i = torch.cat(
        [
            topk_eids_per_token.view(1, -1),
            torch.zeros(topk_eids_per_token.view(-1).size(), dtype=torch.long, device=topk_eids_per_token.device).view(
                1, -1
            ),
        ],
        dim=0,
    )
    v = topk_logit_per_token.view(-1)
    st = torch.sparse.FloatTensor(i, v)
    stc = st.coalesce()
    topk_indices = stc._values().sort(descending=True)[1][:topk_from_batch]
    result = stc._indices()[0, topk_indices]

    return result.cpu().tolist()


def get_entity_annotations(t, outside_id):
    annos = list()
    begin = -1
    in_entity = -1
    for i, j in enumerate(t):
        if j < outside_id and begin == -1:
            begin = i
            in_entity = j.item()
        elif j < outside_id and j != in_entity:
            annos.append((tuple(range(begin, i)), in_entity))
            begin = i
            in_entity = j.item()
        elif j == outside_id and begin != -1:
            annos.append((tuple(range(begin, i)), in_entity))
            begin = -1
    return annos


def get_entity_annotations_with_gold_spans(t, t_gold, outside_id):
    annos = list()
    begin = -1
    in_gold_entity = -1
    collected_entities_in_span = Counter()
    for i, (j, j_gold) in enumerate(zip(t, t_gold)):
        if j_gold < outside_id and begin == -1:
            begin = i
            in_gold_entity = j_gold.item()
            collected_entities_in_span[j.item()] += 1
        elif j_gold != in_gold_entity and begin != -1:
            in_entity = collected_entities_in_span.most_common()[0][0]
            annos.append((tuple(range(begin, i)), in_entity))
            collected_entities_in_span = Counter()
            begin = i
            in_gold_entity = j_gold.item()
            collected_entities_in_span[j.item()] += 1
        elif j_gold == outside_id and begin != -1:
            in_entity = collected_entities_in_span.most_common()[0][0]
            annos.append((tuple(range(begin, i)), in_entity))
            collected_entities_in_span = Counter()
            begin = -1
    return annos


class DummyOptimizer(torch.optim.Optimizer):
    def step(self, closure=None):
        pass


class LRMilestones(_LRScheduler):
    """Set the learning rate of each parameter group to the initial lr decayed
    by gamma once the number of epoch reaches one of the milestones. When
    last_epoch=-1, sets initial lr as lr.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        milestones (list): List of epoch indices. Must be increasing.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.
    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.05     if epoch < 30
        >>> # lr = 0.005    if 30 <= epoch < 80
        >>> # lr = 0.0005   if epoch >= 80
        >>> scheduler = LRMilestones(optimizer, milestones=[(30, 0.1), (80, 0.2), ])
        >>> for epoch in range(100):
        >>>     scheduler.step()
        >>>     train(...)
        >>>     validate(...)
    """

    def __init__(self, optimizer, milestones, last_epoch=-1):
        super().__init__(optimizer)
        if not list(milestones) == sorted(milestones):
            raise ValueError("Milestones should be a list of" " increasing integers. Got {}", milestones)
        self.milestones = milestones
        super(LRMilestones, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        for ep, lr in self.milestones:
            if self.last_epoch >= ep:
                print("Set lr to {} in epoch {}".format(lr, ep))
                return lr


def pad_to(arr, max_len, pad_id, cls_id, sep_id):
    return [cls_id] + arr + [sep_id] + [pad_id] * (max_len - len(arr) - 2)


def set_out_id(t, repl, dummy=-1):
    t[(t == dummy)] = repl
    return t


class LRSchedulers:
    ReduceLROnPlateau = ReduceLROnPlateau
    LRMilestones = LRMilestones

In [235]:
import os
from collections import OrderedDict

import pandas
from itertools import cycle
from operator import gt, lt


class Metrics:

    meta = OrderedDict(
        [
            ("epoch", {"comp": gt, "type": int, "str": lambda a: a}),
            ("step", {"comp": gt, "type": int, "str": lambda a: a}),
            ("num_gold", {"comp": gt, "type": int, "str": lambda a: a}),
            ("num_correct", {"comp": gt, "type": int, "str": lambda a: a}),
            ("num_proposed", {"comp": gt, "type": int, "str": lambda a: a}),
            ("f1", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("f05", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("precision", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("recall", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("span_f1", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("span_precision", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("span_recall", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("lenient_span_f1", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("lenient_span_precision", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("lenient_span_recall", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("precision_gold_mentions", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}),
            ("avg_loss", {"comp": lt, "type": float, "str": lambda a: f"{a:.5f}"}),
        ]
    )

    def __init__(self, epoch=0, step=0, num_correct=0, num_gold=0, num_proposed=0, model_selection="f1", num_best_checkpoints=4, **kwargs):

        self.epoch = epoch
        self.step = step
        self.num_correct = num_correct
        self.num_gold = num_gold
        self.num_proposed = num_proposed

        self.precision = Metrics.compute_precision(num_correct, num_proposed)
        self.recall = Metrics.compute_recall(num_correct, num_gold)
        self.f1 = Metrics.compute_fmeasure(self.precision, self.recall)
        self.f05 = Metrics.compute_fmeasure(self.precision, self.recall, weight=1.5)

        self.avg_loss = float("inf")

        for k,v in kwargs.items():
            if k in self.meta:
                self.__dict__[k] = v

        self.model_selection = model_selection
        self.checkpoint_cycle = cycle(range(num_best_checkpoints),)

    @staticmethod
    def compute_precision(correct, proposed):
        try:
            precision = correct/proposed
        except ZeroDivisionError:
            precision = 0.0
        return precision

    @staticmethod
    def compute_recall(correct, gold):
        try:
            recall = correct/ gold
        except ZeroDivisionError:
            recall = 0.0
        return recall

    @staticmethod
    def compute_fmeasure(precision, recall, weight=2.0):
        try:
            f = weight * precision * recall / (precision + recall)
        except ZeroDivisionError:
            f = 0.0
        return f

    def was_improved(self, other: "Metrics"):
        return Metrics.meta[self.model_selection]["comp"](
            other.get_model_selection_metric(), self.get_model_selection_metric()
        )

    def update(self, other: "Metrics"):
        if self.was_improved(other):
            for key, val in other.__dict__.items():
                self.__setattr__(key, other.__dict__.get(key))

    def get_model_selection_metric(self):
        return self.__dict__.get(self.model_selection)

    def get_best_checkpoint_filename(self):
        return f"best_{self.model_selection}-{next(self.checkpoint_cycle)}"

    def to_csv(self, epoch, step, args):
        header = (
            [k for k in list(self.meta.keys()) if k in self.__dict__]
            if not os.path.exists("{}/log.csv".format(args.logdir))
            else False
        )
        pandas.DataFrame(
            [[self.__dict__[k] for k in list(self.meta.keys()) if k in self.__dict__]]
        ).to_csv(f"{args.logdir}/log.csv", mode="a", header=header)

    def dict(self):
        return self.__dict__

    def __repr__(self):
        return str(self.__dict__)

    def report(self, filter=None):
        if not filter:
            filter = set(self.meta.keys())
        return [f"{k}: {Metrics.meta[k]['str'](self.__dict__[k])}" for k in list(self.meta.keys()) if k in filter and k in self.__dict__]

In [236]:
import logging
import os
from itertools import chain

import numpy
import torch
import torch.nn as nn
import tqdm
from torch import optim
from tqdm import trange
# from metrics import Metrics
# !pip install metrics  metrics.bumpversion metrics.gitinfo metrics.pylint metrics.pytest-cov
!pip install pytorch_pretrained_bert
# from data_loader_wiki import EDLDataset_collate_func
# from skmisc import running_mean, get_topk_ids_aggregated_from_seq_prediction, DummyOptimizer, LRSchedulers
from pytorch_pretrained_bert import BertModel

# Metric = METRICS_FILENAME



In [None]:
from transformers import AlbertConfig, AlbertModel
# Initializing an ALBERT-xxlarge style configuration
albert_xxlarge_configuration = AlbertConfig()

# Initializing an ALBERT-base style configuration
albert_base_configuration = AlbertConfig(hidden_size=768,num_attention_heads=12,intermediate_size=3072 )

>>> # Initializing a model from the ALBERT-base style configuration
>>> model = AlbertModel(albert_xxlarge_configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

**Task 3:** Write NERNet, a PyTorch Module for labeling words with NER tags. 

*input_size:* the size of the vocabulary

*embedding_size:* the size of the embeddings

*hidden_size:* the LSTM hidden size

*output_size:* the number tags we are predicting for

*n_layers:* the number of layers we want to use in LSTM

*directions:* could 1 or 2, indicating unidirectional or bidirectional LSTM, respectively

The input for your forward function should be a single sentence tensor.

In [291]:
class Net(nn.Module):
    def __init__(
        self, args, vocab_size=None,
    ):
        super().__init__()
        # if args.uncased:
        if True:
            self.bert = BertModel.from_pretrained("bert-base-uncased")
        else:
            self.bert = BertModel.from_pretrained("bert-base-cased")

        # self.top_rnns = args.top_rnns
        self.top_rnns = False
        if self.top_rnns:
        # if args.top_rnns:
          self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768 // 2, batch_first=True)
        self.fc = None
        # if args.project:
        #     self.fc = nn.Linear(768, args.entity_embedding_size)
        entity_embedding_size = 768
        asparse = True
        self.fc = nn.Linear(768, entity_embedding_size)
        self.out = nn.Embedding(num_embeddings=vocab_size, embedding_dim=entity_embedding_size, sparse=asparse)
        # torch.nn.init.normal_(self.out, std=0.1)
        finetuning = 3
        # self.device = args.device
        # self.out_device = args.out_device
        # self.finetuning = args.finetuning == 0
        self.device = 0
        self.out_device = 0
        self.finetuning = finetuning == 0
        self.vocab_size = vocab_size

    def to(self, device, out_device):
        self.bert.to(device)
        if self.fc:
            self.fc.to(device)
        self.out.to(out_device)
        self.device = device
        self.out_device = out_device

    def forward(self, x, y=None, probs=None, enc=None):
        """
        x: (N, T). int64
        y: (N, T). int64

        Returns
        enc: (N, T, VOCAB)
        """
        if y is not None:
            y = y.to(self.out_device)
        if probs is not None:
            probs = probs.to(self.out_device)

            # fake_y = torch.Tensor(range(10)).long().to(self.device)

        if enc is None:
            x = x.to(self.device)
            if self.training:
                if self.finetuning:
                    # print("->bert.train()")
                    self.bert.train()
                    encoded_layers, _ = self.bert(x)
                    enc = encoded_layers[-1]
                else:
                    self.bert.eval()
                    with torch.no_grad():
                        encoded_layers, _ = self.bert(x)
                        enc = encoded_layers[-1]
            else:
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]

            if self.top_rnns:
                enc, _ = self.rnn(enc)

            if self.fc:
                enc = self.fc(enc)

            enc = enc.to(self.out_device)

        if y is not None:
            out = self.out(y)
            logits = enc.matmul(out.transpose(0, 1))
            y_hat = logits.argmax(-1)
            return logits, y, y_hat, probs, out, enc
        else:
            with torch.no_grad():
                out = self.out.weight
                logits = enc.matmul(out.transpose(0, 1))
                y_hat = logits.argmax(-1)
                return logits, None, y_hat, None, None, enc

    @staticmethod
    def train_one_epoch(
        args,
        model,
        train_iter,
        optimizers,
        criterion,
        eval_iter,
        vocab,
        epoch,
        # metrics=Metrics(),
        metrics = None,
        loss_aggr=None,
    ):
        labels_with_high_model_score = None

        with trange(len(train_iter)) as t:
            for iter, batch in enumerate(train_iter):

                model.to(
                    args.device, args.out_device,
                )
                model.train()

                batch_token_ids, label_ids, label_probs, eval_mask, _, _, orig_batch, loaded_batch = batch

                enc = None

                if (
                    args.collect_most_popular_labels_steps is not None
                    and args.collect_most_popular_labels_steps > 0
                    and iter > 0
                    and iter % args.collect_most_popular_labels_steps == 0
                ):
                    model.to(args.device, args.eval_device)
                    with torch.no_grad():
                        logits_, _, _, _, _, enc = model(
                            batch_token_ids, None, None,
                        )  # logits: (N, T, VOCAB), y: (N, T)
                        labels_with_high_model_score = get_topk_ids_aggregated_from_seq_prediction(
                            logits_, topk_from_batch=args.label_size, topk_per_token=args.topk_neg_examples
                        )
                        batch_token_ids, label_ids, label_probs, eval_mask, _, _, _, _ = EDLDataset_collate_func(
                            args=args,
                            labels_with_high_model_score=labels_with_high_model_score,
                            batch=orig_batch,
                            return_labels=True,
                            vocab=vocab,
                            is_training=False,
                            loaded_batch=loaded_batch,
                        )

                # if args.label_size is not None:
                logits, y, y_hat, label_probs, sparse_params, _ = model(
                    batch_token_ids, label_ids, label_probs, enc=enc
                )  # logits: (N, T, VOCAB), y: (N, T)
                logits = logits.view(-1)  # (N*T, VOCAB)
                label_probs = label_probs.view(-1)  # (N*T,)

                loss = criterion(logits, label_probs)

                loss.backward()

                if (iter + 1) % args.accumulate_batch_gradients == 0:
                    for optimizer in optimizers:
                        optimizer.step()
                        optimizer.zero_grad()

                if iter == 0:
                    logging.debug(f"Sanity check")
                    logging.debug("x:", batch_token_ids.cpu().numpy()[0])
                    logging.debug("tokens:", vocab.tokenizer.convert_ids_to_tokens(batch_token_ids.cpu().numpy()[0]))
                    logging.debug("y:", label_probs.cpu().numpy()[0])

                loss_aggr = running_mean(loss.detach().item(), loss_aggr)

                if iter > 0 and iter % args.checkpoint_eval_steps == 0:
                    metrics = Net.evaluate(
                        args=args,
                        model=model,
                        iterator=eval_iter,
                        optimizers=optimizers,
                        step=iter,
                        epoch=epoch,
                        save_checkpoint=iter % args.checkpoint_save_steps == 0,
                        sampled_evaluation=False,
                        metrics=metrics,
                        vocab=vocab,
                    )

                t.set_postfix(
                    loss=loss_aggr,
                    nr_labels=len(label_ids),
                    aggr_labels=len(labels_with_high_model_score) if labels_with_high_model_score else 0,
                    last_eval=metrics.report(filter={"f1", "num_proposed", "epoch", "step"}),
                )
                t.update()

        for optimizer in optimizers:
            optimizer.step()
            optimizer.zero_grad()

        return metrics

    @staticmethod
    # def evaluate(
    #     args,
    #     model,
    #     iterator,
    #     vocab,
    #     optimizers,
    #     step=0,
    #     epoch=0,
    #     save_checkpoint=True,
    #     save_predictions=True,
    #     save_csv=True,
    #     sampled_evaluation=False,
    #     metrics = None
    #     # metrics=Metrics(),
    # ):

    #     print()
    #     logging.info(f"Start evaluation on split {'test' if args_eval_on_test_only else 'valid'}")

    #     model.eval()
    #     model.to(args.device, args.eval_device)

    #     all_words, all_tags, all_y, all_y_hat, all_predicted, all_token_ids = [], [], [], [], [], []
    #     with torch.no_grad():
    #         for iter, batch in enumerate(tqdm.tqdm(iterator)):
    #             (
    #                 batch_token_ids,
    #                 label_ids,
    #                 label_probs,
    #                 eval_mask,
    #                 label_id_to_entity_id_dict,
    #                 batch_entity_ids,
    #                 orig_batch,
    #                 _,
    #             ) = batch

    #             logits, y, y_hat, probs, _, _ = model(batch_token_ids, None, None)  # logits: (N, T, VOCAB), y: (N, T)

    #             tags = list()
    #             predtags = list()
    #             y_resolved_list = list()
    #             y_hat_resolved_list = list()
    #             token_list = list()

    #             chunk_len = args.create_integerized_training_instance_text_length
    #             chunk_overlap = args.create_integerized_training_instance_text_overlap

    #             for batch_id, seq in enumerate(label_probs.max(-1)[1]):
    #                 for tok_id, label_id in enumerate(seq[chunk_overlap : -chunk_overlap]):
    #                     y_resolved = (
    #                         vocab.PAD_ID
    #                         if eval_mask[batch_id][tok_id + chunk_overlap] == 0
    #                         else label_ids[label_id].item()
    #                     )
    #                     y_resolved_list.append(y_resolved)
    #                     tags.append(vocab.idx2tag[y_resolved])
    #                     if sampled_evaluation:
    #                         y_hat_resolved = (
    #                             vocab.PAD_ID
    #                             if eval_mask[batch_id][tok_id + chunk_overlap] == 0
    #                             else label_ids[y_hat[batch_id][tok_id + chunk_overlap]].item()
    #                         )
    #                     else:
    #                         y_hat_resolved = y_hat[batch_id][tok_id + chunk_overlap].item()
    #                     y_hat_resolved_list.append(y_hat_resolved)
    #                     predtags.append(vocab.idx2tag[y_hat_resolved])
    #                     token_list.append(batch_token_ids[batch_id][tok_id + chunk_overlap].item())

    #             all_y.append(y_resolved_list)
    #             all_y_hat.append(y_hat_resolved_list)
    #             all_tags.append(tags)
    #             all_predicted.append(predtags)
    #             all_words.append(vocab.tokenizer.convert_ids_to_tokens(token_list))
    #             all_token_ids.append(token_list)

    #     ## calc metric
    #     y_true = numpy.array(list(chain(*all_y)))
    #     y_pred = numpy.array(list(chain(*all_y_hat)))
    #     all_token_ids = numpy.array(list(chain(*all_token_ids)))

    #     num_proposed = len(y_pred[(vocab.OUTSIDE_ID > y_pred) & (all_token_ids > 0)])
    #     num_correct = (((y_true == y_pred) & (vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0))).astype(numpy.int).sum()
    #     num_gold = len(y_true[(vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0)])

    #     new_metrics = Metrics(
    #         epoch=epoch, step=step, num_correct=num_correct, num_proposed=num_proposed, num_gold=num_gold,
    #     )

    #     if save_predictions:
    #         final = args.logdir + "/%s.P%.2f_R%.2f_F%.2f" % (
    #             "{}-{}".format(str(epoch), str(step)),
    #             new_metrics.precision,
    #             new_metrics.recall,
    #             new_metrics.f1,
    #         )
    #         with open(final, "w") as fout:

    #             for words, tags, y_hat, preds in zip(all_words, all_tags, all_y_hat, all_predicted):
    #                 assert len(preds) == len(words) == len(tags)
    #                 for w, t, p in zip(words, tags, preds):
    #                     fout.write(f"{w}\t{t}\t{p}\n")
    #                 fout.write("\n")

    #             fout.write(f"num_proposed:{num_proposed}\n")
    #             fout.write(f"num_correct:{num_correct}\n")
    #             fout.write(f"num_gold:{num_gold}\n")
    #             fout.write(f"precision={new_metrics.precision}\n")
    #             fout.write(f"recall={new_metrics.recall}\n")
    #             fout.write(f"f1={new_metrics.f1}\n")

    #     if not args.dont_save_checkpoints:

    #         if save_checkpoint and metrics.was_improved(new_metrics):
    #             config = {
    #                 "args": args,
    #                 "optimizer_dense": optimizers[0].state_dict(),
    #                 "optimizer_sparse": optimizers[1].state_dict(),
    #                 "model": model.state_dict(),
    #                 "epoch": epoch,
    #                 "step": step,
    #                 "performance": new_metrics.dict(),
    #             }
    #             fname = os.path.join(args.logdir, "{}-{}".format(str(epoch), str(step)))
    #             torch.save(config, f"{fname}.pt")
    #             fname = os.path.join(args.logdir, new_metrics.get_best_checkpoint_filename())
    #             torch.save(config, f"{fname}.pt")
    #             logging.info(f"weights were saved to {fname}.pt")

    #     if save_csv:
    #         new_metrics.to_csv(epoch=epoch, step=step, args=args)

    #     if metrics.was_improved(new_metrics):
    #         metrics.update(new_metrics)

    #     logging.info("Finished evaluation")

    #     return metrics
    def evaluate(
        args,
        model,
        iterator,
        vocab,
        optimizers,
        step=0,
        epoch=0,
        save_checkpoint=True,
        save_predictions=True,
        save_csv=True,
        sampled_evaluation=False,
        metrics = None
        # metrics=Metrics(),
    ):
          args_device = 0
          args_eval_on_test_only = False
          args_eval_device = 0
          args_logdir = '/content/drive/My Drive/data/checkpoints/dummy_wiki_00001'
          
          args_create_integerized_training_instance_text_length = 64
          args_create_integerized_training_instance_text_overlap = 128
          args_dont_save_checkpoints = False
          print()
          logging.info(f"Start evaluation on split {'test' if args_eval_on_test_only else 'valid'}")

          model.eval()
          model.to(args_device, args_eval_device)

          all_words, all_tags, all_y, all_y_hat, all_predicted, all_token_ids = [], [], [], [], [], []
          with torch.no_grad():
              # for iter, batch in enumerate(tqdm.tqdm(iterator)):
              for iter, batch in enumerate(tqdm(iterator)):
                  (
                      batch_token_ids,
                      label_ids,
                      label_probs,
                      eval_mask,
                      label_id_to_entity_id_dict,
                      batch_entity_ids,
                      orig_batch,
                      _,
                  ) = batch

                  logits, y, y_hat, probs, _, _ = model(batch_token_ids, None, None)  # logits: (N, T, VOCAB), y: (N, T)

                  tags = list()
                  predtags = list()
                  y_resolved_list = list()
                  y_hat_resolved_list = list()
                  token_list = list()

                  chunk_len = args_create_integerized_training_instance_text_length
                  chunk_overlap = args_create_integerized_training_instance_text_overlap

                  for batch_id, seq in enumerate(label_probs.max(-1)[1]):
                      for tok_id, label_id in enumerate(seq[chunk_overlap : -chunk_overlap]):
                          y_resolved = (
                              vocab.PAD_ID
                              if eval_mask[batch_id][tok_id + chunk_overlap] == 0
                              else label_ids[label_id].item()
                          )
                          y_resolved_list.append(y_resolved)
                          tags.append(vocab.idx2tag[y_resolved])
                          if sampled_evaluation:
                              y_hat_resolved = (
                                  vocab.PAD_ID
                                  if eval_mask[batch_id][tok_id + chunk_overlap] == 0
                                  else label_ids[y_hat[batch_id][tok_id + chunk_overlap]].item()
                              )
                          else:
                              y_hat_resolved = y_hat[batch_id][tok_id + chunk_overlap].item()
                          y_hat_resolved_list.append(y_hat_resolved)
                          predtags.append(vocab.idx2tag[y_hat_resolved])
                          token_list.append(batch_token_ids[batch_id][tok_id + chunk_overlap].item())

                  all_y.append(y_resolved_list)
                  all_y_hat.append(y_hat_resolved_list)
                  all_tags.append(tags)
                  all_predicted.append(predtags)
                  all_words.append(vocab.tokenizer.convert_ids_to_tokens(token_list))
                  all_token_ids.append(token_list)

          ## calc metric
          y_true = numpy.array(list(chain(*all_y)))
          y_pred = numpy.array(list(chain(*all_y_hat)))
          all_token_ids = numpy.array(list(chain(*all_token_ids)))

          num_proposed = len(y_pred[(vocab.OUTSIDE_ID > y_pred) & (all_token_ids > 0)])
          num_correct = (((y_true == y_pred) & (vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0))).astype(numpy.int).sum()
          num_gold = len(y_true[(vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0)])

          new_metrics = Metrics(
              epoch=epoch, step=step, num_correct=num_correct, num_proposed=num_proposed, num_gold=num_gold,
          )

          if save_predictions:
              final = args_logdir + "/%s.P%.2f_R%.2f_F%.2f" % (
                  "{}-{}".format(str(epoch), str(step)),
                  new_metrics.precision,
                  new_metrics.recall,
                  new_metrics.f1,
              )
              with open(final, "w") as fout:

                  for words, tags, y_hat, preds in zip(all_words, all_tags, all_y_hat, all_predicted):
                      assert len(preds) == len(words) == len(tags)
                      for w, t, p in zip(words, tags, preds):
                          fout.write(f"{w}\t{t}\t{p}\n")
                      fout.write("\n")

                  fout.write(f"num_proposed:{num_proposed}\n")
                  fout.write(f"num_correct:{num_correct}\n")
                  fout.write(f"num_gold:{num_gold}\n")
                  fout.write(f"precision={new_metrics.precision}\n")
                  fout.write(f"recall={new_metrics.recall}\n")
                  fout.write(f"f1={new_metrics.f1}\n")

          if not args_dont_save_checkpoints:

              if save_checkpoint and metrics.was_improved(new_metrics):
                  config = {
                      "args": args,
                      "optimizer_dense": optimizers[0].state_dict(),
                      "optimizer_sparse": optimizers[1].state_dict(),
                      "model": model.state_dict(),
                      "epoch": epoch,
                      "step": step,
                      "performance": new_metrics.dict(),
                  }
                  fname = os.path.join(args_logdir, "{}-{}".format(str(epoch), str(step)))
                  torch.save(config, f"{fname}.pt")
                  fname = os.path.join(args_logdir, new_metrics.get_best_checkpoint_filename())
                  torch.save(config, f"{fname}.pt")
                  logging.info(f"weights were saved to {fname}.pt")

          if save_csv:
              new_metrics.to_csv(epoch=epoch, step=step, args=args)

          if metrics.was_improved(new_metrics):
              metrics.update(new_metrics)

          logging.info("Finished evaluation")

          return metrics

    # def get_optimizers(self, args, checkpoint):

    #     optimizers = list()
    #     args_encoder_lr = 5e-05
    #     args_encoder_weight_decay = 0.0
    #     args_sparse = 0
    #     args_encoder_lr_scheduler_config = 0
    #     if args_encoder_lr_scheduler_config:
    #       args_encoder_lr_scheduler_config = ast.literal_eval(args_encoder_lr_scheduler_config)
    #     if args_decoder_lr_scheduler_config:
    #       args_decoder_lr_scheduler_config = ast.literal_eval(args.decoder_lr_scheduler_config)
    #     if args.segm_decoder_lr_scheduler_config:
    #     args.segm_decoder_lr_scheduler_config = ast.literal_eval(args.segm_decoder_lr_scheduler_config)

    #     args.eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size
        
    #     args.encoder_lr_scheduler_config = 0
    #     args.encoder_lr_scheduler = 0
    #     if args.encoder_lr > 0:
    #         optimizer_encoder = optim.Adam(
    #             list(self.bert.parameters()) + list(self.fc.parameters() if args.project else list()),
    #             lr=args.encoder_lr,
    #         )
    #         if args.resume_from_checkpoint is not None:
    #             optimizer_encoder.load_state_dict(checkpoint["optimizer_dense"])
    #             optimizer_encoder.param_groups[0]["lr"] = args.encoder_lr
    #             optimizer_encoder.param_groups[0]["weight_decay"] = args.encoder_weight_decay
    #         optimizers.append(optimizer_encoder)
    #     else:
    #         optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))

    #     if args.decoder_lr > 0:
    #         if args.sparse:
    #             optimizer_decoder = optim.SparseAdam(self.out.parameters(), lr=args.decoder_lr)
    #         else:
    #             optimizer_decoder = optim.Adam(self.out.parameters(), lr=args.decoder_lr)
    #         if args.resume_from_checkpoint is not None:
    #             optimizer_decoder.load_state_dict(checkpoint["optimizer_sparse"])
    #             if "weight_decay" not in optimizer_decoder.param_groups[0]:
    #                 optimizer_decoder.param_groups[0]["weight_decay"] = 0
    #             optimizer_decoder.param_groups[0]["lr"] = args.decoder_lr
    #             if not args.sparse:
    #                 optimizer_decoder.param_groups[0]["weight_decay"] = args.decoder_weight_decay
    #         optimizers.append(optimizer_decoder)
    #     else:
    #         optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))

    #     lr_schedulers = [
    #         getattr(LRSchedulers, lr_scheduler)(optimizer=optimizer, **lr_scheduler_config)
    #         for optimizer, (lr_scheduler, lr_scheduler_config) in zip(
    #             optimizers,
    #             [
    #                 (args.encoder_lr_scheduler, args.encoder_lr_scheduler_config),
    #                 (args.decoder_lr_scheduler, args.decoder_lr_scheduler_config),
    #             ],
    #         )
    #         if lr_scheduler is not None  # and not isinstance(optimizer, DummyOptimizer)
    #     ]

    #     return tuple(optimizers), tuple(lr_schedulers)
    def get_optimizers(self, args, checkpoint):

        optimizers = list()
        args_encoder_lr = 5e-05
        args_encoder_weight_decay = 0.0
        args_sparse = 0
        args_encoder_lr_scheduler_config = 0
        args_decoder_lr_scheduler_config = 0
        args_segm_decoder_lr_scheduler_config = 0
        if args_encoder_lr_scheduler_config:
          args_encoder_lr_scheduler_config = ast.literal_eval(args_encoder_lr_scheduler_config)
        if args_decoder_lr_scheduler_config:
          args_decoder_lr_scheduler_config = ast.literal_eval(args_decoder_lr_scheduler_config)
        if args_segm_decoder_lr_scheduler_config:
          args_segm_decoder_lr_scheduler_config = ast.literal_eval(args_segm_decoder_lr_scheduler_config)
        args_eval_batch_size = 1
        args_batch_size = 16
        args_eval_batch_size = args_eval_batch_size if args_eval_batch_size else args_batch_size
        args_project = 0
        args_decoder_lr = 0.1
        args_encoder_lr_scheduler_config = 0
        args_encoder_lr_scheduler = 0
        args_resume_from_checkpoint = None
        args_decoder_lr_scheduler = 0
        args_decoder_lr_scheduler_config = 0
        if args_encoder_lr > 0:
            optimizer_encoder = optim.Adam(
                list(self.bert.parameters()) + list(self.fc.parameters() if args_project else list()),
                lr=args_encoder_lr,
            )
            if args_resume_from_checkpoint is not None:
                optimizer_encoder.load_state_dict(checkpoint["optimizer_dense"])
                optimizer_encoder.param_groups[0]["lr"] = args_encoder_lr
                optimizer_encoder.param_groups[0]["weight_decay"] = args_encoder_weight_decay
            optimizers.append(optimizer_encoder)
        else:
            optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))

        if args_decoder_lr > 0:
            if args_sparse:
                optimizer_decoder = optim.SparseAdam(self.out.parameters(), lr=args_decoder_lr)
            else:
                optimizer_decoder = optim.Adam(self.out.parameters(), lr=args_decoder_lr)
            if args_resume_from_checkpoint is not None:
                optimizer_decoder.load_state_dict(checkpoint["optimizer_sparse"])
                if "weight_decay" not in optimizer_decoder.param_groups[0]:
                    optimizer_decoder.param_groups[0]["weight_decay"] = 0
                optimizer_decoder.param_groups[0]["lr"] = args_decoder_lr
                if not args_sparse:
                    optimizer_decoder.param_groups[0]["weight_decay"] = args_decoder_weight_decay
            optimizers.append(optimizer_decoder)
        else:
            optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))
        optimizer = DummyOptimizer(self.out.parameters(), defaults={})
        lr_schedulers = [LRMilestones(optimizer, milestones=[(30, 0.1), (80, 0.2), ])]
        #     getattr(LRSchedulers, lr_scheduler)(optimizer=optimizer, **lr_scheduler_config)
            # for optimizer, (lr_scheduler, lr_scheduler_config) in zip(
            #     optimizers,
            #     [
            #         (args_encoder_lr_scheduler, args_encoder_lr_scheduler_config),
            #         (args_decoder_lr_scheduler, args_decoder_lr_scheduler_config),
            #     ],
            # )
        #     if lr_scheduler is not None  # and not isinstance(optimizer, DummyOptimizer)
        # ]
            # if lr_scheduler is not None  # and not isinstance(optimizer, DummyOptimizer)
    #     ]

        return tuple(optimizers), tuple(lr_schedulers)

In [292]:
import ast
import os
import torch.cuda
import yaml
import argparse
# from torchfun import argparse_bool_type

In [293]:
import os
import pickle
from collections import OrderedDict
from operator import itemgetter

import numpy
import torch
from torch.utils import data
from tqdm import tqdm
class EDLDataset(data.Dataset):
    def __init__(self, args, split, vocab, device, label_size=None):
        args_train_loc_file = 'train.loc'
        args_valid_loc_file = 'valid.loc'
        args_test_loc_file = 'valid.loc'
        args_train_data_dir = data
        args_valid_data_dir = data
        args_test_data_dir = data
        args_data_workers = 24
        args_collect_most_popular_labels_steps = 1
        args_label_size = 8192
        args_vocab_size = 0
	
	
	
	
        if split == "train":
            loc_file_name = args_train_loc_file
            self.data_dir = args_train_data_dir
        elif split == "valid":
            loc_file_name = args_valid_loc_file
            self.data_dir = args_valid_data_dir
        elif split == "test":
            loc_file_name = args_test_loc_file
            self.data_dir = args_test_data_dir

        self.data_path =  f"/content/drive/My Drive/data/versions/dummy/wiki_training/integerized/enwiki/"
        #self.data_path = f"data/versions/{args_data_version_name}/wiki_training/integerized/{args_wiki_lang_version}/"
        self.item_locs = None
        self.device = device
        if os.path.exists("{}.pickle".format(self.data_path + loc_file_name)):
            with open("{}.pickle".format(self.data_path + loc_file_name), "rb") as f:
                self.item_locs = pickle.load(f)
        else:
            with open(self.data_path + loc_file_name) as f:
                self.item_locs = list(map(lambda x: list(map(int, x.strip().split())), tqdm(f.readlines())))
            with open("{}.pickle".format(self.data_path + loc_file_name), "wb") as f:
                pickle.dump(self.item_locs, f)
        self.pad_token_id = vocab.PAD_ID
        self.label_size = label_size
        self.is_training = split == "train"

    def get_data_iter(
        self, args, batch_size, vocab, train,
    ):
        args_collect_most_popular_labels_steps = 1
        args_data_workers = 24
        return data.DataLoader(
            dataset=self.item_locs,
            batch_size=batch_size,
            shuffle=train,
            num_workers=args_data_workers,
            collate_fn=self.collate_func(
                args=args,
                vocab=vocab,
                return_labels=args_collect_most_popular_labels_steps is not None
                and args_collect_most_popular_labels_steps > 0
                if train
                else True,
            ),
        )

    # def collate_func(self, args, vocab, return_labels, shards, shards_locks):
    def collate_func(
        self, args, vocab, return_labels, in_queue=None, out_queue=None,
    ):
        def collate(batch):
            return EDLDataset_collate_func(
                batch=batch,
                labels_with_high_model_score=None,
                args=args,
                return_labels=return_labels,
                data_path=self.data_path,
                vocab=vocab,
                is_training=self.is_training,
            )

        return collate


def EDLDataset_collate_func(
    batch,
    labels_with_high_model_score,
    args,
    return_labels,
    vocab: Vocab,
    data_path=None,
    is_training=True,
    drop_entity_mentions_prob=0.0,
    loaded_batch=None,
):
    args_label_size = 8192
    if loaded_batch is None:
        batch_dict_list = list()
        for shard, offset in batch:
            # print('{}/{}.dat'.format(data_path, shard), offset)
            with open("{}/{}.dat".format(data_path, shard), "rb") as f:
                f.seek(offset)
                (
                    token_ids_chunk,
                    mention_entity_ids_chunk,
                    mention_entity_probs_chunk,
                    mention_probs_chunk,
                ) = pickle.load(f)
                try:
                    eval_mask = list(map(is_a_wikilink_or_keyword, mention_probs_chunk))
                    mention_entity_ids_chunk = list(map(itemgetter(0), mention_entity_ids_chunk))
                    mention_entity_probs_chunk = list(map(itemgetter(0), mention_entity_probs_chunk))
                    batch_dict_list.append(
                        {
                            "token_ids": token_ids_chunk,
                            "entity_ids": mention_entity_ids_chunk,
                            "entity_probs": mention_entity_probs_chunk,
                            "eval_mask": eval_mask,
                        }
                    )
                except Exception as e:
                    print(f"pickle.load(shards[shard]) failed {e}")
                    print(mention_entity_ids_chunk)
                    print(mention_entity_probs_chunk)
                    raise e

        f = lambda x: [sample[x] for sample in batch_dict_list]
        # print(batch)
        batch_token_ids = f("token_ids")
        batch_entity_ids = f("entity_ids")
        batch_entity_probs = f("entity_probs")
        eval_mask = f("eval_mask")
        maxlen = max([len(chunk) for chunk in batch_token_ids])

        eval_mask = torch.LongTensor([sample + [0] * (maxlen - len(sample)) for sample in eval_mask])

        # create dictionary mapping the vocabulary entity id to a batch label id
        #
        # e.g.
        # all_batch_entity_ids[324] = 0
        # all_batch_entity_ids[24]  = 1
        # all_batch_entity_ids[2]   = 2
        # all_batch_entity_ids[987] = 3
        #
        all_batch_entity_ids = OrderedDict()

        for batch_offset, (batch_item_token_item_entity_ids, batch_item_token_entity_probs) in enumerate(
            zip(batch_entity_ids, batch_entity_probs)
        ):
            for tok_id, (token_entity_ids, token_entity_probs) in enumerate(
                zip(batch_item_token_item_entity_ids, batch_item_token_entity_probs)
            ):
                for eid in token_entity_ids:
                    if eid not in all_batch_entity_ids:
                        all_batch_entity_ids[eid] = len(all_batch_entity_ids)

        loaded_batch = (
            batch_token_ids,
            batch_entity_ids,
            batch_entity_probs,
            eval_mask,
            all_batch_entity_ids,
            maxlen,
        )

    else:
        (batch_token_ids, batch_entity_ids, batch_entity_probs, eval_mask, all_batch_entity_ids, maxlen,) = loaded_batch

    batch_token_ids = torch.LongTensor([sample + [0] * (maxlen - len(sample)) for sample in batch_token_ids])

    if return_labels:

        # if labels for each token should be over
        # a. the whole entity vocabulary
        # b. a reduced set of entities composed of:
        #       set of batch's true entities, entities
        #       set of entities with the largest logits
        #       set of negative samples

        if args_label_size is None:

            batch_shared_label_ids = list(all_batch_entity_ids.keys())
            label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), args_vocab_size)

        else:

            # batch_shared_label_ids are constructing by incrementally concatenating
            #       set of batch's true entities, entities
            #       set of entities with the largest logits
            #       set of negative samples

            batch_shared_label_ids = list(all_batch_entity_ids.keys())

            if len(batch_shared_label_ids) < args_label_size and labels_with_high_model_score is not None:
                # print(labels_with_high_model_score)
                negative_examples = set(labels_with_high_model_score)
                negative_examples.difference_update(batch_shared_label_ids)
                batch_shared_label_ids += list(negative_examples)

            if len(batch_shared_label_ids) < args_label_size:
                negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, args_label_size, replace=False))
                negative_samples.difference_update(batch_shared_label_ids)
                batch_shared_label_ids += list(negative_samples)

            batch_shared_label_ids = batch_shared_label_ids[: args_label_size]

            label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), len(batch_shared_label_ids))

        drop_probs = None
        if drop_entity_mentions_prob > 0 and is_training:
            drop_probs = torch.rand((batch_token_ids.size(0), batch_token_ids.size(1)),) < drop_entity_mentions_prob

        # loop through the batch x tokens x (label_ids, label_probs)
        for batch_offset, (batch_item_token_item_entity_ids, batch_item_token_entity_probs) in enumerate(
            zip(batch_entity_ids, batch_entity_probs)
        ):
            # loop through tokens x (label_ids, label_probs)
            for tok_id, (token_entity_ids, token_entity_probs) in enumerate(
                zip(batch_item_token_item_entity_ids, batch_item_token_entity_probs)
            ):
                if drop_entity_mentions_prob > 0 and is_training and drop_probs[batch_offset][tok_id].item() == 1:
                    batch_token_ids[batch_offset][tok_id] = vocab.tokenizer.vocab["[MASK]"]

                if args_label_size is None:
                    label_probs[batch_offset][tok_id][torch.LongTensor(token_entity_ids)] = torch.Tensor(
                        batch_item_token_item_entity_ids
                    )
                else:
                    label_probs[batch_offset][tok_id][
                        torch.LongTensor(list(map(all_batch_entity_ids.__getitem__, token_entity_ids)))
                    ] = torch.Tensor(token_entity_probs)

        label_ids = torch.LongTensor(batch_shared_label_ids)

        return (
            batch_token_ids,
            label_ids,
            label_probs,
            torch.LongTensor(eval_mask),
            {v: k for k, v in all_batch_entity_ids.items()},
            batch_entity_ids,
            batch,
            loaded_batch,
        )

    else:

        return batch_token_ids, None, None, None, None, None, batch, loaded_batch

# hack to detect if an entity annotation was a
# wikilink (== only one entity label) or a
# keyword matcher annotation (== multiple entity labels)
def is_a_wikilink_or_keyword(item):
    if len(item) == 1:
        return 1
    else:
        return 0
# from vocab import Vocab
# class EDLDataset(data.Dataset):
#     def __init__(self, args, split, vocab, device, label_size=None):

#         if split == "train":
#             loc_file_name = args.train_loc_file
#             self.data_dir = args.train_data_dir
#         elif split == "valid":
#             loc_file_name = args.valid_loc_file
#             self.data_dir = args.valid_data_dir
#         elif split == "test":
#             loc_file_name = args.test_loc_file
#             self.data_dir = args.test_data_dir


#         self.data_path = f"data/versions/{args.data_version_name}/wiki_training/integerized/{args.wiki_lang_version}/"
#         self.item_locs = None
#         self.device = device
#         if os.path.exists("{}.pickle".format(self.data_path + loc_file_name)):
#             with open("{}.pickle".format(self.data_path + loc_file_name), "rb") as f:
#                 self.item_locs = pickle.load(f)
#         else:
#             with open(self.data_path + loc_file_name) as f:
#                 self.item_locs = list(map(lambda x: list(map(int, x.strip().split())), tqdm(f.readlines())))
#             with open("{}.pickle".format(self.data_path + loc_file_name), "wb") as f:
#                 pickle.dump(self.item_locs, f)
#         self.pad_token_id = vocab.PAD_ID
#         self.label_size = label_size
#         self.is_training = split == "train"

#     def get_data_iter(
#         self, args, batch_size, vocab, train,
#     ):
#         return data.DataLoader(
#             dataset=self.item_locs,
#             batch_size=batch_size,
#             shuffle=train,
#             num_workers=args.data_workers,
#             collate_fn=self.collate_func(
#                 args=args,
#                 vocab=vocab,
#                 return_labels=args.collect_most_popular_labels_steps is not None
#                 and args.collect_most_popular_labels_steps > 0
#                 if train
#                 else True,
#             ),
#         )

#     # def collate_func(self, args, vocab, return_labels, shards, shards_locks):
#     def collate_func(
#         self, args, vocab, return_labels, in_queue=None, out_queue=None,
#     ):
#         def collate(batch):
#             return EDLDataset_collate_func(
#                 batch=batch,
#                 labels_with_high_model_score=None,
#                 args=args,
#                 return_labels=return_labels,
#                 data_path=self.data_path,
#                 vocab=vocab,
#                 is_training=self.is_training,
#             )

#         return collate


# def EDLDataset_collate_func(
#     batch,
#     labels_with_high_model_score,
#     args,
#     return_labels,
#     vocab: Vocab,
#     data_path=None,
#     is_training=True,
#     drop_entity_mentions_prob=0.0,
#     loaded_batch=None,
# ):
#     if loaded_batch is None:
#         batch_dict_list = list()
#         for shard, offset in batch:
#             # print('{}/{}.dat'.format(data_path, shard), offset)
#             with open("{}/{}.dat".format(data_path, shard), "rb") as f:
#                 f.seek(offset)
#                 (
#                     token_ids_chunk,
#                     mention_entity_ids_chunk,
#                     mention_entity_probs_chunk,
#                     mention_probs_chunk,
#                 ) = pickle.load(f)
#                 try:
#                     eval_mask = list(map(is_a_wikilink_or_keyword, mention_probs_chunk))
#                     mention_entity_ids_chunk = list(map(itemgetter(0), mention_entity_ids_chunk))
#                     mention_entity_probs_chunk = list(map(itemgetter(0), mention_entity_probs_chunk))
#                     batch_dict_list.append(
#                         {
#                             "token_ids": token_ids_chunk,
#                             "entity_ids": mention_entity_ids_chunk,
#                             "entity_probs": mention_entity_probs_chunk,
#                             "eval_mask": eval_mask,
#                         }
#                     )
#                 except Exception as e:
#                     print(f"pickle.load(shards[shard]) failed {e}")
#                     print(mention_entity_ids_chunk)
#                     print(mention_entity_probs_chunk)
#                     raise e

#         f = lambda x: [sample[x] for sample in batch_dict_list]
#         # print(batch)
#         batch_token_ids = f("token_ids")
#         batch_entity_ids = f("entity_ids")
#         batch_entity_probs = f("entity_probs")
#         eval_mask = f("eval_mask")
#         maxlen = max([len(chunk) for chunk in batch_token_ids])

#         eval_mask = torch.LongTensor([sample + [0] * (maxlen - len(sample)) for sample in eval_mask])

#         # create dictionary mapping the vocabulary entity id to a batch label id
#         #
#         # e.g.
#         # all_batch_entity_ids[324] = 0
#         # all_batch_entity_ids[24]  = 1
#         # all_batch_entity_ids[2]   = 2
#         # all_batch_entity_ids[987] = 3
#         #
#         all_batch_entity_ids = OrderedDict()

#         for batch_offset, (batch_item_token_item_entity_ids, batch_item_token_entity_probs) in enumerate(
#             zip(batch_entity_ids, batch_entity_probs)
#         ):
#             for tok_id, (token_entity_ids, token_entity_probs) in enumerate(
#                 zip(batch_item_token_item_entity_ids, batch_item_token_entity_probs)
#             ):
#                 for eid in token_entity_ids:
#                     if eid not in all_batch_entity_ids:
#                         all_batch_entity_ids[eid] = len(all_batch_entity_ids)

#         loaded_batch = (
#             batch_token_ids,
#             batch_entity_ids,
#             batch_entity_probs,
#             eval_mask,
#             all_batch_entity_ids,
#             maxlen,
#         )

#     else:
#         (batch_token_ids, batch_entity_ids, batch_entity_probs, eval_mask, all_batch_entity_ids, maxlen,) = loaded_batch

#     batch_token_ids = torch.LongTensor([sample + [0] * (maxlen - len(sample)) for sample in batch_token_ids])

#     if return_labels:

#         # if labels for each token should be over
#         # a. the whole entity vocabulary
#         # b. a reduced set of entities composed of:
#         #       set of batch's true entities, entities
#         #       set of entities with the largest logits
#         #       set of negative samples

#         if args.label_size is None:

#             batch_shared_label_ids = list(all_batch_entity_ids.keys())
#             label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), args.vocab_size)

#         else:

#             # batch_shared_label_ids are constructing by incrementally concatenating
#             #       set of batch's true entities, entities
#             #       set of entities with the largest logits
#             #       set of negative samples

#             batch_shared_label_ids = list(all_batch_entity_ids.keys())

#             if len(batch_shared_label_ids) < args.label_size and labels_with_high_model_score is not None:
#                 # print(labels_with_high_model_score)
#                 negative_examples = set(labels_with_high_model_score)
#                 negative_examples.difference_update(batch_shared_label_ids)
#                 batch_shared_label_ids += list(negative_examples)

#             if len(batch_shared_label_ids) < args.label_size:
#                 negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, args.label_size, replace=False))
#                 negative_samples.difference_update(batch_shared_label_ids)
#                 batch_shared_label_ids += list(negative_samples)

#             batch_shared_label_ids = batch_shared_label_ids[: args.label_size]

#             label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), len(batch_shared_label_ids))

#         drop_probs = None
#         if drop_entity_mentions_prob > 0 and is_training:
#             drop_probs = torch.rand((batch_token_ids.size(0), batch_token_ids.size(1)),) < drop_entity_mentions_prob

#         # loop through the batch x tokens x (label_ids, label_probs)
#         for batch_offset, (batch_item_token_item_entity_ids, batch_item_token_entity_probs) in enumerate(
#             zip(batch_entity_ids, batch_entity_probs)
#         ):
#             # loop through tokens x (label_ids, label_probs)
#             for tok_id, (token_entity_ids, token_entity_probs) in enumerate(
#                 zip(batch_item_token_item_entity_ids, batch_item_token_entity_probs)
#             ):
#                 if drop_entity_mentions_prob > 0 and is_training and drop_probs[batch_offset][tok_id].item() == 1:
#                     batch_token_ids[batch_offset][tok_id] = vocab.tokenizer.vocab["[MASK]"]

#                 if args.label_size is None:
#                     label_probs[batch_offset][tok_id][torch.LongTensor(token_entity_ids)] = torch.Tensor(
#                         batch_item_token_item_entity_ids
#                     )
#                 else:
#                     label_probs[batch_offset][tok_id][
#                         torch.LongTensor(list(map(all_batch_entity_ids.__getitem__, token_entity_ids)))
#                     ] = torch.Tensor(token_entity_probs)

#         label_ids = torch.LongTensor(batch_shared_label_ids)

#         return (
#             batch_token_ids,
#             label_ids,
#             label_probs,
#             torch.LongTensor(eval_mask),
#             {v: k for k, v in all_batch_entity_ids.items()},
#             batch_entity_ids,
#             batch,
#             loaded_batch,
#         )

#     else:

#         return batch_token_ids, None, None, None, None, None, batch, loaded_batch

# # hack to detect if an entity annotation was a
# # wikilink (== only one entity label) or a
# # keyword matcher annotation (== multiple entity labels)
# def is_a_wikilink_or_keyword(item):
#     if len(item) == 1:
#         return 1
#     else:
#         return 0


In [294]:
class CONLLEDLDataset(data.Dataset):
    def __init__(self, args, split, vocab, device, label_size=None):

        if split == "train":
            train_valid_test_int = 0
        if split == "small_valid" or split == "valid":
            train_valid_test_int = 1
        if split == "test":
            train_valid_test_int = 2

        chunk_len = args.create_integerized_training_instance_text_length
        chunk_overlap = args.create_integerized_training_instance_text_overlap

        self.item_locs = None
        self.device = device
        with open(args.data_path_conll, "rb") as f:
            train_valid_test = pickle.load(f)
            self.conll_docs = torch.LongTensor(
                [
                    [
                        pad_to(
                            [tok_id for _, tok_id, _, _, _, _, _ in doc],
                            max_len=chunk_len + 2,
                            pad_id=0,
                            cls_id=101,
                            sep_id=102,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                    [
                        pad_to(
                            [bio_id for _, _, _, bio_id, _, _, _ in doc],
                            max_len=chunk_len + 2,
                            pad_id=2,
                            cls_id=2,
                            sep_id=2,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                    [
                        pad_to(
                            [wiki_id for _, _, _, _, _, wiki_id, _ in doc],
                            max_len=chunk_len + 2,
                            pad_id=vocab.PAD_ID,
                            cls_id=vocab.PAD_ID,
                            sep_id=vocab.PAD_ID,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                    [
                        pad_to(
                            [doc_id for _, _, _, _, _, _, doc_id in doc],
                            max_len=chunk_len + 2,
                            pad_id=0,
                            cls_id=0,
                            sep_id=0,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                ]
            ).permute(1, 0, 2)
            self.conll_docs[:, 2] = set_out_id(self.conll_docs[:, 2], vocab.OUTSIDE_ID)

        self.pad_token_id = vocab.PAD_ID
        self.label_size = label_size
        self.labels = None
        self.train_valid_test_int = train_valid_test_int

    def get_data_iter(
        self, args, batch_size, vocab, train,
    ):
        return data.DataLoader(
            dataset=self.conll_docs,
            batch_size=batch_size,
            shuffle=train,
            num_workers=args.data_workers,
            collate_fn=self.collate_func(
                args,
                return_labels=args.collect_most_popular_labels_steps is not None
                and args.collect_most_popular_labels_steps > 0
                if train
                else True,
                vocab=vocab,
            ),
        )

    def collate_func(self, args, vocab, return_labels):
        def collate(batch):
            return CONLLEDLDataset_collate_func(
                batch=batch,
                labels_with_high_model_score=self.labels,
                args=args,
                return_labels=return_labels,
                vocab=vocab,
                is_training=self.train_valid_test_int == 0,
            )

        return collate


def CONLLEDLDataset_collate_func(
    batch, labels_with_high_model_score, args, return_labels, vocab: Vocab, is_training=False,
):
    drop_entity_mentions_prob = args.maskout_entity_prob
    # print([b[0] for b in batch])
    label_size = args.label_size
    batch_token_ids = torch.LongTensor([b[0].tolist() for b in batch])
    batch_bio_ids = [b[1].tolist() for b in batch]
    batch_entity_ids = [b[2].tolist() for b in batch]
    batch_doc_ids = [b[3, 0].item() for b in batch]

    if return_labels:

        all_batch_entity_ids = OrderedDict()

        for batch_offset, one_item_entity_ids in enumerate(batch_entity_ids):
            for tok_id, eid in enumerate(one_item_entity_ids):
                if eid not in all_batch_entity_ids:
                    all_batch_entity_ids[eid] = len(all_batch_entity_ids)

        if label_size is not None:

            batch_shared_label_ids = all_batch_entity_ids.keys()
            negative_samples = set()
            if labels_with_high_model_score is not None:
                # print(labels_with_high_model_score)
                negative_samples = set(labels_with_high_model_score)
            # else:
            #     negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False))
            if len(negative_samples) < label_size:
                random_negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False))
                negative_samples = negative_samples.union(random_negative_samples)

            negative_samples.difference_update(batch_shared_label_ids)

            if len(batch_shared_label_ids) + len(negative_samples) < label_size:
                negative_samples.difference_update(
                    set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False))
                )

            batch_shared_label_ids = (list(batch_shared_label_ids) + list(negative_samples))[:label_size]
            label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), len(batch_shared_label_ids))
            bio_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), 3)

        else:

            batch_shared_label_ids = list(all_batch_entity_ids.keys())
            label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), args.vocab_size)
            bio_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), 3)

        drop_probs = None
        if drop_entity_mentions_prob > 0 and is_training:
            drop_probs = torch.rand((batch_token_ids.size(0), batch_token_ids.size(1)),) < drop_entity_mentions_prob

        for batch_offset, (one_item_entity_ids, one_item_bio_ids) in enumerate(zip(batch_entity_ids, batch_bio_ids)):
            for tok_id, one_entity_ids in enumerate(one_item_entity_ids):

                if (
                    is_training
                    and vocab.OUTSIDE_ID != one_entity_ids
                    and drop_entity_mentions_prob > 0
                    and drop_probs[batch_offset][tok_id].item() == 1
                ):
                    batch_token_ids[batch_offset][tok_id] = vocab.tokenizer.vocab["[MASK]"]

                if label_size is not None:
                    label_probs[batch_offset][tok_id][torch.LongTensor([all_batch_entity_ids[one_entity_ids]])] = 1.0
                else:
                    label_probs[batch_offset][tok_id][torch.LongTensor(one_entity_ids)] = 1.0
                bio_probs[batch_offset][tok_id][torch.LongTensor(one_item_bio_ids)] = 1.0

        label_ids = torch.LongTensor(batch_shared_label_ids)

        return (
            batch_token_ids,
            label_ids,
            torch.LongTensor(batch_bio_ids),
            torch.FloatTensor(label_probs),
            bio_probs,
            None,
            {v: k for k, v in all_batch_entity_ids.items()},
            batch_entity_ids,
            batch_doc_ids,
            batch,
        )

    else:

        return batch_token_ids, None, None, None, None, None, None, None, None, batch

In [295]:
import logging
import os
import re

import numpy
import torch
import torch.nn as nn
from torch import optim
from tqdm import trange, tqdm

# from metrics import Metrics
# from data_loader_conll import CONLLEDLDataset_collate_func
# from misc import (
#     running_mean,
#     get_entity_annotations,
#     get_entity_annotations_with_gold_spans,
#     DummyOptimizer,
#     LRSchedulers,
#     create_overlapping_chunks,
#     get_topk_ids_aggregated_from_seq_prediction,
# )
# !pip install transformers
from transformers import BertModel

bio_id = {
    "B": 0,
    "I": 1,
    "O": 2,
}
bio_id_inv = {
    0: "B",
    1: "I",
    2: "O",
}


class ConllNet(nn.Module):
    def __init__(
        self, args, vocab_size=None,
    ):
        super().__init__()
        if args.uncased:
            self.bert = BertModel.from_pretrained("bert-base-uncased")
        else:
            self.bert = BertModel.from_pretrained("bert-base-cased")

        self.top_rnns = args.top_rnns
        if args.top_rnns:
            self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768 // 2, batch_first=True)
        self.fc = None
        if args.project:
            self.fc = nn.Linear(768, args.entity_embedding_size)
        self.out = nn.Embedding(num_embeddings=vocab_size, embedding_dim=args.entity_embedding_size, sparse=args.sparse)

        self.out_segm = nn.Sequential(nn.Dropout(args.bert_dropout), nn.Linear(768, 768), nn.Tanh(), nn.Linear(768, 3),)

        # torch.nn.init.normal_(self.out, std=0.1)

        if args.bert_dropout and args.bert_dropout > 0:
            for m in self.bert.modules():
                if isinstance(m, torch.nn.Dropout):
                    m.p = args.bert_dropout

        self.device = args.device
        self.out_device = args.out_device
        self.finetuning = args.finetuning
        self.vocab_size = vocab_size

    def to(self, device, out_device):
        self.bert.to(device)
        if self.fc:
            self.fc.to(device)
        self.out.to(out_device)
        self.out_segm.to(device)
        self.device = device
        self.out_device = out_device

    def forward(self, x, y=None, probs=None, segm_probs=None, enc=None):
        """
        x: (N, T). int64
        y: (N, T). int64
        Returns
        enc: (N, T, VOCAB)
        """
        if y is not None:
            y = y.to(self.out_device)
        if probs is not None:
            probs = probs.to(self.out_device)
        if segm_probs is not None:
            segm_probs = segm_probs.to(self.out_device)

            # fake_y = torch.Tensor(range(10)).long().to(self.device)

        if enc is None:
            x = x.to(self.device)
            if self.training:
                if self.finetuning:
                    # print("->bert.train()")
                    self.bert.train()
                    encoded_layers, _ = self.bert(x)
                    enc = encoded_layers
                else:
                    self.bert.eval()
                    with torch.no_grad():
                        encoded_layers, _ = self.bert(x)
                        enc = encoded_layers
            else:
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers

            if self.top_rnns:
                enc, _ = self.rnn(enc)

            if self.fc:
                enc = self.fc(enc)

            enc = enc.to(self.out_device)

        logits_segm = self.out_segm(enc)

        if y is not None:
            out = self.out(y)
            logits = enc.matmul(out.transpose(0, 1))
            y_hat = logits.argmax(-1)
            bio_y_hat = logits_segm.argmax(-1)
            return logits, y, y_hat, probs, segm_probs, out, enc, logits_segm, bio_y_hat
        else:
            with torch.no_grad():
                out = self.out.weight
                logits = enc.matmul(out.transpose(0, 1))
                y_hat = logits.argmax(-1)
                bio_y_hat = logits_segm.argmax(-1)
                return logits, None, y_hat, None, None, None, enc, None, bio_y_hat

    @staticmethod
    def train_one_epoch(
        args,
        model,
        train_iter,
        optimizers,
        criterion,
        eval_iter,
        vocab,
        epoch,
        metrics=Metrics(),
        loss_aggr=None,
    ):

        with trange(len(train_iter)) as t:
            for iter, batch in enumerate(train_iter):

                model.to(
                    args.device, args.out_device,
                )
                model.train()

                (
                    batch_token_ids,
                    label_ids,
                    _,
                    label_probs,
                    batch_bio_probs,
                    _,
                    label_id_to_entity_id_dict,
                    batch_entity_ids,
                    batch_doc_ids,
                    orig_batch,
                ) = batch

                enc = None

                labels_with_high_model_score = list()
                if (
                    args.collect_most_popular_labels_steps is not None
                    and args.collect_most_popular_labels_steps > 0
                    and iter > 0
                    and iter % args.collect_most_popular_labels_steps == 0
                ):
                    model.to(args.device, args.eval_device)
                    logits, _, y_hat, _, _, _, enc, segm_logits, segm_pred = model(
                        batch_token_ids, None, None, batch_bio_probs
                    )  # logits: (N, T, VOCAB), y: (N, T)
                    labels_with_high_model_score = get_topk_ids_aggregated_from_seq_prediction(
                        logits, topk_from_batch=args.label_size, topk_per_token=args.topk_neg_examples
                    )
                    (
                        batch_token_ids,
                        label_ids,
                        _,
                        label_probs,
                        batch_bio_probs,
                        _,
                        label_id_to_entity_id_dict,
                        batch_entity_ids,
                        batch_doc_ids,
                        orig_batch,
                    ) = CONLLEDLDataset_collate_func(
                        args=args,
                        labels_with_high_model_score=labels_with_high_model_score,
                        batch=orig_batch,
                        return_labels=True,
                        vocab=vocab,
                    )

                # if args.label_size is not None:
                logits, y, y_hat, label_probs, batch_bio_probs, sparse_params, _, segm_logits, segm_pred = model(
                    batch_token_ids, label_ids, label_probs, batch_bio_probs, enc=enc
                )  # logits: (N, T, VOCAB), y: (N, T)
                # else:
                #     logits, y, y_hat, label_probs, sparse_params = model(batch_token_ids, None, label_probs) # logits: (N, T, VOCAB), y: (N, T)

                # logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
                logits = logits.view(-1)  # (N*T, VOCAB)
                segm_logits = segm_logits.view(-1)  # (N*T, VOCAB)
                label_probs = label_probs.view(-1)  # (N*T,)
                batch_bio_probs = batch_bio_probs.view(-1)

                task_importance_ratio = 0.1

                if args.learn_segmentation:
                    loss = (1 - task_importance_ratio) * criterion(
                        logits, label_probs
                    ) + task_importance_ratio * criterion(segm_logits, batch_bio_probs)
                else:
                    loss = criterion(logits, label_probs)

                loss.backward()

                if (iter + 1) % args.accumulate_batch_gradients == 0:
                    for optimizer in optimizers:
                        optimizer.step()
                        optimizer.zero_grad()

                if iter == 0:
                    logging.debug("=====sanity check======")
                    logging.debug("x:", batch_token_ids.cpu().numpy()[0])
                    logging.debug("tokens:", vocab.tokenizer.convert_ids_to_tokens(batch_token_ids.cpu().numpy()[0]))
                    logging.debug("y:", label_probs.cpu().numpy()[0])
                    logging.debug("=======================")

                loss_aggr = running_mean(loss.detach().item(), loss_aggr)

                if iter > 0 and iter % args.checkpoint_eval_steps == 0:
                    metrics = ConllNet.evaluate(
                        args=args,
                        model=model,
                        iterator=eval_iter,
                        optimizers=optimizers,
                        step=iter,
                        epoch=epoch,
                        save_checkpoint=iter % args.checkpoint_save_steps == 0,
                        sampled_evaluation=False,
                        metrics=metrics,
                        vocab=vocab,
                    )

                t.set_postfix(
                    loss=loss_aggr,
                    # nr_labels=len(label_ids),
                    # aggr_labels=len(labels_with_high_model_score) if labels_with_high_model_score else 0,
                    last_eval=metrics.report(filter={"f1", "span_f1", "lenient_span_f1", "epoch", "step"}),
                )
                t.update()

        for optimizer in optimizers:
            optimizer.step()
            optimizer.zero_grad()

        return metrics

    @staticmethod
    def evaluate(
        args,
        model,
        iterator,
        vocab,
        optimizers,
        step=0,
        epoch=0,
        save_checkpoint=True,
        save_predictions=True,
        save_csv=True,
        sampled_evaluation=False,
        metrics=Metrics(),
    ):

        logging.info(f"Start evaluation on split {'test' if args.eval_on_test_only else 'valid'}")

        model.eval()
        model.to(args.device, args.eval_device)

        chunk_len = args.create_integerized_training_instance_text_length
        chunk_overlap = args.create_integerized_training_instance_text_overlap

        all_words = list()
        all_tags = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_y = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_y_hat = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_segm_preds = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_y_hat_gold_mentions = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_logits = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_predicted = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))
        all_token_ids = [0] * (len(iterator) * args.eval_batch_size * (chunk_len))

        all_y_hat_scores = torch.ones(len(iterator) * args.eval_batch_size * (chunk_len)) * -1e10
        all_y_hat_gold_mentions_scores = torch.ones(len(iterator) * args.eval_batch_size * (chunk_len)) * -1e10

        best_scores = torch.ones(len(iterator) * args.eval_batch_size * (chunk_len)) * -1e10

        offset = 0
        last_doc = -1

        # new_best_top1_logit = torch.ones((chunk_len))
        # new_best_top2_logit_gold_mentions = torch.ones((chunk_len))

        with torch.no_grad():

            for iter, batch in enumerate(tqdm(iterator)):

                (
                    batch_token_ids,
                    label_ids,
                    batch_bio_ids,
                    label_probs,
                    batch_bio_probs,
                    _,
                    label_id_to_entity_id_dict,
                    batch_entity_ids,
                    batch_doc_ids,
                    orig_batch,
                ) = batch
                eval_mask = batch_bio_ids == 2
                logits, y, y_hat, probs, batch_bio_probs, _, _, segm_logits, segm_preds = model(
                    batch_token_ids, None, label_probs, batch_bio_probs
                )  # logits: (N, T, VOCAB), y: (N, T)

                logits = logits[:, 1:-1, :]
                label_probs = label_probs[:, 1:-1, :]
                y_hat = y_hat[:, 1:-1]
                segm_preds = segm_preds[:, 1:-1]
                eval_mask = eval_mask[:, 1:-1]
                batch_token_ids = batch_token_ids[:, 1:-1]

                top2_logit, top2 = logits.topk(k=2, dim=-1,)
                top2_select = (y_hat >= vocab.OUTSIDE_ID).long()

                y_hat_gold_mentions = (
                    top2.view(-1, top2.size(-1)).gather(dim=1, index=top2_select.view(-1, 1)).view(y_hat.size())
                )
                top2_logit_gold_mentions = (
                    top2_logit.view(-1, top2_logit.size(-1))
                    .gather(dim=1, index=top2_select.view(-1, 1))
                    .view(y_hat.size())
                    .to("cpu")
                )

                top1_logit, _ = logits.to("cpu").max(dim=-1,)
                top1_probs = torch.sigmoid(top1_logit)

                for batch_id, seq in enumerate(label_probs.max(-1)[1]):

                    if last_doc >= 0:
                        if last_doc == batch_doc_ids[batch_id]:
                            next_step = chunk_len - chunk_overlap
                        else:
                            last_doc = batch_doc_ids[batch_id]
                            next_step = chunk_len

                        offset += next_step
                    else:
                        last_doc = batch_doc_ids[batch_id]

                    new_best_top1_logit = top1_logit[batch_id] > best_scores[offset : offset + chunk_len]
                    new_best_top2_logit_gold_mentions = (
                        top2_logit_gold_mentions[batch_id] > best_scores[offset : offset + chunk_len]
                    )

                    all_y_hat_scores[offset : offset + chunk_len] = (
                        new_best_top1_logit.float() * top1_logit[batch_id]
                        + (1.0 - new_best_top1_logit.float()) * all_y_hat_scores[offset : offset + chunk_len]
                    )
                    all_y_hat_gold_mentions_scores[offset : offset + chunk_len] = (
                        new_best_top2_logit_gold_mentions.float() * top1_logit[batch_id]
                        + (1.0 - new_best_top2_logit_gold_mentions.float())
                        * all_y_hat_gold_mentions_scores[offset : offset + chunk_len]
                    )

                    for tok_id, label_id in enumerate(seq):

                        y_resolved = (
                            label_ids[label_id].item() if eval_mask[batch_id][tok_id] == 0 else vocab.OUTSIDE_ID
                        )
                        all_y[offset + tok_id] = y_resolved
                        all_tags[offset + tok_id] = vocab.idx2tag[y_resolved]

                        y_hat_resolved = (
                            new_best_top1_logit[tok_id].item() * y_hat[batch_id][tok_id].item()
                            + (1 - new_best_top1_logit[tok_id].item()) * all_y_hat[offset + tok_id]
                        )
                        all_y_hat[offset + tok_id] = y_hat_resolved
                        all_predicted[offset + tok_id] = vocab.idx2tag[y_hat_resolved]

                        y_hat_gold_mentions_resolved = (
                            new_best_top2_logit_gold_mentions[tok_id].item()
                            * y_hat_gold_mentions[batch_id][tok_id].item()
                            + (1 - new_best_top2_logit_gold_mentions[tok_id].item()) * all_y_hat[offset + tok_id]
                        )
                        all_y_hat_gold_mentions[offset + tok_id] = y_hat_gold_mentions_resolved

                        all_segm_preds[offset + tok_id] = segm_preds[batch_id][tok_id].item()
                        all_token_ids[offset + tok_id] = batch_token_ids[batch_id][tok_id].item()
                        all_logits[offset + tok_id] = top1_probs[batch_id][tok_id].item()

        all_tags = all_tags[: offset + chunk_len]
        all_y = all_y[: offset + chunk_len]
        all_y_hat = all_y_hat[: offset + chunk_len]
        all_y_hat_gold_mentions = all_y_hat_gold_mentions[: offset + chunk_len]
        all_logits = all_logits[: offset + chunk_len]
        all_predicted = all_predicted[: offset + chunk_len]
        all_token_ids = all_token_ids[: offset + chunk_len]
        all_segm_preds = all_segm_preds[: offset + chunk_len]

        for chunk in create_overlapping_chunks(all_token_ids, 512, 0):
            all_words.extend(vocab.tokenizer.convert_ids_to_tokens(chunk))

        ## calc metric
        y_true = numpy.array(all_y)
        y_pred = numpy.array(all_y_hat)
        y_pred_gold_mentions = numpy.array(all_y_hat_gold_mentions)
        all_token_ids = numpy.array(all_token_ids)

        spans_true = get_entity_annotations(y_true, vocab.OUTSIDE_ID)
        spans_pred = get_entity_annotations(y_pred, vocab.OUTSIDE_ID)
        spans_pred_gold_mentions = get_entity_annotations_with_gold_spans(
            y_pred_gold_mentions, y_true, vocab.OUTSIDE_ID
        )

        overlaps = list()
        for anno in spans_pred:
            overlaps.extend(filter(lambda s: len(set(anno[0]) & set(s[0])) > 0 and anno[1] == s[1], spans_true))

        overlaps_gold_mentions = list()
        for anno in spans_pred_gold_mentions:
            overlaps_gold_mentions.extend(
                filter(lambda s: len(set(anno[0]) & set(s[0])) > 0 and anno[1] == s[1], spans_true)
            )
        num_lenient_correct_gold_mentions = len(set(overlaps_gold_mentions))

        num_proposed = len(y_pred[(vocab.OUTSIDE_ID > y_pred) & (all_token_ids > 0)])
        num_correct = (((y_true == y_pred) & (vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0))).astype(numpy.int).sum()

        num_correct_gold_mentions = (
            (((y_true == y_pred_gold_mentions) & (vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0)))
            .astype(numpy.int)
            .sum()
        )

        num_gold = len(y_true[(vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0)])

        num_spans_correct = len(set(spans_true).intersection(set(spans_pred)))
        num_spans_true = len(set(spans_true))
        num_spans_proposed = len(set(spans_pred)) if len(set(spans_pred)) > 0 else 0

        num_lenient_correct_spans = len(set(overlaps))

        new_metrics = Metrics(
            epoch=epoch,
            step=step,
            num_correct=num_correct,
            num_gold=num_gold,
            num_proposed=num_proposed,
            # in this setting all gold mentions are scored which is why num_gold == num_proposed
            precision_gold_mentions=Metrics.compute_precision(correct=num_correct_gold_mentions, proposed=num_gold),
            span_precision=Metrics.compute_precision(correct=num_spans_correct, proposed=num_spans_proposed),
            span_recall=Metrics.compute_recall(correct=num_spans_correct, gold=num_spans_true),
            span_f1=Metrics.compute_fmeasure(
                precision=Metrics.compute_precision(correct=num_spans_correct, proposed=num_spans_proposed),
                recall=Metrics.compute_recall(correct=num_spans_correct, gold=num_spans_true),
            ),
            lenient_span_precision=Metrics.compute_precision(
                correct=num_lenient_correct_spans, proposed=num_spans_proposed
            ),
            lenient_span_recall=Metrics.compute_recall(correct=num_lenient_correct_spans, gold=num_spans_true),
            lenient_span_f1=Metrics.compute_fmeasure(
                precision=Metrics.compute_precision(correct=num_lenient_correct_spans, proposed=num_spans_proposed),
                recall=Metrics.compute_recall(correct=num_lenient_correct_spans, gold=num_spans_true),
            ),
        )

        if save_predictions:
            final = (
                args.logdir
                + "/{}-{}-MENTION-S_P_{:.2f}_R_{:.2f}_F1_{:.2f}-LS_P_{:.2f}_R_{:.2f}_F1_{:.2f}-T_P_{:.2f}_R_{:.2f}_F1_{:.2f}-LINK-S_P_{:.2f}.txt".format(
                    epoch,
                    step,
                    new_metrics.span_precision,
                    new_metrics.span_recall,
                    new_metrics.span_f1,
                    new_metrics.lenient_span_precision,
                    new_metrics.lenient_span_recall,
                    new_metrics.lenient_span_f1,
                    new_metrics.precision,
                    new_metrics.recall,
                    new_metrics.f1,
                    new_metrics.precision_gold_mentions,
                )
            )
            with open(final, "w") as fout:

                for words, tags, y_hat, preds, segm_pred, logits in zip(
                    all_words, all_tags, all_y_hat, all_predicted, all_segm_preds, all_logits
                ):
                    fout.write(f"{words}\t{tags}\t{preds}\t{bio_id_inv[segm_pred]}\t{logits}\n")

                fout.write(f"num_proposed:{new_metrics.num_proposed}\n")
                fout.write(f"num_correct:{new_metrics.num_correct}\n")
                fout.write(f"num_gold:{new_metrics.num_gold}\n")
                fout.write(f"precision={new_metrics.precision}\n")
                fout.write(f"precision_gold_mentions={new_metrics.precision_gold_mentions}\n")
                fout.write(f"recall={new_metrics.recall}\n")
                fout.write(f"f1={new_metrics.f1}\n")

        if not args.dont_save_checkpoints:
            if save_checkpoint or metrics.was_improved(new_metrics):
                config = {
                    "args": args,
                    "optimizer_dense": optimizers[0].state_dict() if optimizers else None,
                    "optimizer_sparse": optimizers[1].state_dict() if optimizers else None,
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "step": step,
                    "performance": new_metrics.dict(),
                }
                fname = os.path.join(args.logdir, "{}-{}".format(str(epoch), str(step)))
                torch.save(config, f"{fname}.pt")
                fname = os.path.join(args.logdir, new_metrics.get_best_checkpoint_filename())
                torch.save(config, f"{fname}.pt")
                logging.info(f"weights were saved to {fname}.pt")

        if save_csv:
            new_metrics.to_csv(epoch=epoch, step=step, args=args)

        if metrics.was_improved(new_metrics):
            metrics.update(new_metrics)

        logging.info("Finished evaluation")

        return metrics

    def get_optimizers(self, args, checkpoint):

        optimizers = list()

        if args.encoder_lr > 0:
            if args.exclude_parameter_names_regex is not None:
                bert_parameters = list()
                regex = re.compile(args.exclude_parameter_names_regex)
                for n, p in list(self.bert.named_parameters()):
                    if not len(regex.findall(n)) > 0:
                        bert_parameters.append(p)
            else:
                bert_parameters = list(self.bert.parameters())
            optimizer_encoder = optim.Adam(
                bert_parameters + list(self.fc.parameters() if args.project else list()), lr=args.encoder_lr
            )
            # optimizer_encoder = BertAdam(bert_parameters + list(self.fc.parameters() if args.project else list()),
            #                      lr=args.encoder_lr,
            # )

            if args.resume_optimizer_from_checkpoint:
                optimizer_encoder.load_state_dict(checkpoint["optimizer_dense"])
                optimizer_encoder.param_groups[0]["lr"] = args.encoder_lr
                optimizer_encoder.param_groups[0]["weight_decay"] = args.encoder_weight_decay
            optimizers.append(optimizer_encoder)
        else:
            optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))

        if args.decoder_lr > 0:
            if args.sparse:
                optimizer_decoder = optim.SparseAdam(self.out.parameters(), lr=args.decoder_lr)
            else:
                optimizer_decoder = optim.Adam(self.out.parameters(), lr=args.decoder_lr)
            if args.resume_from_checkpoint is not None:
                optimizer_decoder.load_state_dict(checkpoint["optimizer_sparse"])
                if "weight_decay" not in optimizer_decoder.param_groups[0]:
                    optimizer_decoder.param_groups[0]["weight_decay"] = 0
                optimizer_decoder.param_groups[0]["lr"] = args.decoder_lr
                if not args.sparse:
                    optimizer_decoder.param_groups[0]["weight_decay"] = args.decoder_weight_decay
            optimizers.append(optimizer_decoder)
        else:
            optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))

        if args.segm_decoder_lr > 0:
            optimizer_segm_decoder = optim.Adam(self.out_segm.parameters(), lr=args.segm_decoder_lr)
            if args.resume_optimizer_from_checkpoint:
                optimizer_segm_decoder.param_groups[0]["lr"] = args.segm_decoder_lr
                optimizer_segm_decoder.param_groups[0]["weight_decay"] = args.segm_decoder_weight_decay
            optimizers.append(optimizer_segm_decoder)
        else:
            optimizers.append(DummyOptimizer(self.out.parameters(), defaults={}))

        lr_schedulers = [
            getattr(LRSchedulers, lr_scheduler)(optimizer=optimizer, **lr_scheduler_config)
            for optimizer, (lr_scheduler, lr_scheduler_config) in zip(
                optimizers,
                [
                    (args.encoder_lr_scheduler, args.encoder_lr_scheduler_config),
                    (args.decoder_lr_scheduler, args.decoder_lr_scheduler_config),
                    (args.segm_decoder_lr_scheduler, args.segm_decoder_lr_scheduler_config),
                ],
            )
            if lr_scheduler is not None  # and not isinstance(optimizer, DummyOptimizer)
        ]

        return tuple(optimizers), tuple(lr_schedulers)


In [296]:


class CONLLEDLDataset(data.Dataset):
    def __init__(self, args, split, vocab, device, label_size=None):

        if split == "train":
            train_valid_test_int = 0
        if split == "small_valid" or split == "valid":
            train_valid_test_int = 1
        if split == "test":
            train_valid_test_int = 2

        chunk_len = args.create_integerized_training_instance_text_length
        chunk_overlap = args.create_integerized_training_instance_text_overlap

        self.item_locs = None
        self.device = device
        with open(args.data_path_conll, "rb") as f:
            train_valid_test = pickle.load(f)
            self.conll_docs = torch.LongTensor(
                [
                    [
                        pad_to(
                            [tok_id for _, tok_id, _, _, _, _, _ in doc],
                            max_len=chunk_len + 2,
                            pad_id=0,
                            cls_id=101,
                            sep_id=102,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                    [
                        pad_to(
                            [bio_id for _, _, _, bio_id, _, _, _ in doc],
                            max_len=chunk_len + 2,
                            pad_id=2,
                            cls_id=2,
                            sep_id=2,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                    [
                        pad_to(
                            [wiki_id for _, _, _, _, _, wiki_id, _ in doc],
                            max_len=chunk_len + 2,
                            pad_id=vocab.PAD_ID,
                            cls_id=vocab.PAD_ID,
                            sep_id=vocab.PAD_ID,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                    [
                        pad_to(
                            [doc_id for _, _, _, _, _, _, doc_id in doc],
                            max_len=chunk_len + 2,
                            pad_id=0,
                            cls_id=0,
                            sep_id=0,
                        )
                        for doc in train_valid_test[train_valid_test_int]
                    ],
                ]
            ).permute(1, 0, 2)
            self.conll_docs[:, 2] = set_out_id(self.conll_docs[:, 2], vocab.OUTSIDE_ID)

        self.pad_token_id = vocab.PAD_ID
        self.label_size = label_size
        self.labels = None
        self.train_valid_test_int = train_valid_test_int

    def get_data_iter(
        self, args, batch_size, vocab, train,
    ):
        return data.DataLoader(
            dataset=self.conll_docs,
            batch_size=batch_size,
            shuffle=train,
            num_workers=args.data_workers,
            collate_fn=self.collate_func(
                args,
                return_labels=args.collect_most_popular_labels_steps is not None
                and args.collect_most_popular_labels_steps > 0
                if train
                else True,
                vocab=vocab,
            ),
        )

    def collate_func(self, args, vocab, return_labels):
        def collate(batch):
            return CONLLEDLDataset_collate_func(
                batch=batch,
                labels_with_high_model_score=self.labels,
                args=args,
                return_labels=return_labels,
                vocab=vocab,
                is_training=self.train_valid_test_int == 0,
            )

        return collate


def CONLLEDLDataset_collate_func(
    batch, labels_with_high_model_score, args, return_labels, vocab: Vocab, is_training=False,
):
    drop_entity_mentions_prob = args.maskout_entity_prob
    # print([b[0] for b in batch])
    label_size = args.label_size
    batch_token_ids = torch.LongTensor([b[0].tolist() for b in batch])
    batch_bio_ids = [b[1].tolist() for b in batch]
    batch_entity_ids = [b[2].tolist() for b in batch]
    batch_doc_ids = [b[3, 0].item() for b in batch]

    if return_labels:

        all_batch_entity_ids = OrderedDict()

        for batch_offset, one_item_entity_ids in enumerate(batch_entity_ids):
            for tok_id, eid in enumerate(one_item_entity_ids):
                if eid not in all_batch_entity_ids:
                    all_batch_entity_ids[eid] = len(all_batch_entity_ids)

        if label_size is not None:

            batch_shared_label_ids = all_batch_entity_ids.keys()
            negative_samples = set()
            if labels_with_high_model_score is not None:
                # print(labels_with_high_model_score)
                negative_samples = set(labels_with_high_model_score)
            # else:
            #     negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False))
            if len(negative_samples) < label_size:
                random_negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False))
                negative_samples = negative_samples.union(random_negative_samples)

            negative_samples.difference_update(batch_shared_label_ids)

            if len(batch_shared_label_ids) + len(negative_samples) < label_size:
                negative_samples.difference_update(
                    set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False))
                )

            batch_shared_label_ids = (list(batch_shared_label_ids) + list(negative_samples))[:label_size]
            label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), len(batch_shared_label_ids))
            bio_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), 3)

        else:

            batch_shared_label_ids = list(all_batch_entity_ids.keys())
            label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), args.vocab_size)
            bio_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), 3)

        drop_probs = None
        if drop_entity_mentions_prob > 0 and is_training:
            drop_probs = torch.rand((batch_token_ids.size(0), batch_token_ids.size(1)),) < drop_entity_mentions_prob

        for batch_offset, (one_item_entity_ids, one_item_bio_ids) in enumerate(zip(batch_entity_ids, batch_bio_ids)):
            for tok_id, one_entity_ids in enumerate(one_item_entity_ids):

                if (
                    is_training
                    and vocab.OUTSIDE_ID != one_entity_ids
                    and drop_entity_mentions_prob > 0
                    and drop_probs[batch_offset][tok_id].item() == 1
                ):
                    batch_token_ids[batch_offset][tok_id] = vocab.tokenizer.vocab["[MASK]"]

                if label_size is not None:
                    label_probs[batch_offset][tok_id][torch.LongTensor([all_batch_entity_ids[one_entity_ids]])] = 1.0
                else:
                    label_probs[batch_offset][tok_id][torch.LongTensor(one_entity_ids)] = 1.0
                bio_probs[batch_offset][tok_id][torch.LongTensor(one_item_bio_ids)] = 1.0

        label_ids = torch.LongTensor(batch_shared_label_ids)

        return (
            batch_token_ids,
            label_ids,
            torch.LongTensor(batch_bio_ids),
            torch.FloatTensor(label_probs),
            bio_probs,
            None,
            {v: k for k, v in all_batch_entity_ids.items()},
            batch_entity_ids,
            batch_doc_ids,
            batch,
        )

    else:

        return batch_token_ids, None, None, None, None, None, None, None, None, batch

In [297]:
import copy
import logging
import os
import time

import torch.cuda
import torch.nn as nn

# # from metrics import Metrics
# from data_loader_conll import CONLLEDLDataset
# from data_loader_wiki import EDLDataset
# from model import Net
# from model_conll import ConllNet
# from train_util import get_args
# from vocab import Vocab


class Datasets:
    EDLDataset = EDLDataset
    CONLLEDLDataset = CONLLEDLDataset


class Models:
    Net = Net
    ConllNet = ConllNet


In [298]:
import ast
import os
!pip install configargparse
import configargparse as argparse
import torch.cuda
import yaml
parser = argparse.ArgumentParser()
print(parser)
parser.add_argument("-c", "--config", is_config_file=True, help="config file path")
parser.add_argument("--debug", type=argparse_bool_type, default=False)
parser.add_argument("--device", default=0)
parser.add_argument("--eval_device", default=None)
parser.add_argument("--dataset", default="EDLDataset")
parser.add_argument("--model", default="Net")
parser.add_argument("--data_version_name")
parser.add_argument("--wiki_lang_version")
parser.add_argument("--eval_on_test_only", type=argparse_bool_type, default=False)
parser.add_argument("--out_device", default=None)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--accumulate_batch_gradients", type=int, default=1)
parser.add_argument("--sparse", dest="sparse", type=argparse_bool_type)
parser.add_argument("--encoder_lr", type=float, default=5e-5)
parser.add_argument("--decoder_lr", type=float, default=1e-3)
parser.add_argument("--maskout_entity_prob", type=float, default=0)
parser.add_argument("--segm_decoder_lr", type=float, default=1e-3)
parser.add_argument("--encoder_weight_decay", type=float, default=0)
parser.add_argument("--decoder_weight_decay", type=float, default=0)
parser.add_argument("--segm_decoder_weight_decay", type=float, default=0)
parser.add_argument("--learn_segmentation", type=argparse_bool_type, default=False)
parser.add_argument("--label_size", type=int)
parser.add_argument("--vocab_size", type=int)
parser.add_argument("--entity_embedding_size", type=int, default=768)
parser.add_argument("--project", type=argparse_bool_type, default=False)
parser.add_argument("--n_epochs", type=int, default=1000)
parser.add_argument("--collect_most_popular_labels_steps", type=int, default=100)
parser.add_argument("--checkpoint_eval_steps", type=int, default=1000)
parser.add_argument("--checkpoint_save_steps", type=int, default=50000)
parser.add_argument("--finetuning", dest="finetuning", type=int, default=9999999999)
parser.add_argument("--top_rnns", dest="top_rnns", type=argparse_bool_type)
parser.add_argument("--logdir", type=str)
parser.add_argument("--train_loc_file", type=str, default="train.loc")
parser.add_argument("--valid_loc_file", type=str, default="valid.loc")
parser.add_argument("--test_loc_file", type=str, default="test.loc")
parser.add_argument("--resume_from_checkpoint", type=str)
parser.add_argument("--resume_reset_epoch", type=argparse_bool_type, default=False)
parser.add_argument("--resume_optimizer_from_checkpoint", type=argparse_bool_type, default=False)
parser.add_argument("--topk_neg_examples", type=int, default=3)
parser.add_argument("--dont_save_checkpoints", type=argparse_bool_type, default=False)
parser.add_argument("--data_workers", type=int, default=8)
parser.add_argument("--bert_dropout", type=float, default=None)
parser.add_argument("--encoder_lr_scheduler", type=str, default=None)
parser.add_argument("--encoder_lr_scheduler_config", default=None)
parser.add_argument("--decoder_lr_scheduler", type=str, default=None)
parser.add_argument("--decoder_lr_scheduler_config", default=None)
parser.add_argument("--segm_decoder_lr_scheduler", type=str, default=None)
parser.add_argument("--segm_decoder_lr_scheduler_config", default=None)
parser.add_argument("--eval_before_training", type=argparse_bool_type, default=False)
parser.add_argument("--data_path_conll", type=str,)
parser.add_argument("--train_data_dir", type=str, default="data")
parser.add_argument("--valid_data_dir", type=str, default="data")
parser.add_argument("--test_data_dir", type=str, default="data")
parser.add_argument("--exclude_parameter_names_regex", type=str)
print(parser)
print('after add arguments to parser')
print(argparse_bool_type)
# parser.parse_args(['my', 'list', 'of', 'strings']) 

ArgumentParser(prog='ipykernel_launcher.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)
ArgumentParser(prog='ipykernel_launcher.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)
after add arguments to parser
<function argparse_bool_type at 0x7f1102477d90>


In [299]:


# from misc import argparse_bool_type





def get_args():

    args = parser.parse_args()
    
    for k, v in args.__dict__.items():
        print(k, ":", v)
        if v == "None":
            args.__dict__[k] = None

    args.device = (
        int(args.device) if args.device is not None and args.device != "cpu" and torch.cuda.is_available() else "cpu"
    )
    if args.eval_device is not None:
        if args.eval_device != "cpu":
            args.eval_device = int(args.eval_device)
        else:
            args.eval_device = "cpu"
    else:
        args.eval_device = args.device
    if args.out_device is not None:
        if args.out_device != "cpu":
            args.out_device = int(args.out_device)
        else:
            args.out_device = "cpu"
    else:
        args.out_device = args.device

    if args.encoder_lr_scheduler_config:
        args.encoder_lr_scheduler_config = ast.literal_eval(args.encoder_lr_scheduler_config)
    if args.decoder_lr_scheduler_config:
        args.decoder_lr_scheduler_config = ast.literal_eval(args.decoder_lr_scheduler_config)
    if args.segm_decoder_lr_scheduler_config:
        args.segm_decoder_lr_scheduler_config = ast.literal_eval(args.segm_decoder_lr_scheduler_config)

    args.eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size

    if not args.logdir:
        raise Exception("set args.logdir")

    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    if not args.eval_on_test_only:
        config_fname = os.path.join(args.logdir, "config")
        with open(f"{config_fname}.yaml", "w") as f:
            f.writelines(
                [
                    "{}: {}\n".format(k, v)
                    for k, v in args.__dict__.items()
                    if isinstance(v, str) and len(v.strip()) > 0 or not isinstance(v, str) and v is not None
                ]
            )

    with open(f"data/versions/{args.data_version_name}/config.yaml") as f:
        dataset = yaml.load(f, Loader=yaml.UnsafeLoader)

    for k, v in dataset.items():
        if k != "debug":
            args.__setattr__(k, v)

    return args

In [300]:
args = {'logdir': 'data/checkpoints/dummy_wiki_00001',
'debug': False,
'device': 0,
'eval_device': 0,
'dataset': EDLDataset,
'model': Net,
'data_version_name': 'dummy',
'wiki_lang_version': 'enwiki',
'eval_on_test_only': False,
'out_device': 0,
'batch_size': 16,
'eval_batch_size': 1,
'accumulate_batch_gradients': 8,
'sparse': True,
'encoder_lr': 5e-05,
'decoder_lr': 0.1,
'maskout_entity_prob': 0.0,
'encoder_weight_decay': 0.0,
'decoder_weight_decay': 0.0,
'segm_decoder_weight_decay': 0.0,
'learn_segmentation': False,
'label_size': 8192,
'entity_embedding_size': 768,
'project': False,
'n_epochs': 100,
'collect_most_popular_labels_steps': 1,
'checkpoint_eval_steps': 1000,
'checkpoint_save_steps': 100000,
'finetuning': 3,
'train_loc_file': 'train.loc',
'resume_reset_epoch': False,
'resume_optimizer_from_checkpoint': False,
'topk_neg_examples': 20,
'dont_save_checkpoints': False,
'data_workers': 24,
'uncased': True,
'eval_before_training': False,
'train_data_dir': 'data',
'valid_data_dir': 'data',
'test_data_dir': 'data',
'vocab_size': 0,
'label_size': 0,
 'top_rnns': False,
 'project': False
}
print(args)

{'logdir': 'data/checkpoints/dummy_wiki_00001', 'debug': False, 'device': 0, 'eval_device': 0, 'dataset': <class '__main__.EDLDataset'>, 'model': <class '__main__.Net'>, 'data_version_name': 'dummy', 'wiki_lang_version': 'enwiki', 'eval_on_test_only': False, 'out_device': 0, 'batch_size': 16, 'eval_batch_size': 1, 'accumulate_batch_gradients': 8, 'sparse': True, 'encoder_lr': 5e-05, 'decoder_lr': 0.1, 'maskout_entity_prob': 0.0, 'encoder_weight_decay': 0.0, 'decoder_weight_decay': 0.0, 'segm_decoder_weight_decay': 0.0, 'learn_segmentation': False, 'label_size': 0, 'entity_embedding_size': 768, 'project': False, 'n_epochs': 100, 'collect_most_popular_labels_steps': 1, 'checkpoint_eval_steps': 1000, 'checkpoint_save_steps': 100000, 'finetuning': 3, 'train_loc_file': 'train.loc', 'resume_reset_epoch': False, 'resume_optimizer_from_checkpoint': False, 'topk_neg_examples': 20, 'dont_save_checkpoints': False, 'data_workers': 24, 'uncased': True, 'eval_before_training': False, 'train_data_dir

In [301]:

# if args.debug:
#     logging.basicConfig(level=logging.DEBUG)
# else:
#     logging.basicConfig(level=logging.INFO)

# logging.info(str(("Devices", args.device, args.eval_device, args.out_device)))
# args = 'config/dummy__train_on_wiki.yaml'
# set up the model
vocab = Vocab(args)
model_class = Net
model = model_class(args=args, vocab_size=vocab.size())


In [302]:
checkpoint = None
# if args.resume_from_checkpoint is not None:
resume_from_checkpoint = "/content/drive/My Drive/data/checkpoints/dummy_wiki_00001"
# checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
# model.load_state_dict(checkpoint["model"], strict=False)
if device != "cpu":
    torch.cuda.empty_cache()
    model.to(device, 0)
print(model)

Net(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        

In [303]:

# model_class = getattr(Models, args.model)
# model = model_class(args=args, vocab_size=vocab.size())
# checkpoint = None
# if args.resume_from_checkpoint is not None:
#     checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
#     model.load_state_dict(checkpoint["model"], strict=False)
# if args.device != "cpu":
#     torch.cuda.empty_cache()
#     model.to(args.device, args.out_device)
# print(model)

# set up the optimizers and the loss
# optimizers, lr_schedulers = model.get_optimizers(args, checkpoint=checkpoint)


# else:
#     eval_dataset = getattr(Datasets, args.dataset)(args, split="test", vocab=vocab, device=args.eval_device)
#     eval_iter = eval_dataset.get_data_iter(args=args, batch_size=args.eval_batch_size, vocab=vocab, train=False)



train_dataset = getattr(Datasets, 'EDLDataset')(
    args, split="train", vocab=vocab, device=0, label_size=8192
)
batch_size = 16
eval_batch_size = 1
train_iter = train_dataset.get_data_iter(args=args, batch_size=batch_size, vocab=vocab, train=True)
eval_dataset = getattr(Datasets, 'EDLDataset')(args, split="valid", vocab=vocab, device=0)
eval_iter = eval_dataset.get_data_iter(args=args, batch_size=eval_batch_size, vocab=vocab, train=False)
print('eval_iter')

eval_iter


In [304]:
optimizers = []
optimizer1 = optim.Adam(model.out.parameters(), lr=0.05)
optimizers.append(optimizer1)
lr_schedulers = []
scheduler = ReduceLROnPlateau
#LRMilestones(optimizer1,milestones= [(30, 0.1), (80, 0.2)])
lr_schedulers.append(scheduler)
criterion = nn.BCEWithLogitsLoss()


In [305]:
# # set up the datasets and dataloaders
# if not args.eval_on_test_only:
# train_dataset = getattr(Datasets, args.dataset)(
#     args, split="train", vocab=vocab, device=args.device, label_size=args.label_size
# )
# train_iter = train_dataset.get_data_iter(args=args, batch_size=args.batch_size, vocab=vocab, train=True)
# eval_dataset = getattr(Datasets, args.dataset)(args, split="valid", vocab=vocab, device=args.eval_device)
# eval_iter = eval_dataset.get_data_iter(args=args, batch_size=args.eval_batch_size, vocab=vocab, train=False)

In [None]:
start_epoch = 1
if checkpoint and not args.resume_reset_epoch:
    start_epoch = checkpoint["epoch"]

metrics = Metrics()
args_eval_on_test_only = False
# if args.eval_before_training or args.eval_on_test_only:
cloned_args = copy.deepcopy(args)
# cloned_args.dont_save_checkpoints = True
metrics = model_class.evaluate(
    cloned_args,
    model,
    eval_iter,
    optimizers=optimizers,
    step=0,
    epoch=0,
    save_checkpoint=False,
    save_csv=args_eval_on_test_only,
    vocab=vocab,
    metrics=metrics,
)














  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A[A[A
















  0%|          | 1/1000 [00:02<35:20,  2.12s/it][A[A[A[A[A[A[A[A[A[A[A[A











  0%|          | 3/1000 [00:02<25:06,  1.51s/it][A[A[A[A[A[A[A[A[A[A[A[A











  0%|          | 5/1000 [00:02<17:56,  1.08s/it][A[A[A[A[A[A[A[A[A[A[A[A











  1%|          | 7/1000 [00:02<12:56,  1.28it/s][A[A[A[A[A[A[A[A[A[A[A[A











  1%|          | 9/1000 [00:02<09:21,  1.77it/s][A[A[A[A[A[A[A[A[A[A[A[A











  1%|          | 11/1000 [00:02<06:55,  2.38it/s][A[A[A[A[A[A[A[A[A[A[A[A











  1%|▏         | 13/1000 [00:03<05:22,  3.06it/s][A[A[A[A[A[A[A[A[A[A[A[A











  2%|▏         | 15/1000 [00:03<04:05,  4.02it/s][A[A[A[A[A[A[A[A[A[A[A[A











  2%|▏         | 17/1000 [00:03<03:08,  5.22it/s][A[A[A[A[A[A[A[A[A[A[A[A











  2%|▏         | 19/1000 [00:03<02:30,  6.52it/s][A[A[A[A[A[A[A[A[A[A[A[A











  2%|▏         | 21/100

In [None]:
# if not args.eval_on_test_only
args_n_epochs = 100
args_finetuning = False
for epoch in range(start_epoch, args_n_epochs + 1):

    start = time.time()

    model.finetuning = epoch >= args_finetuning if args_finetuning >= 0 else False

    metrics = model_class.train_one_epoch(
        args=args,
        model=model,
        train_iter=train_iter,
        optimizers=optimizers,
        criterion=criterion,
        vocab=vocab,
        eval_iter=eval_iter,
        epoch=epoch,
        metrics=metrics,
    )

    logging.info(f"Evaluate in epoch {epoch}")
    metrics = model_class.evaluate(
        args, model, eval_iter, optimizers=optimizers, step=0, epoch=epoch, vocab=vocab, metrics=metrics,
    )

    logging.info(f"{time.time() - start} per epoch")

    if lr_schedulers:
        for lr_scheduler in lr_schedulers:
                lr_scheduler.step(metrics.get_model_selection_metric())

**Task 4:** write a training loop, which takes a model (instance of NERNet) and number of epochs to train on. The loss is always CrossEntropyLoss and the optimizer is always Adam.

In [None]:
# def train_loop(model, n_epochs):
#   # Loss function
#   criterion = nn.CrossEntropyLoss()

#   # Optimizer (ADAM is a fancy version of SGD)
#   optimizer = optim.Adam(model.parameters(), lr=0.0001)
  
#   for e in range(1, n_epochs + 1):
#     # TODO - your code goes here...



**Task 5:** write an evaluation loop on a trained model, using the dev and test datasets. This function print the true positive rate (TPR), also known as Recall and the opposite to false positive rate (FPR), also known as precision, of each label seperately (7 labels in total), and for all the 6 labels (except O) together. The caption argument for the function should be served for printing, so that when you print include it as a prefix.

In [None]:
# def evaluate(model, caption):
#   # TODO - your code goes here
#   print(...)

**Task 6:** Train and evaluate a few models, all with embedding_size=300, and with the following hyper parameters (you may use that as captions for the models as well):

Model 1: (hidden_size: 500, n_layers: 1, directions: 1)

Model 2: (hidden_size: 500, n_layers: 2, directions: 1)

Model 3: (hidden_size: 500, n_layers: 3, directions: 1)

Model 4: (hidden_size: 500, n_layers: 1, directions: 2)

Model 5: (hidden_size: 500, n_layers: 2, directions: 2)

Model 6: (hidden_size: 500, n_layers: 3, directions: 2)

Model 4: (hidden_size: 800, n_layers: 1, directions: 2)

Model 5: (hidden_size: 800, n_layers: 2, directions: 2)

Model 6: (hidden_size: 800, n_layers: 3, directions: 2)

In [None]:
# TODO - your code goes here...

**Task 6:** Download the GloVe embeddings from https://nlp.stanford.edu/projects/glove/ (use the 300-dim vectors from glove.6B.zip). Then intialize the nn.Embedding module in your NERNet with these embeddings, so that you can start your training with pre-trained vectors. Repeat Task 6 and print the results for each model.

Note: make sure that vectors are aligned with the IDs in your Vocab, in other words, make sure that for example the word with ID 0 is the first vector in the GloVe matrix of vectors that you initialize nn.Embedding with. For a dicussion on how to do that, check it this link:
https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222

In [None]:
# TODO - your code goes here...

**Good luck!**