In [1]:
!pip install --upgrade pip
!pip3 install segtok

Requirement already up-to-date: pip in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (19.3.1)


In [11]:
from typing import List, Dict, Union

import torch
import logging

from collections import Counter
from collections import defaultdict

from segtok.segmenter import split_single
from segtok.tokenizer import split_contractions
from segtok.tokenizer import word_tokenizer


log = logging.getLogger(__name__)


class Dictionary:
    """
    This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings.
    """

    def __init__(self, add_unk=True):
        # init dictionaries
        self.item2idx: Dict[str, int] = {}
        self.idx2item: List[str] = []

        # in order to deal with unknown tokens, add <unk>
        if add_unk:
            self.add_item('<unk>')

    def add_item(self, item: str) -> int:
        """
        add string - if already in dictionary returns its ID. if not in dictionary, it will get a new ID.
        :param item: a string for which to assign an id
        :return: ID of string
        """
        item = item.encode('utf-8')
        if item not in self.item2idx:
            self.idx2item.append(item)
            self.item2idx[item] = len(self.idx2item) - 1
        return self.item2idx[item]

    def get_idx_for_item(self, item: str) -> int:
        """
        returns the ID of the string, otherwise 0
        :param item: string for which ID is requested
        :return: ID of string, otherwise 0
        """
        item = item.encode('utf-8')
        if item in self.item2idx.keys():
            return self.item2idx[item]
        else:
            return 0

    def get_items(self) -> List[str]:
        items = []
        for item in self.idx2item:
            items.append(item.decode('UTF-8'))
        return items

    def __len__(self) -> int:
        return len(self.idx2item)

    def get_item_for_index(self, idx):
        return self.idx2item[idx].decode('UTF-8')

    def save(self, savefile):
        import pickle
        with open(savefile, 'wb') as f:
            mappings = {
                'idx2item': self.idx2item,
                'item2idx': self.item2idx
            }
            pickle.dump(mappings, f)

    @classmethod
    def load_from_file(cls, filename: str):
        import pickle
        dictionary: Dictionary = Dictionary()
        with open(filename, 'rb') as f:
            mappings = pickle.load(f, encoding='latin1')
            idx2item = mappings['idx2item']
            item2idx = mappings['item2idx']
            dictionary.item2idx = item2idx
            dictionary.idx2item = idx2item
        return dictionary

    @classmethod
    def load(cls, name: str):
        from flair.file_utils import cached_path
        if name == 'chars' or name == 'common-chars':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/common_characters'
            char_dict = cached_path(base_path, cache_dir='datasets')
            return Dictionary.load_from_file(char_dict)

        return Dictionary.load_from_file(name)


class Label:
    """
    This class represents a label of a sentence. Each label has a value and optionally a confidence score. The
    score needs to be between 0.0 and 1.0. Default value for the score is 1.0.
    """

    def __init__(self, value: str, score: float = 1.0):
        self.value = value
        self.score = score
        super().__init__()

    @property
    def value(self):
        return self._value

    @value.setter
    def value(self, value):
        if not value and value != '':
            raise ValueError('Incorrect label value provided. Label value needs to be set.')
        else:
            self._value = value

    @property
    def score(self):
        return self._score

    @score.setter
    def score(self, score):
        if 0.0 <= score <= 1.0:
            self._score = score
        else:
            self._score = 1.0

    def to_dict(self):
        return {
            'value': self.value,
            'confidence': self.score
        }

    def __str__(self):
        return "{} ({})".format(self._value, self._score)

    def __repr__(self):
        return "{} ({})".format(self._value, self._score)


class Token:
    """
    This class represents one word in a tokenized sentence. Each token may have any number of tags. It may also point
    to its head in a dependency tree.
    """

    def __init__(self,
                 text: str,
                 sem: str = None,
                 idx: int = None,
                 head_id: int = None,
                 whitespace_after: bool = True,
                 start_position: int = None
                 ):
        self.text: str = text
        self.sem: str = sem 
        self.idx: int = idx
        self.head_id: int = head_id
        self.whitespace_after: bool = whitespace_after

        self.start_pos = start_position
        self.end_pos = start_position + len(text) if start_position is not None else None

        self.sentence: Sentence = None
        self._embeddings: Dict = {}
        self.tags: Dict[str, Label] = {}

    def add_tag(self, tag_type: str, tag_value: str, confidence=1.0):
        tag = Label(tag_value, confidence)
        self.tags[tag_type] = tag

    def get_tag(self, tag_type: str) -> Label:
        if tag_type in self.tags: return self.tags[tag_type]
        return Label('')
    def set_sem(self, sem: str):
        self.sem = sem
    

    def get_head(self):
        return self.sentence.get_token(self.head_id)

    def set_embedding(self, name: str, vector: torch.autograd.Variable):
        self._embeddings[name] = vector.cpu()

    def clear_embeddings(self):
        self._embeddings: Dict = {}

    def get_embedding(self) -> torch.FloatTensor:

        embeddings = [self._embeddings[embed] for embed in sorted(self._embeddings.keys())]

        if embeddings:
            return torch.cat(embeddings, dim=0)

        return torch.FloatTensor()

    @property
    def start_position(self) -> int:
        return self.start_pos

    @property
    def end_position(self) -> int:
        return self.end_pos

    @property
    def embedding(self):
        return self.get_embedding()

    def __str__(self) -> str:
        return 'Token: {} {}'.format(self.idx, self.text) if self.idx is not None else 'Token: {}'.format(self.text)

    def __repr__(self) -> str:
        return 'Token: {} {}'.format(self.idx, self.text) if self.idx is not None else 'Token: {}'.format(self.text)


class Span:
    """
    This class represents one textual span consisting of Tokens. A span may have a tag.
    """

    def __init__(self, tokens: List[Token], tag: str = None, score=1.):
        self.tokens = tokens
        self.tag = tag
        self.score = score
        self.start_pos = None
        self.end_pos = None

        if tokens:
            self.start_pos = tokens[0].start_position
            self.end_pos = tokens[len(tokens) - 1].end_position

    @property
    def text(self) -> str:
        return ' '.join([t.text for t in self.tokens])

    def to_original_text(self) -> str:
        str = ''
        pos = self.tokens[0].start_pos
        for t in self.tokens:
            while t.start_pos != pos:
                str += ' '
                pos += 1

            str += t.text
            pos += len(t.text)

        return str

    def to_dict(self):
        return {
            'text': self.to_original_text(),
            'start_pos': self.start_pos,
            'end_pos': self.end_pos,
            'type': self.tag,
            'confidence': self.score
        }

    def __str__(self) -> str:
        ids = ','.join([str(t.idx) for t in self.tokens])
        return '{}-span [{}]: "{}"'.format(self.tag, ids, self.text) \
            if self.tag is not None else 'span [{}]: "{}"'.format(ids, self.text)

    def __repr__(self) -> str:
        ids = ','.join([str(t.idx) for t in self.tokens])
        return '<{}-span ({}): "{}">'.format(self.tag, ids, self.text) \
            if self.tag is not None else '<span ({}): "{}">'.format(ids, self.text)


class Sentence:
    """
    A Sentence is a list of Tokens and is used to represent a sentence or text fragment.
    """

    def __init__(self, text: str = None, use_tokenizer: bool = False, labels: Union[List[Label], List[str]] = None):

        super(Sentence, self).__init__()

        self.tokens: List[Token] = []

        self.labels: List[Label] = []
        if labels is not None: self.add_labels(labels)

        self._embeddings: Dict = {}

        # if text is passed, instantiate sentence with tokens (words)
        if text is not None:

            # tokenize the text first if option selected
            if use_tokenizer:

                # use segtok for tokenization
                tokens = []
                sentences = split_single(text)
                for sentence in sentences:
                    contractions = split_contractions(word_tokenizer(sentence))
                    tokens.extend(contractions)

                # determine offsets for whitespace_after field
                index = text.index
                running_offset = 0
                last_word_offset = -1
                last_token = None
                for word in tokens:
                    try:
                        word_offset = index(word, running_offset)
                        start_position = word_offset
                    except:
                        word_offset = last_word_offset + 1
                        start_position = running_offset + 1 if running_offset > 0 else running_offset

                    token = Token(word, start_position=start_position)
                    self.add_token(token)

                    if word_offset - 1 == last_word_offset and last_token is not None:
                        last_token.whitespace_after = False

                    word_len = len(word)
                    running_offset = word_offset + word_len
                    last_word_offset = running_offset - 1
                    last_token = token

            # otherwise assumes whitespace tokenized text
            else:
                # add each word in tokenized string as Token object to Sentence
                offset = 0
                for word in text.split(' '):
                    if word:
                        try:
                            word_offset = text.index(word, offset)
                        except:
                            word_offset = offset

                        token = Token(word, start_position=word_offset)
                        self.add_token(token)
                        offset += len(word) + 1

    def get_token(self, token_id: int) -> Token:
        for token in self.tokens:
            if token.idx == token_id:
                return token

    def add_token(self, token: Token):
        self.tokens.append(token)

        # set token idx if not set
        token.sentence = self
        if token.idx is None:
            token.idx = len(self.tokens)

    def get_spans(self, tag_type: str, min_score=-1) -> List[Span]:

        spans: List[Span] = []

        current_span = []

        tags = defaultdict(lambda: 0.0)

        previous_tag_value: str = 'O'
        for token in self:

            tag: Label = token.get_tag(tag_type)
            tag_value = tag.value

            # non-set tags are OUT tags
            if tag_value == '' or tag_value == 'O':
                tag_value = 'O-'

            # anything that is not a BIOES tag is a SINGLE tag
            if tag_value[0:2] not in ['B-', 'I-', 'O-', 'E-', 'S-']:
                tag_value = 'S-' + tag_value

            # anything that is not OUT is IN
            in_span = False
            if tag_value[0:2] not in ['O-']:
                in_span = True

            # single and begin tags start a new span
            starts_new_span = False
            if tag_value[0:2] in ['B-', 'S-']:
                starts_new_span = True

            if previous_tag_value[0:2] in ['S-'] and previous_tag_value[2:] != tag_value[2:] and in_span:
                starts_new_span = True

            if (starts_new_span or not in_span) and len(current_span) > 0:
                scores = [t.get_tag(tag_type).score for t in current_span]
                span_score = sum(scores) / len(scores)
                if span_score > min_score:
                    spans.append(Span(
                        current_span,
                        tag=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
                        score=span_score)
                    )
                current_span = []
                tags = defaultdict(lambda: 0.0)

            if in_span:
                current_span.append(token)
                weight = 1.1 if starts_new_span else 1.0
                tags[tag_value[2:]] += weight

            # remember previous tag
            previous_tag_value = tag_value

        if len(current_span) > 0:
            scores = [t.get_tag(tag_type).score for t in current_span]
            span_score = sum(scores) / len(scores)
            if span_score > min_score:
                spans.append(Span(
                    current_span,
                    tag=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
                    score=span_score)
                )

        return spans

    def add_label(self, label: Union[Label, str]):
        if type(label) is Label:
            self.labels.append(label)

        elif type(label) is str:
            self.labels.append(Label(label))

    def add_labels(self, labels: Union[List[Label], List[str]]):
        for label in labels:
            self.add_label(label)

    def get_label_names(self) -> List[str]:
        return [label.value for label in self.labels]

    @property
    def embedding(self):
        return self.get_embedding()

    def set_embedding(self, name: str, vector):
        self._embeddings[name] = vector.cpu()

    def get_embedding(self) -> torch.autograd.Variable:
        embeddings = []
        for embed in sorted(self._embeddings.keys()):
            embedding = self._embeddings[embed]
            embeddings.append(embedding)

        if embeddings:
            return torch.cat(embeddings, dim=0)

        return torch.FloatTensor()

    def clear_embeddings(self, also_clear_word_embeddings: bool = True):
        self._embeddings: Dict = {}

        if also_clear_word_embeddings:
            for token in self:
                token.clear_embeddings()

    def cpu_embeddings(self):
        for name, vector in self._embeddings.items():
            self._embeddings[name] = vector.cpu()

    def to_tagged_string(self, main_tag=None) -> str:
        list = []
        for token in self.tokens:
            list.append(token.text)

            tags: List[str] = []
            for tag_type in token.tags.keys():

                if main_tag is not None and main_tag != tag_type: continue

                if token.get_tag(tag_type).value == '' or token.get_tag(tag_type).value == 'O': continue
                tags.append(token.get_tag(tag_type).value)
            all_tags = '<' + '/'.join(tags) + '>'
            if all_tags != '<>':
                list.append(all_tags)
        return ' '.join(list)
    def get_tags(self, main_tag=None) -> str:
        list = []
        output = []
        for token in self.tokens:
            list.append(token.text)

            tags: List[str] = []
            for tag_type in token.tags.keys():

                if main_tag is not None and main_tag != tag_type:
                        output.append("O")
                        continue

                #if token.get_tag(tag_type).value == '' or token.get_tag(tag_type).value == 'O': continue
                
                tags.append(token.get_tag(tag_type).value)
            output.append(tags[0])
        return output 
    def to_offset_tags(self, sent_offset):
        list = []
        offset = 0
        cur_entity_start_offset = -1
        cur_entity_text = ""
        for token in self.tokens:
                start = offset
                #for tag_type in token.tags.keys():
                        #if token.get_tag(tag_type) == '' or token.get_tag(tag_type).value == 'O':print (token.text, 'O')
                        #else:  print (token.text, token.get_tag(tag_type).value)	
                offset += len(token.text)
                #print ('-----')
                #print (token.text)
                tag_type = "ner"
                #print ("tag:"+token.get_tag(tag_type).value)
                if token.get_tag(tag_type).value.startswith("E-"):
                        cur_entity_text += token.text
                        entity_start_offset = sent_offset + cur_entity_start_offset
                        entity_end_offset = sent_offset + offset
                        entity_tag = token.get_tag(tag_type).value[2:]
                        entity_txt = cur_entity_text
                        list.append([entity_tag, str(entity_start_offset), str(entity_end_offset), str(entity_txt)])
			#list.append(entity_tag + " " + str(entity_start_offset) + " " + str(entity_end_offset) + "\t"+str(entity_txt))	
                        cur_entity_text = ""
                        cur_entity_start_offset = -1
                if token.get_tag(tag_type).value.startswith("B-"):
                        cur_entity_text = token.text
                        cur_entity_start_offset = offset - len(token.text)
                        if token.whitespace_after:
                                cur_entity_text += " "
                if token.get_tag(tag_type).value.startswith("I-"):
                        cur_entity_text += token.text
                        if token.whitespace_after:
                                cur_entity_text += " "
                if token.get_tag(tag_type).value.startswith("S-"):
                        entity_start_offset = sent_offset + offset - len(token.text)
                        entity_end_offset = sent_offset + offset
                        entity_tag = token.get_tag(tag_type).value[2:]
                        entity_txt = token.text
                        list.append([entity_tag, str(entity_start_offset), str(entity_end_offset), str(entity_txt)]) 
                        #list.append(entity_tag + " " + str(entity_start_offset) + " " + str(entity_end_offset) + "\t"+str(entity_txt))
                        cur_entity_text = ""
                        cur_entity_start_offset = -1		         
                if token.whitespace_after:
                        offset += 1
                end = offset
                #print ("cur_entity:"+cur_entity_text)
        return list		
			


    def to_tokenized_string(self) -> str:
        return ' '.join([t.text for t in self.tokens])

    def to_plain_string(self):
        plain = ''
        for token in self.tokens:
            plain += token.text
            if token.whitespace_after: plain += ' '
        return plain.rstrip()

    def convert_tag_scheme(self, tag_type: str = 'ner', target_scheme: str = 'iob'):

        tags: List[Label] = []
        for token in self.tokens:
            token: Token = token
            tags.append(token.get_tag(tag_type))
        print ("tokens:", self.tokens)
        print ("tags:", tags)
        if target_scheme == 'iob':
            iob2(tags)

        if target_scheme == 'iobes':
            iob2(tags)
            tags = iob_iobes(tags)

        for index, tag in enumerate(tags):
            self.tokens[index].add_tag(tag_type, tag)

    def infer_space_after(self):
        """
        Heuristics in case you wish to infer whitespace_after values for tokenized text. This is useful for some old NLP
        tasks (such as CoNLL-03 and CoNLL-2000) that provide only tokenized data with no info of original whitespacing.
        :return:
        """
        last_token = None
        quote_count: int = 0
        # infer whitespace after field

        for token in self.tokens:
            if token.text == '"':
                quote_count += 1
                if quote_count % 2 != 0:
                    token.whitespace_after = False
                elif last_token is not None:
                    last_token.whitespace_after = False

            if last_token is not None:

                if token.text in ['.', ':', ',', ';', ')', 'n\'t', '!', '?']:
                    last_token.whitespace_after = False

                if token.text.startswith('\''):
                    last_token.whitespace_after = False

            if token.text in ['(']:
                token.whitespace_after = False

            last_token = token
        return self

    def to_original_text(self) -> str:
        str = ''
        pos = 0
        for t in self.tokens:
            while t.start_pos != pos:
                str += ' '
                pos += 1

            str += t.text
            pos += len(t.text)

        return str

    def to_dict(self, tag_type: str = None):
        labels = []
        entities = []

        if tag_type:
            entities = [span.to_dict() for span in self.get_spans(tag_type)]
        if self.labels:
            labels = [l.to_dict() for l in self.labels]

        return {
            'text': self.to_original_text(),
            'labels': labels,
            'entities': entities
        }

    def __getitem__(self, idx: int) -> Token:
        return self.tokens[idx]

    def __iter__(self):
        return iter(self.tokens)

    def __repr__(self):
        return 'Sentence: "{}" - {} Tokens'.format(' '.join([t.text for t in self.tokens]), len(self))

    def __copy__(self):
        s = Sentence()
        for token in self.tokens:
            nt = Token(token.text)
            for tag_type in token.tags:
                nt.add_tag(tag_type, token.get_tag(tag_type).value, token.get_tag(tag_type).score)

            s.add_token(nt)
        return s

    def __str__(self) -> str:
        return 'Sentence: "{}" - {} Tokens'.format(' '.join([t.text for t in self.tokens]), len(self))

    def __len__(self) -> int:
        return len(self.tokens)


class TaggedCorpus:
    def __init__(self, train: List[Sentence], dev: List[Sentence], test: List[Sentence]):
        self.train: List[Sentence] = train
        self.dev: List[Sentence] = dev
        self.test: List[Sentence] = test

    def downsample(self, percentage: float = 0.1, only_downsample_train=False):

        self.train = self._downsample_to_proportion(self.train, percentage)
        if not only_downsample_train:
            self.dev = self._downsample_to_proportion(self.dev, percentage)
            self.test = self._downsample_to_proportion(self.test, percentage)

        return self

    def clear_embeddings(self):
        for sentence in self.get_all_sentences():
            for token in sentence.tokens:
                token.clear_embeddings()

    def get_all_sentences(self) -> List[Sentence]:
        all_sentences: List[Sentence] = []
        all_sentences.extend(self.train)
        all_sentences.extend(self.dev)
        all_sentences.extend(self.test)
        return all_sentences

    def make_tag_dictionary(self, tag_type: str) -> Dictionary:

        # Make the tag dictionary
        tag_dictionary: Dictionary = Dictionary()
        tag_dictionary.add_item('O')
        for sentence in self.get_all_sentences():
            for token in sentence.tokens:
                token: Token = token
                tag_dictionary.add_item(token.get_tag(tag_type).value)
        tag_dictionary.add_item('<START>')
        tag_dictionary.add_item('<STOP>')
        return tag_dictionary

    def make_label_dictionary(self) -> Dictionary:
        """
        Creates a dictionary of all labels assigned to the sentences in the corpus.
        :return: dictionary of labels
        """

        labels = set(self._get_all_label_names())

        label_dictionary: Dictionary = Dictionary(add_unk=False)
        for label in labels:
            label_dictionary.add_item(label)

        return label_dictionary

    def make_vocab_dictionary(self, max_tokens=-1, min_freq=1) -> Dictionary:
        """
        Creates a dictionary of all tokens contained in the corpus.
        By defining `max_tokens` you can set the maximum number of tokens that should be contained in the dictionary.
        If there are more than `max_tokens` tokens in the corpus, the most frequent tokens are added first.
        If `min_freq` is set the a value greater than 1 only tokens occurring more than `min_freq` times are considered
        to be added to the dictionary.
        :param max_tokens: the maximum number of tokens that should be added to the dictionary (-1 = take all tokens)
        :param min_freq: a token needs to occur at least `min_freq` times to be added to the dictionary (-1 = there is no limitation)
        :return: dictionary of tokens
        """
        tokens = self._get_most_common_tokens(max_tokens, min_freq)

        vocab_dictionary: Dictionary = Dictionary()
        for token in tokens:
            vocab_dictionary.add_item(token)

        return vocab_dictionary

    def _get_most_common_tokens(self, max_tokens, min_freq) -> List[str]:
        tokens_and_frequencies = Counter(self._get_all_tokens())
        tokens_and_frequencies = tokens_and_frequencies.most_common()

        tokens = []
        for token, freq in tokens_and_frequencies:
            if (min_freq != -1 and freq < min_freq) or (max_tokens != -1 and len(tokens) == max_tokens):
                break
            tokens.append(token)
        return tokens

    def _get_all_label_names(self) -> List[str]:
        return [label.value for sent in self.train for label in sent.labels]

    def _get_all_tokens(self) -> List[str]:
        tokens = list(map((lambda s: s.tokens), self.train))
        tokens = [token for sublist in tokens for token in sublist]
        return list(map((lambda t: t.text), tokens))

    def _downsample_to_proportion(self, list: List, proportion: float):

        counter = 0.0
        last_counter = None
        downsampled: List = []

        for item in list:
            counter += proportion
            if int(counter) != last_counter:
                downsampled.append(item)
                last_counter = int(counter)
        return downsampled

    def print_statistics(self):
        """
        Print statistics about the class distribution (only labels of sentences are taken into account) and sentence
        sizes.
        """
        self._print_statistics_for(self.train, "TRAIN")
        self._print_statistics_for(self.test, "TEST")
        self._print_statistics_for(self.dev, "DEV")

    @staticmethod
    def _print_statistics_for(sentences, name):
        if len(sentences) == 0:
            return

        classes_to_count = TaggedCorpus._get_classes_to_count(sentences)
        tokens_per_sentence = TaggedCorpus._get_tokens_per_sentence(sentences)

        size_dict = {}
        for l, c in classes_to_count.items():
            size_dict[l] = c
        size_dict['total'] = len(sentences)

        stats = {
            'dataset': name,
            'number_of_documents': size_dict,
            'number_of_tokens': {
                'total': sum(tokens_per_sentence),
                'min': min(tokens_per_sentence),
                'max': max(tokens_per_sentence),
                'avg': sum(tokens_per_sentence) / len(sentences)
            }
        }

        log.info(stats)

    @staticmethod
    def _get_tokens_per_sentence(sentences):
        return list(map(lambda x: len(x.tokens), sentences))

    @staticmethod
    def _get_classes_to_count(sentences):
        classes_to_count = defaultdict(lambda: 0)
        for sent in sentences:
            for label in sent.labels:
                classes_to_count[label.value] += 1
        return classes_to_count

    def __str__(self) -> str:
        return 'TaggedCorpus: %d train + %d dev + %d test sentences' % (len(self.train), len(self.dev), len(self.test))


def iob2(tags):
    """
    Check that tags have a valid IOB format.
    Tags in IOB1 format are converted to IOB2.
    """
    for i, tag in enumerate(tags):
        # print(tag)
        if tag.value == 'O':
            continue
        split = tag.value.split('-')
        if len(split) != 2 or split[0] not in ['I', 'B']:
            return False
        if split[0] == 'B':
            continue
        elif i == 0 or tags[i - 1].value == 'O':  # conversion IOB1 to IOB2
            tags[i].value = 'B' + tag.value[1:]
        elif tags[i - 1].value[1:] == tag.value[1:]:
            continue
        else:  # conversion IOB1 to IOB2
            tags[i].value = 'B' + tag.value[1:]
    return True


def iob_iobes(tags):
    """
    IOB -> IOBES
    """
    new_tags = []
    for i, tag in enumerate(tags):
        if tag.value == 'O':
            new_tags.append(tag.value)
        elif tag.value.split('-')[0] == 'B':
            if i + 1 != len(tags) and \
                    tags[i + 1].value.split('-')[0] == 'I':
                new_tags.append(tag.value)
            else:
                new_tags.append(tag.value.replace('B-', 'S-'))
        elif tag.value.split('-')[0] == 'I':
            if i + 1 < len(tags) and \
                    tags[i + 1].value.split('-')[0] == 'I':
                new_tags.append(tag.value)
            else:
                new_tags.append(tag.value.replace('I-', 'E-'))
        else:
            print (tag)
            raise Exception('Invalid IOB format!')
    return new_tags



In [9]:
"""
Utilities for working with the local dataset cache. Copied from AllenNLP
"""

from typing import Tuple
import os
import base64
import logging
import shutil
import tempfile
import re
from urllib.parse import urlparse

import requests

# from allennlp.common.tqdm import Tqdm

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

CACHE_ROOT = os.path.expanduser(os.path.join('~', '.flair'))


def url_to_filename(url: str, etag: str = None) -> str:
    """
    Converts a url into a filename in a reversible way.
    If `etag` is specified, add it on the end, separated by a period
    (which necessarily won't appear in the base64-encoded filename).
    Get rid of the quotes in the etag, since Windows doesn't like them.
    """
    url_bytes = url.encode('utf-8')
    b64_bytes = base64.b64encode(url_bytes)
    decoded = b64_bytes.decode('utf-8')

    if etag:
        # Remove quotes from etag
        etag = etag.replace('"', '')
        return f"{decoded}.{etag}"
    else:
        return decoded


def filename_to_url(filename: str) -> Tuple[str, str]:
    """
    Recovers the the url from the encoded filename. Returns it and the ETag
    (which may be ``None``)
    """
    try:
        # If there is an etag, it's everything after the first period
        decoded, etag = filename.split(".", 1)
    except ValueError:
        # Otherwise, use None
        decoded, etag = filename, None

    filename_bytes = decoded.encode('utf-8')
    url_bytes = base64.b64decode(filename_bytes)
    return url_bytes.decode('utf-8'), etag


def cached_path(url_or_filename: str, cache_dir: str) -> str:
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
    """
    dataset_cache = os.path.join(CACHE_ROOT, cache_dir)

    parsed = urlparse(url_or_filename)

    if parsed.scheme in ('http', 'https'):
        # URL, so get it from the cache (downloading if necessary)
        return get_from_cache(url_or_filename, dataset_cache)
    elif parsed.scheme == '' and os.path.exists(url_or_filename):
        # File, and it exists.
        return url_or_filename
    elif parsed.scheme == '':
        # File, but it doesn't exist.
        raise FileNotFoundError("file {} not found".format(url_or_filename))
    else:
        # Something unknown
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


# TODO(joelgrus): do we want to do checksums or anything like that?
def get_from_cache(url: str, cache_dir: str = None) -> str:
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file.
    """

    os.makedirs(cache_dir, exist_ok=True)

    filename = re.sub(r'.+/', '', url)
    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)
    if os.path.exists(cache_path):
        return cache_path

    # make HEAD request to check ETag
    response = requests.head(url)
    if response.status_code != 200:
        raise IOError("HEAD request failed for url {}".format(url))

    # add ETag to filename if it exists
    # etag = response.headers.get("ETag")

    if not os.path.exists(cache_path):
        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        _, temp_filename = tempfile.mkstemp()
        logger.info("%s not found in cache, downloading to %s", url, temp_filename)

        # GET file object
        req = requests.get(url, stream=True)
        content_length = req.headers.get('Content-Length')
        total = int(content_length) if content_length is not None else None
        progress = Tqdm.tqdm(unit="B", total=total)
        with open(temp_filename, 'wb') as temp_file:
            for chunk in req.iter_content(chunk_size=1024):
                if chunk: # protocol_filter out keep-alive new chunks
                    progress.update(len(chunk))
                    temp_file.write(chunk)

        progress.close()

        logger.info("copying %s to cache at %s", temp_filename, cache_path)
        shutil.copyfile(temp_filename, cache_path)
        logger.info("removing temp file %s", temp_filename)
        os.remove(temp_filename)

    return cache_path


from tqdm import tqdm as _tqdm


class Tqdm:
    # These defaults are the same as the argument defaults in tqdm.
    default_mininterval: float = 0.1

    @staticmethod
    def set_default_mininterval(value: float) -> None:
        Tqdm.default_mininterval = value

    @staticmethod
    def set_slower_interval(use_slower_interval: bool) -> None:
        """
        If ``use_slower_interval`` is ``True``, we will dramatically slow down ``tqdm's`` default
        output rate.  ``tqdm's`` default output rate is great for interactively watching progress,
        but it is not great for log files.  You might want to set this if you are primarily going
        to be looking at output through log files, not the terminal.
        """
        if use_slower_interval:
            Tqdm.default_mininterval = 10.0
        else:
            Tqdm.default_mininterval = 0.1

    @staticmethod
    def tqdm(*args, **kwargs):
        new_kwargs = {
                'mininterval': Tqdm.default_mininterval,
                **kwargs
        }

        return _tqdm(*args, **new_kwargs)


In [19]:
#!pip install flair



In [1]:
import torch.nn


class LockedDropout(torch.nn.Module):
    """
    Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space.
    """
    def __init__(self, dropout_rate=0.5):
        super(LockedDropout, self).__init__()
        self.dropout_rate = dropout_rate

    def forward(self, x):
        if not self.training or not self.dropout_rate:
            return x

        m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate)
        mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate)
        mask = mask.expand_as(x)
        return mask * x


class WordDropout(torch.nn.Module):
    """
    Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space.
    """
    def __init__(self, dropout_rate=0.05):
        super(WordDropout, self).__init__()
        self.dropout_rate = dropout_rate

    def forward(self, x):
        if not self.training or not self.dropout_rate:
            return x

        m = x.data.new(x.size(0), 1, 1).bernoulli_(1 - self.dropout_rate)
        mask = torch.autograd.Variable(m, requires_grad=False)
        mask = mask.expand_as(x)
        return mask * x

In [1]:
import itertools
import random
import logging
import os
from collections import defaultdict
from typing import List
# from flair.data import Dictionary, Sentence
from functools import reduce

MICRO_AVG_METRIC = 'MICRO_AVG'

log = logging.getLogger(__name__)


class Metric(object):

    def __init__(self, name):
        self.name = name

        self._tps = defaultdict(int)
        self._fps = defaultdict(int)
        self._tns = defaultdict(int)
        self._fns = defaultdict(int)

    def tp(self, class_name=None):
        self._tps[class_name] += 1

    def tn(self, class_name=None):
        self._tns[class_name] += 1

    def fp(self, class_name=None):
        self._fps[class_name] += 1

    def fn(self, class_name=None):
        self._fns[class_name] += 1

    def get_tp(self, class_name=None):
        return self._tps[class_name]

    def get_tn(self, class_name=None):
        return self._tns[class_name]

    def get_fp(self, class_name=None):
        return self._fps[class_name]

    def get_fn(self, class_name=None):
        return self._fns[class_name]

    def precision(self, class_name=None):
        if self._tps[class_name] + self._fps[class_name] > 0:
            return round(self._tps[class_name] / (self._tps[class_name] + self._fps[class_name]), 4)
        return 0.0

    def recall(self, class_name=None):
        if self._tps[class_name] + self._fns[class_name] > 0:
            return round(self._tps[class_name] / (self._tps[class_name] + self._fns[class_name]), 4)
        return 0.0

    def f_score(self, class_name=None):
        if self.precision(class_name) + self.recall(class_name) > 0:
            return round(2 * (self.precision(class_name) * self.recall(class_name))
                         / (self.precision(class_name) + self.recall(class_name)), 4)
        return 0.0

    def micro_avg_f_score(self):
        all_tps = sum([self.tp(class_name) for class_name in self.get_classes()])
        all_fps = sum([self.fp(class_name) for class_name in self.get_classes()])
        all_fns = sum([self.fn(class_name) for class_name in self.get_classes()])
        micro_precision = 0.0
        micro_recall = 0.0
        if all_tps + all_fps > 0:
            micro_precision = round(all_tps / (all_tps + all_fps), 4)
        if all_tps + all_fns > 0:
            micro_recall = round(all_tps / (all_tps + all_fns), 4)
        if micro_precision + micro_recall > 0:
            return round(2 * (micro_precision * micro_recall)
                         / (micro_precision + micro_recall), 4)
        return 0.0

    def macro_avg_f_score(self):
        class_precisions = [self.precision(class_name) for class_name in self.get_classes()]
        class_recalls = [self.precision(class_name) for class_name in self.get_classes()]
        macro_precision = sum(class_precisions) / len(class_precisions)
        macro_recall = sum(class_recalls) / len(class_recalls)
        if macro_precision + macro_recall > 0:
            return round(2 * (macro_precision * macro_recall)
                         / (macro_precision + macro_recall), 4)
        return 0.0



    def accuracy(self, class_name=None):
        if self._tps[class_name] + self._tns[class_name] + self._fps[class_name] + self._fns[class_name] > 0:
            return round(
                (self._tps[class_name] + self._tns[class_name])
                / (self._tps[class_name] + self._tns[class_name] + self._fps[class_name] + self._fns[class_name]),
                4)
        return 0.0

    def to_tsv(self):
        return '{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(
            self.get_tp(), self.get_tn(), self.get_fp(), self.get_fn(), self.precision(), self.recall(), self.f_score(),
            self.accuracy())

    def print(self):
        log.info(self)

    @staticmethod
    def tsv_header(prefix=None):
        if prefix:
            return '{0}_TP\t{0}_TN\t{0}_FP\t{0}_FN\t{0}_PRECISION\t{0}_RECALL\t{0}_F-SCORE\t{0}_ACCURACY'.format(
                prefix)

        return 'TP\tTN\tFP\tFN\tPRECISION\tRECALL\tF-SCORE\tACCURACY'

    @staticmethod
    def to_empty_tsv():
        return '_\t_\t_\t_\t_\t_\t_\t_'

    def __str__(self):
        all_classes = self.get_classes()
        all_classes = [None] + all_classes
        all_lines = [
            '{0:<10}\ttp: {1} - fp: {2} - fn: {3} - tn: {4} - precision: {5:.4f} - recall: {6:.4f} - accuracy: {7:.4f} - f1-score: {8:.4f}'.format(
                self.name if class_name == None else class_name,
                self._tps[class_name], self._fps[class_name], self._fns[class_name], self._tns[class_name],
                self.precision(class_name), self.recall(class_name), self.accuracy(class_name),
                self.f_score(class_name))
            for class_name in all_classes]
        return '\n'.join(all_lines)

    def get_classes(self) -> List:
        all_classes = set(itertools.chain(*[list(keys) for keys
                                                 in [self._tps.keys(), self._fps.keys(), self._tns.keys(),
                                                     self._fns.keys()]]))
        all_classes = [class_name for class_name in all_classes if class_name is not None]
        all_classes.sort()
        return all_classes


class WeightExtractor(object):

    def __init__(self, directory: str, number_of_weights: int = 10):
        self.weights_file = init_output_file(directory, 'weights.txt')
        self.weights_dict = defaultdict(lambda: defaultdict(lambda: list()))
        self.number_of_weights = number_of_weights

    def extract_weights(self, state_dict, iteration):
        for key in state_dict.keys():

            vec = state_dict[key]
            weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size())))

            if key not in self.weights_dict:
                self._init_weights_index(key, state_dict, weights_to_watch)

            for i in range(weights_to_watch):
                vec = state_dict[key]
                for index in self.weights_dict[key][i]:
                    vec = vec[index]

                value = vec.item()

                with open(self.weights_file, 'a') as f:
                    f.write('{}\t{}\t{}\t{}\n'.format(iteration, key, i, float(value)))

    def _init_weights_index(self, key, state_dict, weights_to_watch):
        indices = {}

        i = 0
        while len(indices) < weights_to_watch:
            vec = state_dict[key]
            cur_indices = []

            for x in range(len(vec.size())):
                index = random.randint(0, len(vec) - 1)
                vec = vec[index]
                cur_indices.append(index)

            if cur_indices not in list(indices.values()):
                indices[i] = cur_indices
                i += 1

        self.weights_dict[key] = indices


def clear_embeddings(sentences: List[Sentence], also_clear_word_embeddings=False):
    """
    Clears the embeddings from all given sentences.
    :param sentences: list of sentences
    """
    flag = also_clear_word_embeddings
#     for sentence in sentences:
#         sentence.clear_embeddings(also_clear_word_embeddings=False)


def init_output_file(base_path: str, file_name: str):
    """
    Creates a local file.
    :param base_path: the path to the directory
    :param file_name: the file name
    :return: the created file
    """
    os.makedirs(base_path, exist_ok=True)

    file = os.path.join(base_path, file_name)
    open(file, "w", encoding='utf-8').close()
    return file


def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionary) -> List[List[int]]:
    """
    Convert list of labels (strings) to a one hot list.
    :param label_list: list of labels
    :param label_dict: label dictionary
    :return: converted label list
    """
    return [[1 if l in labels else 0 for l in label_dict.get_items()] for labels in label_list]


In [22]:
import os
import re
from abc import abstractmethod
from typing import List, Union, Dict
import logging
from pathlib import Path
import gensim
import numpy as np
import torch
from deprecated import deprecated
from bpemb import BPEmb
#from .nn import LockedDropout, WordDropout
#from .data import Dictionary, Token, Sentence
#from .file_utils import cached_path
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertModel, PRETRAINED_MODEL_ARCHIVE_MAP
#import flair
#from .nn import LockedDropout, WordDropout
#from .data import Dictionary, Token, Sentence

#from .file_utils import cached_path
log = logging.getLogger('flair')
import json
# dean add
CACHE_ROOT='https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/cache'

class Embeddings(torch.nn.Module):
    """Abstract base class for all embeddings. Every new type of embedding must implement these methods."""

    @property
    @abstractmethod
    def embedding_length(self) -> int:
        """Returns the length of the embedding vector."""
        pass

    @property
    @abstractmethod
    def embedding_type(self) -> str:
        pass

    def embed(self, sentences: Union[Sentence, List[Sentence]]) -> List[Sentence]:
        """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings
        are non-static."""

        # if only one sentence is passed, convert to list of sentence
        if type(sentences) is Sentence:
            sentences = [sentences]

        everything_embedded: bool = True

        if self.embedding_type == 'word-level':
            for sentence in sentences:
                for token in sentence.tokens:
                    if self.name not in token._embeddings.keys(): everything_embedded = False
        else:
            for sentence in sentences:
                if self.name not in sentence._embeddings.keys(): everything_embedded = False

        if not everything_embedded or not self.static_embeddings:
            self._add_embeddings_internal(sentences)

        return sentences

    @abstractmethod
    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
        """Private method for adding embeddings to all words in a list of sentences."""
        pass


class TokenEmbeddings(Embeddings):
    """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""

    @property
    @abstractmethod
    def embedding_length(self) -> int:
        """Returns the length of the embedding vector."""
        pass

    @property
    def embedding_type(self) -> str:
        return 'word-level'


class DocumentEmbeddings(Embeddings):
    """Abstract base class for all document-level embeddings. Ever new type of document embedding must implement these methods."""

    @property
    @abstractmethod
    def embedding_length(self) -> int:
        """Returns the length of the embedding vector."""
        pass

    @property
    def embedding_type(self) -> str:
        return 'sentence-level'


class StackedEmbeddings(TokenEmbeddings):
    """A stack of embeddings, used if you need to combine several different embedding types."""

    def __init__(self, embeddings: List[TokenEmbeddings], detach: bool = True):
        """The constructor takes a list of embeddings to be combined."""
        super().__init__()

        self.embeddings = embeddings

        # IMPORTANT: add embeddings as torch modules
        for i, embedding in enumerate(embeddings):
            self.add_module('list_embedding_{}'.format(i), embedding)

        self.detach: bool = detach
        self.name: str = 'Stack'
        self.static_embeddings: bool = True

        self.__embedding_type: str = embeddings[0].embedding_type

        self.__embedding_length: int = 0
        for embedding in embeddings:
            self.__embedding_length += embedding.embedding_length

    def embed(self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True):
        # if only one sentence is passed, convert to list of sentence
        if type(sentences) is Sentence:
            sentences = [sentences]

        for embedding in self.embeddings:
            embedding.embed(sentences)

    @property
    def embedding_type(self) -> str:
        return self.__embedding_type

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        for embedding in self.embeddings:
            embedding._add_embeddings_internal(sentences)

        return sentences


class WordEmbeddings(TokenEmbeddings):
    """Standard static word embeddings, such as GloVe or FastText."""

    def __init__(self, embeddings: str):
        """
        Initializes classic word embeddings. Constructor downloads required files if not there.
        :param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code.
        If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable.
        """

        old_base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/'
        base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/'

        # GLOVE embeddings
        if embeddings.lower() == 'glove' or embeddings.lower() == 'en-glove':
            cached_path(os.path.join(old_base_path, 'glove.gensim.vectors.npy'), cache_dir='embeddings')
            embeddings = cached_path(os.path.join(old_base_path, 'glove.gensim'), cache_dir='embeddings')

        # KOMNIOS embeddings
        elif embeddings.lower() == 'extvec' or embeddings.lower() == 'en-extvec':
            cached_path(os.path.join(old_base_path, 'extvec.gensim.vectors.npy'), cache_dir='embeddings')
            embeddings = cached_path(os.path.join(old_base_path, 'extvec.gensim'), cache_dir='embeddings')

        # FT-CRAWL embeddings
        elif embeddings.lower() == 'crawl' or embeddings.lower() == 'en-crawl':
            cached_path(os.path.join(base_path, 'en-fasttext-crawl-300d-1M.vectors.npy'), cache_dir='embeddings')
            embeddings = cached_path(os.path.join(base_path, 'en-fasttext-crawl-300d-1M'), cache_dir='embeddings')

        # FT-CRAWL embeddings
        elif embeddings.lower() == 'news' or embeddings.lower() == 'en-news' or embeddings.lower() == 'en':
            cached_path(os.path.join(base_path, 'en-fasttext-news-300d-1M.vectors.npy'), cache_dir='embeddings')
            embeddings = cached_path(os.path.join(base_path, 'en-fasttext-news-300d-1M'), cache_dir='embeddings')

        # other language fasttext embeddings
        elif len(embeddings.lower()) == 2 and not embeddings.lower() == 'en':
            cached_path(os.path.join(base_path, '{}-fasttext-300d-1M.vectors.npy'.format(embeddings)),
                        cache_dir='embeddings')
            embeddings = cached_path(os.path.join(base_path, '{}-fasttext-300d-1M'.format(embeddings)),
                                     cache_dir='embeddings')

        elif not os.path.exists(embeddings):
            raise ValueError(f'The given embeddings "{embeddings}" is not available or is not a valid path.')

        self.name = embeddings
        self.static_embeddings = True
        
        if embeddings.lower() in ['mimic3-i2b22010-vector-50', 'mimic3-i2b22014-vector-50'] or embeddings.lower() == 'flair/elmo_embeddings/mimic_elmo_embedding' or embeddings == "/lrlhps/users/c272987/embeddings/PubMed-w2v.txt":
            self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(embeddings)
        else:
            self.precomputed_word_embeddings = gensim.models.KeyedVectors.load(embeddings)

        self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
        super().__init__()

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        for i, sentence in enumerate(sentences):

            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token

                if token.text in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[token.text]
                elif token.text.lower() in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[token.text.lower()]
                elif re.sub(r'\d', '#', token.text.lower()) in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[re.sub(r'\d', '#', token.text.lower())]
                elif re.sub(r'\d', '0', token.text.lower()) in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[re.sub(r'\d', '0', token.text.lower())]
                else:
                    word_embedding = np.zeros(self.embedding_length, dtype='float')

                word_embedding = torch.FloatTensor(word_embedding)

                token.set_embedding(self.name, word_embedding)

        return sentences

class BioBertEmbeddings(TokenEmbeddings):
    def __init__(self, biobert_file: str):
        self.sentenceEmbedding = self.readBiobertEmbedding(biobert_file)
        self.name = "biobert"
        self.static_embeddings = True
        self.__embedding_length:int = 768 
        super().__init__()
    @property
    def embedding_length(self) -> int:
        return self.__embedding_length
    
    def readBiobertEmbedding(self, biobert_file):
        dict_fea = {}
        with open(biobert_file) as f:
             cur_sent = ""
             cur_feature = []
             for line in f:
                 data = json.loads(line)
                 #{"linex_index": 0, "features": [{"token": "[CLS]", "layers": [{"index": -1, "values": [0.599137,
                 for f in data['features']:
                     cur_token =  (f["token"])
                     if cur_token not in ['[CLS]', '[SEP]']:
                         cur_sent += cur_token
                         for layers in f["layers"]:
                             #print (layers.keys())
                             if layers["index"] == -1:
                                 #print (len(layers["values"]))
                                 cur_feature.append(layers["values"])
                     if cur_token == '[SEP]':
                         dict_fea[cur_sent.strip()] = cur_feature
                         cur_sent = ""
                         cur_feature = []
        return dict_fea
                              
            
    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        for i, sentence in enumerate(sentences):
            sent_txt = sentence.to_plain_string().replace(" ", "")
            sent_embedding = []
            if sent_txt in self.sentenceEmbedding:
                sent_embedding = self.sentenceEmbedding[sent_txt]
            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token
                if len( sent_embedding) == len(sentence.tokens):
                    word_embedding = sent_embedding[token_idx]
                    word_embedding = torch.FloatTensor(word_embedding)
                    token.set_embedding(self.name, word_embedding)
                else:
                    word_embedding = [1] * 768
                    word_embedding = torch.FloatTensor(word_embedding)
                    token.set_embedding(self.name, word_embedding)
        return sentences

class SemanticEmbeddings(TokenEmbeddings):
    """Standard static word embeddings, such as GloVe or FastText."""

    def __init__(self, embeddings: str, semantic_file: str):
        """
        Initializes classic word embeddings. Constructor downloads required files if not there.
        :param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code.
        If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable.
        """


        self.name = embeddings
        self.static_embeddings = True
        self.semantic_file = semantic_file
        if embeddings.lower() in ['mimic3-i2b22014-vector-dic-50', 'mimic3-i2b22010-vector-dic-50', 'flair/elmo_embeddings/mimic_semantic_elmo_embedding']:
            self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(embeddings)
        else:
            self.precomputed_word_embeddings = gensim.models.KeyedVectors.load(embeddings)
        self.semantic_dict = {}
        self.initSemanticDict()

        self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
        super().__init__()

    def initSemanticDict(self):
        words = []
        sems = []
        for line in open(self.semantic_file).readlines():
            #print (line.strip())
            if line.strip() == "":
                key = " ".join(words)
                self.semantic_dict[key] = sems
                #print (key, sems)
                words = []
                sems = []
            else:    
                items = line.strip().split("\t")
                words.append(items[0])
                sems.append(items[1])
                
        key = " ".join(words)
        self.semantic_dict[key] = sems
        words = []
        sems = []   

    def getSemanticSentence(self, tokens: List[Token]):
        #print (self.semantic_dict.items()[:10])
        #print("calling getSemanticSentence")
        keys = [token.text for token in tokens]
        key = " ".join(keys)
        #print (key)
        sems = []
        if key in self.semantic_dict:
            sems = self.semantic_dict[key]
        else:
            print (key + " was not found")
        for i in range(len(sems)):
            tokens[i].set_sem(sems[i])
            #print (tokens[i].sem)
        return tokens

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        for i, sentence in enumerate(sentences):
            sem_tokens = self.getSemanticSentence(sentence.tokens)
            #for token in sem_tokens:
            #    print (token.text, token.sem)
            for token, token_idx in zip(sem_tokens, range(len(sem_tokens))):
                token: Token = token

                if token.text in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[token.sem]
                elif token.sem.lower() in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[token.sem.lower()]
                elif re.sub(r'\d', '#', token.sem.lower()) in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[re.sub(r'\d', '#', token.sem.lower())]
                elif re.sub(r'\d', '0', token.sem.lower()) in self.precomputed_word_embeddings:
                    word_embedding = self.precomputed_word_embeddings[re.sub(r'\d', '0', token.sem.lower())]
                else:
                    word_embedding = np.zeros(self.embedding_length, dtype='float')

                word_embedding = torch.FloatTensor(word_embedding)

                token.set_embedding(self.name, word_embedding)

        return sentences

class MemoryEmbeddings(TokenEmbeddings):

    def __init__(self, tag_type: str, tag_dictionary: Dictionary):

        self.name = "memory"
        self.static_embeddings = False
        self.tag_type: str = tag_type
        self.tag_dictionary: Dictionary = tag_dictionary
        self.__embedding_length: int = len(tag_dictionary)

        self.memory: Dict[str:List] = {}

        super().__init__()

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def train(self, mode=True):
        super().train(mode=mode)
        if mode:
            self.memory: Dict[str:List] = {}

    def update_embedding(self, text: str, tag: str):
        self.memory[text][self.tag_dictionary.get_idx_for_item(tag)] += 1

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        for i, sentence in enumerate(sentences):

            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token

                if token.text not in self.memory:
                    self.memory[token.text] = [0] * self.__embedding_length

                word_embedding = torch.FloatTensor(self.memory[token.text])
                import torch.nn.functional as F
                word_embedding = F.normalize(word_embedding, p=2, dim=0)

                token.set_embedding(self.name, word_embedding)

                # add label if in training mode
                if self.training:
                    self.update_embedding(token.text, token.get_tag(self.tag_type).value)

        return sentences


class BPEmbSerializable(BPEmb):

    def __getstate__(self):
        state = self.__dict__.copy()
        # save the sentence piece model as binary file (not as path which may change)
        state['spm_model_binary'] = open(self.model_file, mode='rb').read()
        state['spm'] = None
        return state

    def __setstate__(self, state):
        from bpemb.util import sentencepiece_load
        model_file = self.model_tpl.format(lang=state['lang'], vs=state['vs'])
        self.__dict__ = state

        # write out the binary sentence piece model into the expected directory
#         self.cache_dir: Path = Path(flair.file_utils.CACHE_ROOT) / 'embeddings'
        # dean add
        self.cache_dir: Path = CACHE_ROOT +'/embeddings'
        if 'spm_model_binary' in self.__dict__:
            # if the model was saved as binary and it is not found on disk, write to appropriate path
            if not os.path.exists(self.cache_dir / state['lang']):
                os.makedirs(self.cache_dir / state['lang'])
            self.model_file = self.cache_dir / model_file
            with open(self.model_file, 'wb') as out:
                out.write(self.__dict__['spm_model_binary'])
        else:
            # otherwise, use normal process and potentially trigger another download
            self.model_file = self._load_file(model_file)

        # once the modes if there, load it with sentence piece
        state['spm'] = sentencepiece_load(self.model_file)

class BytePairEmbeddings(TokenEmbeddings):

    def __init__(self, language: str, dim: int = 50, syllables: int = 100000, cache_dir = CACHE_ROOT+ '/embeddings'):
#         def __init__(self, language: str, dim: int = 50, syllables: int = 100000, cache_dir = Path(flair.file_utils.CACHE_ROOT) / 'embeddings'):
        """
        Initializes BP embeddings. Constructor downloads required files if not there.
        """

        self.name: str = f'bpe-{language}-{syllables}-{dim}'
        self.static_embeddings = True
        self.embedder = BPEmbSerializable(lang=language, vs=syllables, dim=dim, cache_dir=cache_dir)

        self.__embedding_length: int = self.embedder.emb.vector_size * 2
        super().__init__()

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        for i, sentence in enumerate(sentences):

            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token

                if 'field' not in self.__dict__ or self.field is None:
                    word = token.text
                else:
                    word = token.get_tag(self.field).value

                embeddings = self.embedder.embed(word.lower())
                embedding = np.concatenate((embeddings[0], embeddings[len(embeddings)-1]))
                token.set_embedding(self.name, torch.tensor(embedding, dtype=torch.float))

        return sentences

    def __str__(self):
        return self.name



class ELMoEmbeddings(TokenEmbeddings):
    """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018."""

    def __init__(self, model: str = 'original'):
        super().__init__()

        try:
            import allennlp.commands.elmo
        except:
            log.warning('-' * 100)
            log.warning('ATTENTION! The library "allennlp" is not installed!')
            log.warning('To use ELMoEmbeddings, please first install with "pip install allennlp"')
            log.warning('-' * 100)
            pass

        self.name = 'elmo-' + model
        self.static_embeddings = True

        # the default model for ELMo is the 'original' model, which is very large
        options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
        weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
        # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
        if model == 'small':
            options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
            weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
        if model == 'medium':
            options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json'
            weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5'
        if model == 'pt' or model == 'portuguese':
            options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json'
            weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5'
        if model == 'pubmed':
            options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json'
            weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5'

        # put on Cuda if available
#         from flair import device
#         import de
        cuda_device = 0 if str(device) != 'cpu' else -1
        self.ee = allennlp.commands.elmo.ElmoEmbedder(options_file=options_file,
                                                      weight_file=weight_file,
                                                      cuda_device=cuda_device)

        # embed a dummy sentence to determine embedding_length
        dummy_sentence: Sentence = Sentence()
        dummy_sentence.add_token(Token('hello'))
        embedded_dummy = self.embed(dummy_sentence)
        self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        sentence_words: List[List[str]] = []
        for sentence in sentences:
            sentence_words.append([token.text for token in sentence])

        embeddings = self.ee.embed_batch(sentence_words)

        for i, sentence in enumerate(sentences):

            sentence_embeddings = embeddings[i]

            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token

                word_embedding = torch.cat([
                    torch.FloatTensor(sentence_embeddings[0, token_idx, :]),
                    torch.FloatTensor(sentence_embeddings[1, token_idx, :]),
                    torch.FloatTensor(sentence_embeddings[2, token_idx, :])
                ], 0)

                token.set_embedding(self.name, word_embedding)

        return sentences

    def extra_repr(self):
        return 'model={}'.format(self.name)

    def __str__(self):
        return self.name


class ELMoTransformerEmbeddings(TokenEmbeddings):
    """Contextual word embeddings using word-level Transformer-based LM, as proposed in Peters et al., 2018."""

    def __init__(self, model_file: str):
        super().__init__()

        try:
            from allennlp.modules.token_embedders.bidirectional_language_model_token_embedder import \
                BidirectionalLanguageModelTokenEmbedder
            from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
        except:
            log.warning('-' * 100)
            log.warning('ATTENTION! The library "allennlp" is not installed!')
            log.warning(
                'To use ELMoTransformerEmbeddings, please first install a recent version from https://github.com/allenai/allennlp')
            log.warning('-' * 100)
            pass

        self.name = 'elmo-transformer'
        self.static_embeddings = True
        self.lm_embedder = BidirectionalLanguageModelTokenEmbedder(
            archive_file=model_file,
            dropout=0.2,
            bos_eos_tokens=("<S>", "</S>"),
            remove_bos_eos=True,
            requires_grad=False
        )
        self.lm_embedder = self.lm_embedder.to(device=flair.device)
        self.vocab = self.lm_embedder._lm.vocab
        self.indexer = ELMoTokenCharactersIndexer()

        # embed a dummy sentence to determine embedding_length
        dummy_sentence: Sentence = Sentence()
        dummy_sentence.add_token(Token('hello'))
        embedded_dummy = self.embed(dummy_sentence)
        self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
        # Avoid conflicts with flair's Token class
        import allennlp.data.tokenizers.token as allen_nlp_token

        indexer = self.indexer
        vocab = self.vocab

        for sentence in sentences:
            character_indices = indexer.tokens_to_indices([allen_nlp_token.Token(token.text) for token in sentence],
                                                          vocab, "elmo")["elmo"]

            indices_tensor = torch.LongTensor([character_indices])
            indices_tensor = indices_tensor.to(device=flair.device)
            embeddings = self.lm_embedder(indices_tensor)[0].detach().cpu().numpy()

            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token
                embedding = embeddings[token_idx]
                word_embedding = torch.FloatTensor(embedding)
                token.set_embedding(self.name, word_embedding)

        return sentences

    def extra_repr(self):
        return 'model={}'.format(self.name)

    def __str__(self):
        return self.name

class CharacterEmbeddings(TokenEmbeddings):
    """Character embeddings of words, as proposed in Lample et al., 2016."""

    def __init__(self, path_to_char_dict: str = None):
        """Uses the default character dictionary if none provided."""

        super(CharacterEmbeddings, self).__init__()
        self.name = 'Char'
        self.static_embeddings = False

        # use list of common characters if none provided
        if path_to_char_dict is None:
            self.char_dictionary: Dictionary = Dictionary.load('common-chars')
        else:
            self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict)

        self.char_embedding_dim: int = 25
        self.hidden_size_char: int = 25
        self.char_embedding = torch.nn.Embedding(len(self.char_dictionary.item2idx), self.char_embedding_dim)
        self.char_rnn = torch.nn.LSTM(self.char_embedding_dim, self.hidden_size_char, num_layers=1,
                                      bidirectional=True)

        self.__embedding_length = self.char_embedding_dim * 2

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]):

        for sentence in sentences:

            tokens_char_indices = []

            # translate words in sentence into ints using dictionary
            for token in sentence.tokens:
                token: Token = token
                char_indices = [self.char_dictionary.get_idx_for_item(char) for char in token.text]
                tokens_char_indices.append(char_indices)

            # sort words by length, for batching and masking
            tokens_sorted_by_length = sorted(tokens_char_indices, key=lambda p: len(p), reverse=True)
            d = {}
            for i, ci in enumerate(tokens_char_indices):
                for j, cj in enumerate(tokens_sorted_by_length):
                    if ci == cj:
                        d[j] = i
                        continue
            chars2_length = [len(c) for c in tokens_sorted_by_length]
            longest_token_in_sentence = max(chars2_length)
            tokens_mask = np.zeros((len(tokens_sorted_by_length), longest_token_in_sentence), dtype='int')
            for i, c in enumerate(tokens_sorted_by_length):
                tokens_mask[i, :chars2_length[i]] = c

            tokens_mask = torch.LongTensor(tokens_mask)

            # chars for rnn processing
            chars = tokens_mask
            if torch.cuda.is_available():
                chars = chars.cuda()

            character_embeddings = self.char_embedding(chars).transpose(0, 1)

            packed = torch.nn.utils.rnn.pack_padded_sequence(character_embeddings, chars2_length)

            lstm_out, self.hidden = self.char_rnn(packed)

            outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
            outputs = outputs.transpose(0, 1)
            chars_embeds_temp = torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2))))
            if torch.cuda.is_available():
                chars_embeds_temp = chars_embeds_temp.cuda()
            for i, index in enumerate(output_lengths):
                chars_embeds_temp[i] = outputs[i, index - 1]
            character_embeddings = chars_embeds_temp.clone()
            for i in range(character_embeddings.size(0)):
                character_embeddings[d[i]] = chars_embeds_temp[i]

            for token_number, token in enumerate(sentence.tokens):
                token.set_embedding(self.name, character_embeddings[token_number])


class FlairEmbeddings(TokenEmbeddings):
    """Contextual string embeddings of words, as proposed in Akbik et al., 2018."""

    def __init__(self, model: str, use_cache: bool = False, cache_directory: Path = None, chars_per_chunk: int = 512):
        """
        initializes contextual string embeddings using a character-level language model.
        :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
                'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward'
                depending on which character language model is desired.
        :param use_cache: if set to False, will not write embeddings to file for later retrieval. this saves disk space but will
                not allow re-use of once computed embeddings that do not fit into memory
        :param cache_directory: if cache_directory is not set, the cache will be written to ~/.flair/embeddings. otherwise the cache
                is written to the provided directory.
        :param  chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster but requires
                more memory. Lower means slower but less memory.
        """
        super().__init__()

        cache_dir = Path('embeddings')

        # multilingual forward (English, German, French, Italian, Dutch, Polish)
        if model.lower() == 'multi-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-multi-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # multilingual backward  (English, German, French, Italian, Dutch, Polish)
        elif model.lower() == 'multi-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-multi-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # multilingual forward fast (English, German, French, Italian, Dutch, Polish)
        elif model.lower() == 'multi-forward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-multi-forward-fast-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # multilingual backward fast (English, German, French, Italian, Dutch, Polish)
        elif model.lower() == 'multi-backward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-multi-backward-fast-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # news-english-forward
        elif model.lower() == 'news-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # news-english-backward
        elif model.lower() == 'news-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # news-english-forward
        elif model.lower() == 'news-forward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward-1024-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # news-english-backward
        elif model.lower() == 'news-backward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward-1024-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # mix-english-forward
        elif model.lower() == 'mix-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-forward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # mix-english-backward
        elif model.lower() == 'mix-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-backward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # mix-german-forward
        elif model.lower() == 'german-forward' or model.lower() == 'de-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-forward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # mix-german-backward
        elif model.lower() == 'german-backward' or model.lower() == 'de-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-backward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # common crawl Polish forward
        elif model.lower() == 'polish-forward' or model.lower() == 'pl-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-polish-forward-v0.2.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # common crawl Polish backward
        elif model.lower() == 'polish-backward' or model.lower() == 'pl-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-polish-backward-v0.2.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Slovenian forward
        elif model.lower() == 'slovenian-forward' or model.lower() == 'sl-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-sl-large-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Slovenian backward
        elif model.lower() == 'slovenian-backward' or model.lower() == 'sl-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-sl-large-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Bulgarian forward
        elif model.lower() == 'bulgarian-forward' or model.lower() == 'bg-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-bg-small-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Bulgarian backward
        elif model.lower() == 'bulgarian-backward' or model.lower() == 'bg-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-bg-small-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Dutch forward
        elif model.lower() == 'dutch-forward' or model.lower() == 'nl-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-nl-large-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Dutch backward
        elif model.lower() == 'dutch-backward' or model.lower() == 'nl-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-nl-large-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Swedish forward
        elif model.lower() == 'swedish-forward' or model.lower() == 'sv-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-sv-large-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Swedish backward
        elif model.lower() == 'swedish-backward' or model.lower() == 'sv-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-sv-large-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # French forward
        elif model.lower() == 'french-forward' or model.lower() == 'fr-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-fr-charlm-forward.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # French backward
        elif model.lower() == 'french-backward' or model.lower() == 'fr-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-fr-charlm-backward.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Czech forward
        elif model.lower() == 'czech-forward' or model.lower() == 'cs-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-cs-large-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Czech backward
        elif model.lower() == 'czech-backward' or model.lower() == 'cs-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-cs-large-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Portuguese forward
        elif model.lower() == 'portuguese-forward' or model.lower() == 'pt-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-pt-forward.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Portuguese backward
        elif model.lower() == 'portuguese-backward' or model.lower() == 'pt-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-pt-backward.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Basque forward
        elif model.lower() == 'basque-forward' or model.lower() == 'eu-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-eu-large-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Basque backward
        elif model.lower() == 'basque-backward' or model.lower() == 'eu-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/lm-eu-large-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Spanish forward fast
        elif model.lower() == 'spanish-forward-fast' or model.lower() == 'es-forward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/language_model_es_forward/lm-es-forward-fast.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Spanish backward fast
        elif model.lower() == 'spanish-backward-fast' or model.lower() == 'es-backward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/language_model_es_backward/lm-es-backward-fast.pt'
            model = cached_path(base_path, cache_dir=cache_dir)

        # Spanish forward
        elif model.lower() == 'spanish-forward' or model.lower() == 'es-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/language_model_es_forward_long/lm-es-forward.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Spanish backward
        elif model.lower() == 'spanish-backward' or model.lower() == 'es-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/language_model_es_backward_long/lm-es-backward.pt'
            model = cached_path(base_path, cache_dir=cache_dir)
        # Pubmed forward
        elif model.lower() == "pubmed-forward":
            base_path = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/pubmed-2015-fw-lm.pt"
            model = cached_path(base_path, cache_dir=cache_dir)
        # Pubmed backward
        elif model.lower() == "pubmed-backward":
            base_path = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/pubmed-2015-bw-lm.pt"
            model = cached_path(base_path, cache_dir=cache_dir)

        elif not Path(model).exists():
            raise ValueError(f'The given model "{model}" is not available or is not a valid path.')

        self.name = str(model)
        self.static_embeddings = True

#         from flair.models import LanguageModel
        self.lm = LanguageModel.load_language_model(model)

        self.is_forward_lm: bool = self.lm.is_forward_lm
        self.chars_per_chunk: int = chars_per_chunk

        # initialize cache if use_cache set
        self.cache = None
        if use_cache:
            cache_path = Path(f'{self.name}-tmp-cache.sqllite') if not cache_directory else \
                cache_directory / f'{self.name}-tmp-cache.sqllite'
            from sqlitedict import SqliteDict
            self.cache = SqliteDict(str(cache_path), autocommit=True)

        # embed a dummy sentence to determine embedding_length
        dummy_sentence: Sentence = Sentence()
        dummy_sentence.add_token(Token('hello'))
        embedded_dummy = self.embed(dummy_sentence)
        self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())

        # set to eval mode
        self.eval()

    def train(self, mode=True):
        pass

    def __getstate__(self):
        # Copy the object's state from self.__dict__ which contains
        # all our instance attributes. Always use the dict.copy()
        # method to avoid modifying the original state.
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        state['cache'] = None
        return state

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        # make compatible with serialized models
        if 'chars_per_chunk' not in self.__dict__:
            self.chars_per_chunk = 512

        # if cache is used, try setting embeddings from cache first
        if 'cache' in self.__dict__ and self.cache is not None:

            # try populating embeddings from cache
            all_embeddings_retrieved_from_cache: bool = True
            for sentence in sentences:
                key = sentence.to_tokenized_string()
                embeddings = self.cache.get(key)

                if not embeddings:
                    all_embeddings_retrieved_from_cache = False
                    break
                else:
                    for token, embedding in zip(sentence, embeddings):
                        token.set_embedding(self.name, torch.FloatTensor(embedding))

            if all_embeddings_retrieved_from_cache:
                return sentences

        with torch.no_grad():

            # if this is not possible, use LM to generate embedding. First, get text sentences
            text_sentences = [sentence.to_tokenized_string() for sentence in sentences]

            longest_character_sequence_in_batch: int = len(max(text_sentences, key=len))

            # pad strings with whitespaces to longest sentence
            sentences_padded: List[str] = []
            append_padded_sentence = sentences_padded.append

            start_marker = '\n'

            end_marker = ' '
            extra_offset = len(start_marker)
            for sentence_text in text_sentences:
                pad_by = longest_character_sequence_in_batch - len(sentence_text)
                if self.is_forward_lm:
                    padded = '{}{}{}{}'.format(start_marker, sentence_text, end_marker, pad_by * ' ')
                    append_padded_sentence(padded)
                else:
                    padded = '{}{}{}{}'.format(start_marker, sentence_text[::-1], end_marker, pad_by * ' ')
                    append_padded_sentence(padded)

            # get hidden states from language model
            all_hidden_states_in_lm = self.lm.get_representation(sentences_padded, self.chars_per_chunk)

            # take first or last hidden states from language model as word representation
            for i, sentence in enumerate(sentences):
                sentence_text = sentence.to_tokenized_string()

                offset_forward: int = extra_offset
                offset_backward: int = len(sentence_text) + extra_offset

                for token in sentence.tokens:
                    token: Token = token

                    offset_forward += len(token.text)

                    if self.is_forward_lm:
                        offset = offset_forward
                    else:
                        offset = offset_backward

                    embedding = all_hidden_states_in_lm[offset, i, :]

                    # if self.tokenized_lm or token.whitespace_after:
                    offset_forward += 1
                    offset_backward -= 1

                    offset_backward -= len(token.text)

                    token.set_embedding(self.name, embedding.clone().detach())

            all_hidden_states_in_lm = None

        if 'cache' in self.__dict__ and self.cache is not None:
            for sentence in sentences:
                self.cache[sentence.to_tokenized_string()] = [token._embeddings[self.name].tolist() for token in
                                                              sentence]

        return sentences

    def __str__(self):
        return self.name



class PooledFlairEmbeddings(TokenEmbeddings):

    def __init__(self,
                 contextual_embeddings: Union[str, FlairEmbeddings],
                 pooling: str = 'fade',
                 only_capitalized: bool = False,
                 **kwargs):

        super().__init__()

        # use the character language model embeddings as basis
        if type(contextual_embeddings) is str:
            self.context_embeddings: FlairEmbeddings = FlairEmbeddings(contextual_embeddings, **kwargs)
        else:
            self.context_embeddings: FlairEmbeddings = contextual_embeddings

        # length is twice the original character LM embedding length
        self.embedding_length = self.context_embeddings.embedding_length * 2
        self.name = self.context_embeddings.name + '-context'

        # these fields are for the embedding memory
        self.word_embeddings = {}
        self.word_count = {}

        # whether to add only capitalized words to memory (faster runtime and lower memory consumption)
        self.only_capitalized = only_capitalized

        # we re-compute embeddings dynamically at each epoch
        self.static_embeddings = False

        # set the memory method
        self.pooling = pooling
        if pooling == 'mean':
            self.aggregate_op = torch.add
        elif pooling == 'fade':
            self.aggregate_op = torch.add
        elif pooling == 'max':
            self.aggregate_op = torch.max
        elif pooling == 'min':
            self.aggregate_op = torch.min

    def train(self, mode=True):
        super().train(mode=mode)
        if mode:
            # memory is wiped each time we do a training run
            print('train mode resetting embeddings')
            self.word_embeddings = {}
            self.word_count = {}

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        self.context_embeddings.embed(sentences)

        # if we keep a pooling, it needs to be updated continuously
        for sentence in sentences:
            for token in sentence.tokens:

                # update embedding
                local_embedding = token._embeddings[self.context_embeddings.name]

                if token.text[0].isupper() or not self.only_capitalized:

                    if token.text not in self.word_embeddings:
                        self.word_embeddings[token.text] = local_embedding
                        self.word_count[token.text] = 1
                    else:
                        aggregated_embedding = self.aggregate_op(self.word_embeddings[token.text], local_embedding)
                        if self.pooling == 'fade':
                            aggregated_embedding /= 2
                        self.word_embeddings[token.text] = aggregated_embedding
                        self.word_count[token.text] += 1

        # add embeddings after updating
        for sentence in sentences:
            for token in sentence.tokens:
                if token.text in self.word_embeddings:
                    base = self.word_embeddings[token.text] / self.word_count[token.text] \
                        if self.pooling == 'mean' else self.word_embeddings[token.text]
                else:
                    base = token._embeddings[self.context_embeddings.name]

                token.set_embedding(self.name, base)

        return sentences

    def embedding_length(self) -> int:
        return self.embedding_length


class BertEmbeddings(TokenEmbeddings):

    def __init__(self,
                 bert_model: str = 'bert-base-uncased',
                 layers: str = '-1,-2,-3,-4',
                 pooling_operation: str = 'first'):
        """
        Bidirectional transformer embeddings of words, as proposed in Devlin et al., 2018.
        :param bert_model: name of BERT model ('')
        :param layers: string indicating which layers to take for embedding
        :param pooling_operation: how to get from token piece embeddings to token embedding. Either pool them and take
        the average ('mean') or use first word piece embedding as token embedding ('first)
        """
        super().__init__()

        if bert_model not in PRETRAINED_MODEL_ARCHIVE_MAP.keys():
            raise ValueError('Provided bert-model is not available.')

        self.tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.model = BertModel.from_pretrained(bert_model)
        self.layer_indexes = [int(x) for x in layers.split(",")]
        self.pooling_operation = pooling_operation
        self.name = str(bert_model)
        self.static_embeddings = True

    class BertInputFeatures(object):
        """Private helper class for holding BERT-formatted features"""

        def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids, token_subtoken_count):
            self.unique_id = unique_id
            self.tokens = tokens
            self.input_ids = input_ids
            self.input_mask = input_mask
            self.input_type_ids = input_type_ids
            self.token_subtoken_count = token_subtoken_count

    def _convert_sentences_to_features(self, sentences, max_sequence_length: int) -> [BertInputFeatures]:

        max_sequence_length = max_sequence_length + 2

        features: List[BertEmbeddings.BertInputFeatures] = []
        for (sentence_index, sentence) in enumerate(sentences):

            bert_tokenization: List[str] = []
            token_subtoken_count: Dict[int, int] = {}

            for token in sentence:
                subtokens = self.tokenizer.tokenize(token.text)
                bert_tokenization.extend(subtokens)
                token_subtoken_count[token.idx] = len(subtokens)

            if len(bert_tokenization) > max_sequence_length - 2:
                bert_tokenization = bert_tokenization[0:(max_sequence_length - 2)]

            tokens = []
            input_type_ids = []
            tokens.append("[CLS]")
            input_type_ids.append(0)
            for token in bert_tokenization:
                tokens.append(token)
                input_type_ids.append(0)
            tokens.append("[SEP]")
            input_type_ids.append(0)

            input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_sequence_length:
                input_ids.append(0)
                input_mask.append(0)
                input_type_ids.append(0)

            features.append(BertEmbeddings.BertInputFeatures(
                unique_id=sentence_index,
                tokens=tokens,
                input_ids=input_ids,
                input_mask=input_mask,
                input_type_ids=input_type_ids,
                token_subtoken_count=token_subtoken_count))

        return features

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
        """Add embeddings to all words in a list of sentences. If embeddings are already added,
        updates only if embeddings are non-static."""

        # first, find longest sentence in batch
        longest_sentence_in_batch: int = len(
            max([self.tokenizer.tokenize(sentence.to_tokenized_string()) for sentence in sentences], key=len))

        # prepare id maps for BERT model
        features = self._convert_sentences_to_features(sentences, longest_sentence_in_batch)
        all_input_ids = torch.LongTensor([f.input_ids for f in features]).to(flair.device)
        all_input_masks = torch.LongTensor([f.input_mask for f in features]).to(flair.device)

        # put encoded batch through BERT model to get all hidden states of all encoder layers
        self.model.to(flair.device)
        self.model.eval()
        all_encoder_layers, _ = self.model(all_input_ids, token_type_ids=None, attention_mask=all_input_masks)

        with torch.no_grad():

            for sentence_index, sentence in enumerate(sentences):

                feature = features[sentence_index]

                # get aggregated embeddings for each BERT-subtoken in sentence
                subtoken_embeddings = []
                for token_index, _ in enumerate(feature.tokens):
                    all_layers = []
                    for layer_index in self.layer_indexes:
                        layer_output = all_encoder_layers[int(layer_index)].detach().cpu()[sentence_index]
                        all_layers.append(layer_output[token_index])

                    subtoken_embeddings.append(torch.cat(all_layers))

                # get the current sentence object
                token_idx = 0
                for token in sentence:
                    # add concatenated embedding to sentence
                    token_idx += 1

                    if self.pooling_operation == 'first':
                        # use first subword embedding if pooling operation is 'first'
                        token.set_embedding(self.name, subtoken_embeddings[token_idx])
                    else:
                        # otherwise, do a mean over all subwords in token
                        embeddings = subtoken_embeddings[token_idx:token_idx + feature.token_subtoken_count[token.idx]]
                        embeddings = [embedding.unsqueeze(0) for embedding in embeddings]
                        mean = torch.mean(torch.cat(embeddings, dim=0), dim=0)
                        token.set_embedding(self.name, mean)

                    token_idx += feature.token_subtoken_count[token.idx] - 1

        return sentences

    @property
    @abstractmethod
    def embedding_length(self) -> int:
        """Returns the length of the embedding vector."""
        return len(self.layer_indexes) * self.model.config.hidden_size

class CharLMEmbeddings(TokenEmbeddings):
    """Contextual string embeddings of words, as proposed in Akbik et al., 2018."""

    def __init__(self, model, detach: bool = True, use_cache: bool = True, cache_directory: str = None):
        """
        initializes contextual string embeddings using a character-level language model.
        :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
                'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward'
                depending on which character language model is desired.
        :param detach: if set to False, the gradient will propagate into the language model. this dramatically slows down
                training and often leads to worse results, so not recommended.
        :param use_cache: if set to False, will not write embeddings to file for later retrieval. this saves disk space but will
                not allow re-use of once computed embeddings that do not fit into memory
        :param cache_directory: if cache_directory is not set, the cache will be written to ~/.flair/embeddings. otherwise the cache
                is written to the provided directory.
        """
        super().__init__()

        # news-english-forward
        if model.lower() == 'news-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # news-english-backward
        elif model.lower() == 'news-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # news-english-forward
        elif model.lower() == 'news-forward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward-1024-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # news-english-backward
        elif model.lower() == 'news-backward-fast':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward-1024-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # mix-english-forward
        elif model.lower() == 'mix-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-forward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # mix-english-backward
        elif model.lower() == 'mix-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-backward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # mix-german-forward
        elif model.lower() == 'german-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-forward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # mix-german-backward
        elif model.lower() == 'german-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-backward-v0.2rc.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # common crawl Polish forward
        elif model.lower() == 'polish-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-polish-forward-v0.2.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # common crawl Polish backward
        elif model.lower() == 'polish-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-polish-backward-v0.2.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # Slovenian forward
        elif model.lower() == 'slovenian-forward' or model.lower() == 'sl-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-sl-large-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir='embeddings')
        # Slovenian backward
        elif model.lower() == 'slovenian-backward' or model.lower() == 'sl-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-sl-large-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir='embeddings')

        # Bulgarian forward
        elif model.lower() == 'bulgarian-forward' or model.lower() == 'bg-forward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-bg-small-forward-v0.1.pt'
            model = cached_path(base_path, cache_dir='embeddings')
        # Bulgarian backward
        elif model.lower() == 'bulgarian-backward' or model.lower() == 'bg-backward':
            base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/lm-bg-small-backward-v0.1.pt'
            model = cached_path(base_path, cache_dir='embeddings')
        elif model.lower() == "pubmed-forward":
            base_path = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/pubmed-2015-fw-lm.pt"
            model = cached_path(base_path, cache_dir='embeddings')
        # Pubmed backward
        elif model.lower() == "pubmed-backward":
            base_path = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/pubmed-2015-bw-lm.pt"
            model = cached_path(base_path, cache_dir='embeddings')
        elif not os.path.exists(model):
            raise ValueError(f'The given model "{model}" is not available or is not a valid path.')

        self.name = model
        self.static_embeddings = detach

        from flair.models import LanguageModel
        self.lm = LanguageModel.load_language_model(model)
        self.detach = detach

        self.is_forward_lm: bool = self.lm.is_forward_lm

        # caching variables
        self.use_cache: bool = use_cache
        self.cache = None
        self.cache_directory: str = cache_directory

        dummy_sentence: Sentence = Sentence()
        dummy_sentence.add_token(Token('hello'))
        embedded_dummy = self.embed(dummy_sentence)
        self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())

        # set to eval mode
        self.eval()

    def train(self, mode=True):
        pass

    def __getstate__(self):
        # Copy the object's state from self.__dict__ which contains
        # all our instance attributes. Always use the dict.copy()
        # method to avoid modifying the original state.
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        state['cache'] = None
        state['use_cache'] = False
        state['cache_directory'] = None
        return state

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

        # this whole block is for compatibility with older serialized models  TODO: remove in version 0.4
        if 'cache' not in self.__dict__ or 'cache_directory' not in self.__dict__:
            self.use_cache = False
            self.cache_directory = None
        else:
            cache_path = '{}-tmp-cache.sqllite'.format(self.name) if not self.cache_directory else os.path.join(
                self.cache_directory, '{}-tmp-cache.sqllite'.format(os.path.basename(self.name)))
            if not os.path.exists(cache_path):
                self.use_cache = False
                self.cache_directory = None

        # if cache is used, try setting embeddings from cache first
        if self.use_cache:

            # lazy initialization of cache
            if not self.cache:
                from sqlitedict import SqliteDict
                self.cache = SqliteDict(cache_path, autocommit=True)

            # try populating embeddings from cache
            all_embeddings_retrieved_from_cache: bool = True
            for sentence in sentences:
                key = sentence.to_tokenized_string()
                embeddings = self.cache.get(key)

                if not embeddings:
                    all_embeddings_retrieved_from_cache = False
                    break
                else:
                    for token, embedding in zip(sentence, embeddings):
                        token.set_embedding(self.name, torch.FloatTensor(embedding))

            if all_embeddings_retrieved_from_cache:
                return sentences

        # if this is not possible, use LM to generate embedding. First, get text sentences
        text_sentences = [sentence.to_tokenized_string() for sentence in sentences]

        longest_character_sequence_in_batch: int = len(max(text_sentences, key=len))

        # pad strings with whitespaces to longest sentence
        sentences_padded: List[str] = []
        append_padded_sentence = sentences_padded.append

        end_marker = ' '
        extra_offset = 1
        for sentence_text in text_sentences:
            pad_by = longest_character_sequence_in_batch - len(sentence_text)
            if self.is_forward_lm:
                padded = '\n{}{}{}'.format(sentence_text, end_marker, pad_by * ' ')
                append_padded_sentence(padded)
            else:
                padded = '\n{}{}{}'.format(sentence_text[::-1], end_marker, pad_by * ' ')
                append_padded_sentence(padded)

        # get hidden states from language model
        all_hidden_states_in_lm = self.lm.get_representation(sentences_padded, self.detach)

        # take first or last hidden states from language model as word representation
        for i, sentence in enumerate(sentences):
            sentence_text = sentence.to_tokenized_string()

            offset_forward: int = extra_offset
            offset_backward: int = len(sentence_text) + extra_offset

            for token in sentence.tokens:
                token: Token = token

                offset_forward += len(token.text)

                if self.is_forward_lm:
                    offset = offset_forward
                else:
                    offset = offset_backward

                embedding = all_hidden_states_in_lm[offset, i, :]

                # if self.tokenized_lm or token.whitespace_after:
                offset_forward += 1
                offset_backward -= 1

                offset_backward -= len(token.text)

                token.set_embedding(self.name, embedding)

        if self.use_cache:
            for sentence in sentences:
                self.cache[sentence.to_tokenized_string()] = [token._embeddings[self.name].tolist() for token in
                                                              sentence]

        return sentences


class DocumentMeanEmbeddings(DocumentEmbeddings):

    @deprecated(version='0.3.1', reason="The functionality of this class is moved to 'DocumentPoolEmbeddings'")
    def __init__(self, token_embeddings: List[TokenEmbeddings]):
        """The constructor takes a list of embeddings to be combined."""
        super().__init__()

        self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)
        self.name: str = 'document_mean'

        self.__embedding_length: int = self.embeddings.embedding_length

        if torch.cuda.is_available():
            self.cuda()

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def embed(self, sentences: Union[List[Sentence], Sentence]):
        """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates
        only if embeddings are non-static."""

        everything_embedded: bool = True

        # if only one sentence is passed, convert to list of sentence
        if type(sentences) is Sentence:
            sentences = [sentences]

        for sentence in sentences:
            if self.name not in sentence._embeddings.keys(): everything_embedded = False

        if not everything_embedded:

            self.embeddings.embed(sentences)

            for sentence in sentences:
                word_embeddings = []
                for token in sentence.tokens:
                    token: Token = token
                    word_embeddings.append(token.get_embedding().unsqueeze(0))

                word_embeddings = torch.cat(word_embeddings, dim=0)
                if torch.cuda.is_available():
                    word_embeddings = word_embeddings.cuda()

                mean_embedding = torch.mean(word_embeddings, 0)

                sentence.set_embedding(self.name, mean_embedding.unsqueeze(0))

    def _add_embeddings_internal(self, sentences: List[Sentence]):
        pass


class DocumentPoolEmbeddings(DocumentEmbeddings):

    def __init__(self, token_embeddings: List[TokenEmbeddings], mode: str = 'mean'):
        """The constructor takes a list of embeddings to be combined.
        :param token_embeddings: a list of token embeddings
        :param mode: a string which can any value from ['mean', 'max', 'min']
        """
        super().__init__()

        self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)

        self.__embedding_length: int = self.embeddings.embedding_length

        if torch.cuda.is_available():
            self.cuda()

        self.mode = mode
        if self.mode == 'mean':
            self.pool_op = torch.mean
        elif mode == 'max':
            self.pool_op = torch.max
        elif mode == 'min':
            self.pool_op = torch.min
        else:
            raise ValueError(f'Pooling operation for {self.mode!r} is not defined')
        self.name: str = f'document_{self.mode}'

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def embed(self, sentences: Union[List[Sentence], Sentence]):
        """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates
        only if embeddings are non-static."""

        everything_embedded: bool = True

        # if only one sentence is passed, convert to list of sentence
        if isinstance(sentences, Sentence):
            sentences = [sentences]

        for sentence in sentences:
            if self.name not in sentence._embeddings.keys(): everything_embedded = False

        if not everything_embedded:

            self.embeddings.embed(sentences)

            for sentence in sentences:
                word_embeddings = []
                for token in sentence.tokens:
                    token: Token = token
                    word_embeddings.append(token.get_embedding().unsqueeze(0))

                word_embeddings = torch.cat(word_embeddings, dim=0)
                if torch.cuda.is_available():
                    word_embeddings = word_embeddings.cuda()

                if self.mode == 'mean':
                    pooled_embedding = self.pool_op(word_embeddings, 0)
                else:
                    pooled_embedding, _ = self.pool_op(word_embeddings, 0)

                sentence.set_embedding(self.name, pooled_embedding.unsqueeze(0))

    def _add_embeddings_internal(self, sentences: List[Sentence]):
        pass


class DocumentLSTMEmbeddings(DocumentEmbeddings):

    def __init__(self,
                 token_embeddings: List[TokenEmbeddings],
                 hidden_states=128,
                 num_layers=1,
                 reproject_words: bool = True,
                 reproject_words_dimension: int = None,
                 bidirectional: bool = False,
                 use_word_dropout: bool = False,
                 use_locked_dropout: bool = False):
        """The constructor takes a list of embeddings to be combined.
        :param token_embeddings: a list of token embeddings
        :param hidden_states: the number of hidden states in the lstm
        :param num_layers: the number of layers for the lstm
        :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
        layer before putting them into the lstm or not
        :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
        dimension as before will be taken.
        :param bidirectional: boolean value, indicating whether to use a bidirectional lstm or not
        representation of the lstm to be used as final document embedding.
        :param use_word_dropout: boolean value, indicating whether to use word dropout or not.
        :param use_locked_dropout: boolean value, indicating whether to use locked dropout or not.
        """
        super().__init__()

        self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=token_embeddings)

        self.reproject_words = reproject_words
        self.bidirectional = bidirectional

        self.length_of_all_token_embeddings: int = self.embeddings.embedding_length

        self.name = 'document_lstm'
        self.static_embeddings = False

        self.__embedding_length: int = hidden_states
        if self.bidirectional:
            self.__embedding_length *= 4

        self.embeddings_dimension: int = self.length_of_all_token_embeddings
        if self.reproject_words and reproject_words_dimension is not None:
            self.embeddings_dimension = reproject_words_dimension

        # bidirectional LSTM on top of embedding layer
        self.word_reprojection_map = torch.nn.Linear(self.length_of_all_token_embeddings,
                                                     self.embeddings_dimension)
        self.rnn = torch.nn.GRU(self.embeddings_dimension, hidden_states, num_layers=num_layers,
                                bidirectional=self.bidirectional)

        # dropouts
        if use_locked_dropout:
            self.dropout: torch.nn.Module = LockedDropout(0.5)
        else:
            self.dropout = torch.nn.Dropout(0.5)

        self.use_word_dropout: bool = use_word_dropout
        if self.use_word_dropout:
            self.word_dropout = WordDropout(0.05)

        torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)

        if torch.cuda.is_available():
            self.cuda()

    @property
    def embedding_length(self) -> int:
        return self.__embedding_length

    def embed(self, sentences: Union[List[Sentence], Sentence]):
        """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
         only if embeddings are non-static."""

        if type(sentences) is Sentence:
            sentences = [sentences]

        self.rnn.zero_grad()

        sentences.sort(key=lambda x: len(x), reverse=True)

        self.embeddings.embed(sentences)

        # first, sort sentences by number of tokens
        longest_token_sequence_in_batch: int = len(sentences[0])

        all_sentence_tensors = []
        lengths: List[int] = []

        # go through each sentence in batch
        for i, sentence in enumerate(sentences):

            lengths.append(len(sentence.tokens))

            word_embeddings = []

            for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
                token: Token = token
                word_embeddings.append(token.get_embedding().unsqueeze(0))

            # PADDING: pad shorter sentences out
            for add in range(longest_token_sequence_in_batch - len(sentence.tokens)):
                word_embeddings.append(
                    torch.FloatTensor(np.zeros(self.length_of_all_token_embeddings, dtype='float')).unsqueeze(0))

            word_embeddings_tensor = torch.cat(word_embeddings, 0)

            sentence_states = word_embeddings_tensor

            # ADD TO SENTENCE LIST: add the representation
            all_sentence_tensors.append(sentence_states.unsqueeze(1))

        # --------------------------------------------------------------------
        # GET REPRESENTATION FOR ENTIRE BATCH
        # --------------------------------------------------------------------
        sentence_tensor = torch.cat(all_sentence_tensors, 1)
        if torch.cuda.is_available():
            sentence_tensor = sentence_tensor.cuda()

        # --------------------------------------------------------------------
        # FF PART
        # --------------------------------------------------------------------
        # use word dropout if set
        if self.use_word_dropout:
            sentence_tensor = self.word_dropout(sentence_tensor)

        if self.reproject_words:
            sentence_tensor = self.word_reprojection_map(sentence_tensor)

        sentence_tensor = self.dropout(sentence_tensor)

        packed = torch.nn.utils.rnn.pack_padded_sequence(sentence_tensor, lengths)

        self.rnn.flatten_parameters()

        lstm_out, hidden = self.rnn(packed)

        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)

        outputs = self.dropout(outputs)

        # --------------------------------------------------------------------
        # EXTRACT EMBEDDINGS FROM LSTM
        # --------------------------------------------------------------------
        for sentence_no, length in enumerate(lengths):
            last_rep = outputs[length - 1, sentence_no].unsqueeze(0)

            embedding = last_rep
            if self.bidirectional:
                first_rep = outputs[0, sentence_no].unsqueeze(0)
                embedding = torch.cat([first_rep, last_rep], 1)

            sentence = sentences[sentence_no]
            sentence.set_embedding(self.name, embedding)

    def _add_embeddings_internal(self, sentences: List[Sentence]):
        pass


class DocumentLMEmbeddings(DocumentEmbeddings):
    def __init__(self, charlm_embeddings: List[CharLMEmbeddings], detach: bool = True):
        super().__init__()

        self.embeddings = charlm_embeddings

        self.static_embeddings = detach
        self.detach = detach

        dummy: Sentence = Sentence('jo')
        self.embed([dummy])
        self._embedding_length: int = len(dummy.embedding)

    @property
    def embedding_length(self) -> int:
        return self._embedding_length

    def embed(self, sentences: Union[List[Sentence], Sentence]):
        if type(sentences) is Sentence:
            sentences = [sentences]

        for embedding in self.embeddings:
            embedding.embed(sentences)

            # iterate over sentences
            for sentence in sentences:

                # if its a forward LM, take last state
                if embedding.is_forward_lm:
                    sentence.set_embedding(embedding.name, sentence[len(sentence)]._embeddings[embedding.name])
                else:
                    sentence.set_embedding(embedding.name, sentence[1]._embeddings[embedding.name])

        return sentences

    def _add_embeddings_internal(self, sentences: List[Sentence]):
        pass


In [23]:
import torch.nn as nn
import torch
import math
from torch.autograd import Variable
from typing import List
#from flair.data import Dictionary


class LanguageModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self,
                 dictionary: Dictionary,
                 is_forward_lm: bool,
                 hidden_size: int,
                 nlayers: int,
                 embedding_size: int = 100,
                 nout=None,
                 dropout=0.5,
                 best_score=None):

        super(LanguageModel, self).__init__()

        self.dictionary = dictionary
        self.is_forward_lm: bool = is_forward_lm

        self.dropout = dropout
        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.nlayers = nlayers

        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(len(dictionary), embedding_size)

        if nlayers == 1:
            self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers)
        else:
            self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers, dropout=dropout)

        self.hidden = None

        self.nout = nout
        if nout is not None:
            self.proj = nn.Linear(hidden_size, nout)
            self.initialize(self.proj.weight)
            self.decoder = nn.Linear(nout, len(dictionary))
        else:
            self.proj = None
            self.decoder = nn.Linear(hidden_size, len(dictionary))

        self.init_weights()

        self.best_score = best_score

        # auto-spawn on GPU if available
        if torch.cuda.is_available():
            self.cuda()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def set_hidden(self, hidden):
        self.hidden = hidden

    def forward(self, input, hidden, ordered_sequence_lengths=None):
        encoded = self.encoder(input)
        emb = self.drop(encoded)

        self.rnn.flatten_parameters()

        output, hidden = self.rnn(emb, hidden)

        if self.proj is not None:
            output = self.proj(output)

        output = self.drop(output)

        decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))

        return decoded.view(output.size(0), output.size(1), decoded.size(1)), output, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        return (Variable(weight.new(self.nlayers, bsz, self.hidden_size).zero_()),
                Variable(weight.new(self.nlayers, bsz, self.hidden_size).zero_()))

    def get_representation(self, strings: List[str], detach_from_lm=True):

        sequences_as_char_indices: List[List[int]] = []
        for string in strings:
            char_indices = [self.dictionary.get_idx_for_item(char) for char in string]
            sequences_as_char_indices.append(char_indices)

        batch = Variable(torch.LongTensor(sequences_as_char_indices).transpose(0, 1))

        if torch.cuda.is_available():
            batch = batch.cuda()

        hidden = self.init_hidden(len(strings))
        prediction, rnn_output, hidden = self.forward(batch, hidden)

        if detach_from_lm: rnn_output = self.repackage_hidden(rnn_output)

        return rnn_output

    def repackage_hidden(self, h):
        """Wraps hidden states in new Variables, to detach them from their history."""
        if type(h) == torch.Tensor:
            return Variable(h.data)
        else:
            return tuple(self.repackage_hidden(v) for v in h)

    def initialize(self, matrix):
        in_, out_ = matrix.size()
        stdv = math.sqrt(3. / (in_ + out_))
        matrix.data.uniform_(-stdv, stdv)

    @classmethod
    def load_language_model(cls, model_file):

        if not torch.cuda.is_available():
            state = torch.load(model_file, map_location='cpu')
        else:
            state = torch.load(model_file)

        best_score = state['best_score'] if 'best_score' in state else None

        model = LanguageModel(state['dictionary'],
                                             state['is_forward_lm'],
                                             state['hidden_size'],
                                             state['nlayers'],
                                             state['embedding_size'],
                                             state['nout'],
                                             state['dropout'],
                                             best_score)
        model.load_state_dict(state['state_dict'])
        model.eval()
        if torch.cuda.is_available():
            model.cuda()
        return model

    def save(self, file):
        model_state = {
            'state_dict': self.state_dict(),
            'dictionary': self.dictionary,
            'is_forward_lm': self.is_forward_lm,
            'hidden_size': self.hidden_size,
            'nlayers': self.nlayers,
            'embedding_size': self.embedding_size,
            'nout': self.nout,
            'dropout': self.dropout,
            'best_score': self.best_score
        }
        torch.save(model_state, file, pickle_protocol=4)

    def generate_text(self, number_of_characters=1000) -> str:
        with torch.no_grad():
            characters = []

            idx2item = self.dictionary.idx2item

            # initial hidden state
            hidden = self.init_hidden(1)
            input = torch.rand(1, 1).mul(len(idx2item)).long()
            if torch.cuda.is_available():
                input = input.cuda()

            for i in range(number_of_characters):
                prediction, rnn_output, hidden = self.forward(input, hidden)
                word_weights = prediction.squeeze().data.div(1.0).exp().cpu()
                word_idx = torch.multinomial(word_weights, 1)[0]
                input.data.fill_(word_idx)
                word = idx2item[word_idx].decode('UTF-8')
                characters.append(word)

            return ''.join(characters)

In [15]:
!pip install tiny-tokenizer



In [16]:
#!pip install flair



In [17]:

import warnings
import logging

import torch.autograd as autograd
import torch.nn
import flair.nn
import torch

#import flair.embeddings
#from flair.data import Dictionary, Sentence, Token
#from flair.file_utils import cached_path

from typing import List, Tuple, Union

#from flair.training_utils import clear_embeddings

import os
current_path = os.getcwd() + "/ner/"

log = logging.getLogger(__name__)

START_TAG: str = '<START>'
STOP_TAG: str = '<STOP>'


def to_scalar(var):
    return var.view(-1).data.tolist()[0]


def argmax(vec):
    _, idx = torch.max(vec, 1)
    return to_scalar(idx)


def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
           torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))


def argmax_batch(vecs):
    _, idx = torch.max(vecs, 1)
    return idx


def log_sum_exp_batch(vecs):
    maxi = torch.max(vecs, 1)[0]
    maxi_bc = maxi[:, None].repeat(1, vecs.shape[1])
    recti_ = torch.log(torch.sum(torch.exp(vecs - maxi_bc), 1))
    return maxi + recti_


def pad_tensors(tensor_list, type_=torch.FloatTensor):
    ml = max([x.shape[0] for x in tensor_list])
    shape = [len(tensor_list), ml] + list(tensor_list[0].shape[1:])
    template = type_(*shape)
    template.fill_(0)
    lens_ = [x.shape[0] for x in tensor_list]
    for i, tensor in enumerate(tensor_list):
        template[i, :lens_[i]] = tensor

    return template, lens_


class SequenceTagger(torch.nn.Module):

    def __init__(self,
                 hidden_size: int,
                 embeddings: flair.embeddings.TokenEmbeddings,
                 tag_dictionary: Dictionary,
                 tag_type: str,
                 use_crf: bool = True,
                 use_rnn: bool = True,
                 rnn_layers: int = 1,
                 use_dropout: float = 0.0,
                 use_word_dropout: float = 0.05,
                 use_locked_dropout: float = 0.5,
                 ):

        super(SequenceTagger, self).__init__()

        self.use_rnn = use_rnn
        self.hidden_size = hidden_size
        self.use_crf: bool = use_crf
        self.rnn_layers: int = rnn_layers

        self.trained_epochs: int = 0

        self.embeddings = embeddings

        # set the dictionaries
        self.tag_dictionary: Dictionary = tag_dictionary
        self.tag_type: str = tag_type
        self.tagset_size: int = len(tag_dictionary)

        # initialize the network architecture
        self.nlayers: int = rnn_layers
        self.hidden_word = None

        # dropouts
        self.use_dropout: float = use_dropout
        self.use_word_dropout: float = use_word_dropout
        self.use_locked_dropout: float = use_locked_dropout

        if use_dropout > 0.0:
            self.dropout = torch.nn.Dropout(use_dropout)

        if use_word_dropout > 0.0:
            self.word_dropout = flair.nn.WordDropout(use_word_dropout)

        if use_locked_dropout > 0.0:
            self.locked_dropout = flair.nn.LockedDropout(use_locked_dropout)

        rnn_input_dim: int = self.embeddings.embedding_length

        self.relearn_embeddings: bool = True

        if self.relearn_embeddings:
            self.embedding2nn = torch.nn.Linear(rnn_input_dim, rnn_input_dim)

        # bidirectional LSTM on top of embedding layer
        self.rnn_type = 'LSTM'
        if self.rnn_type in ['LSTM', 'GRU']:

            if self.nlayers == 1:
                self.rnn = getattr(torch.nn, self.rnn_type)(rnn_input_dim, hidden_size,
                                                            num_layers=self.nlayers,
                                                            bidirectional=True)
            else:
                self.rnn = getattr(torch.nn, self.rnn_type)(rnn_input_dim, hidden_size,
                                                            num_layers=self.nlayers,
                                                            dropout=0.5,
                                                            bidirectional=True)

        # final linear map to tag space
        if self.use_rnn:
            self.linear = torch.nn.Linear(hidden_size * 2, len(tag_dictionary))
        else:
            self.linear = torch.nn.Linear(self.embeddings.embedding_length, len(tag_dictionary))

        if self.use_crf:
            self.transitions = torch.nn.Parameter(
                torch.randn(self.tagset_size, self.tagset_size))
            self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
            self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000

        #if torch.cuda.is_available():
        #    self.cuda()

    def save(self, model_file: str):
        model_state = {
            'state_dict': self.state_dict(),
            'embeddings': self.embeddings,
            'hidden_size': self.hidden_size,
            'tag_dictionary': self.tag_dictionary,
            'tag_type': self.tag_type,
            'use_crf': self.use_crf,
            'use_rnn': self.use_rnn,
            'rnn_layers': self.rnn_layers,
            'use_word_dropout': self.use_word_dropout,
            'use_locked_dropout': self.use_locked_dropout,
        }

        torch.save(model_state, model_file, pickle_protocol=4)


    @classmethod
    def load_from_file(cls, model_file):
        # suppress torch warnings:
        # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            state = torch.load(model_file, map_location={'cuda:0': 'cpu'})

        use_dropout = 0.0 if not 'use_dropout' in state.keys() else state['use_dropout']
        use_word_dropout = 0.0 if not 'use_word_dropout' in state.keys() else state['use_word_dropout']
        use_locked_dropout = 0.0 if not 'use_locked_dropout' in state.keys() else state['use_locked_dropout']

        model = SequenceTagger(
            hidden_size=state['hidden_size'],
            embeddings=state['embeddings'],
            tag_dictionary=state['tag_dictionary'],
            tag_type=state['tag_type'],
            use_crf=state['use_crf'],
            use_rnn=state['use_rnn'],
            rnn_layers=state['rnn_layers'],
            use_dropout=use_dropout,
            use_word_dropout=use_word_dropout,
            use_locked_dropout=use_locked_dropout,
        )

        model.load_state_dict(state['state_dict'])
        model.eval()
        #if torch.cuda.is_available():
        #    model = model.cuda()
        return model

    def forward(self, sentences: List[Sentence]):

        self.zero_grad()

        self.embeddings.embed(sentences)

        # first, sort sentences by number of tokens
        sentences.sort(key=lambda x: len(x), reverse=True)
        longest_token_sequence_in_batch: int = len(sentences[0])

        lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
        tag_list: List = []

        # initialize zero-padded word embeddings tensor
        sentence_tensor = torch.zeros([len(sentences),
                                       longest_token_sequence_in_batch,
                                       self.embeddings.embedding_length],
                                      dtype=torch.float)

        for s_id, sentence in enumerate(sentences):

            # fill values with word embeddings
            sentence_tensor[s_id][:len(sentence)] = torch.cat([token.get_embedding().unsqueeze(0)
                                                               for token in sentence], 0)

            # get the tags in this sentence
            tag_idx: List[int] = [self.tag_dictionary.get_idx_for_item(token.get_tag(self.tag_type).value)
                                  for token in sentence]
            # add tags as tensor
            #if torch.cuda.is_available():
            #    tag_list.append(torch.cuda.LongTensor(tag_idx))
            #else:
            if True:
                tag_list.append(torch.LongTensor(tag_idx))

        sentence_tensor = sentence_tensor.transpose_(0, 1)
        #if torch.cuda.is_available():
        #    sentence_tensor = sentence_tensor.cuda()

        # --------------------------------------------------------------------
        # FF PART
        # --------------------------------------------------------------------
        if self.use_dropout > 0.0:
            sentence_tensor = self.dropout(sentence_tensor)
        if self.use_word_dropout > 0.0:
            sentence_tensor = self.word_dropout(sentence_tensor)
        if self.use_locked_dropout > 0.0:
            sentence_tensor = self.locked_dropout(sentence_tensor)

        if self.relearn_embeddings:
            sentence_tensor = self.embedding2nn(sentence_tensor)

        if self.use_rnn:
            packed = torch.nn.utils.rnn.pack_padded_sequence(sentence_tensor, lengths)

            rnn_output, hidden = self.rnn(packed)

            sentence_tensor, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output)

            if self.use_dropout > 0.0:
                sentence_tensor = self.dropout(sentence_tensor)
            # word dropout only before LSTM - TODO: more experimentation needed
            # if self.use_word_dropout > 0.0:
            #     sentence_tensor = self.word_dropout(sentence_tensor)
            if self.use_locked_dropout > 0.0:
                sentence_tensor = self.locked_dropout(sentence_tensor)

        features = self.linear(sentence_tensor)

        return features.transpose_(0, 1), lengths, tag_list

    def _score_sentence(self, feats, tags, lens_):

        if False:#torch.cuda.is_available():
            start = torch.cuda.LongTensor([
                self.tag_dictionary.get_idx_for_item(START_TAG)
            ])
            start = start[None, :].repeat(tags.shape[0], 1)

            stop = torch.cuda.LongTensor([
                self.tag_dictionary.get_idx_for_item(STOP_TAG)
            ])
            stop = stop[None, :].repeat(tags.shape[0], 1)

            pad_start_tags = \
                torch.cat([start, tags], 1)
            pad_stop_tags = \
                torch.cat([tags, stop], 1)
        else:
            start = torch.LongTensor([
                self.tag_dictionary.get_idx_for_item(START_TAG)
            ])
            start = start[None, :].repeat(tags.shape[0], 1)

            stop = torch.LongTensor([
                self.tag_dictionary.get_idx_for_item(STOP_TAG)
            ])

            stop = stop[None, :].repeat(tags.shape[0], 1)

            pad_start_tags = torch.cat([start, tags], 1)
            pad_stop_tags = torch.cat([tags, stop], 1)

        for i in range(len(lens_)):
            pad_stop_tags[i, lens_[i]:] = \
                self.tag_dictionary.get_idx_for_item(STOP_TAG)

        score = torch.FloatTensor(feats.shape[0])
        #if torch.cuda.is_available():
        #    score = score.cuda()

        for i in range(feats.shape[0]):
            r = torch.LongTensor(range(lens_[i]))
            #if torch.cuda.is_available():
            #    r = r.cuda()

            score[i] = \
                torch.sum(
                    self.transitions[pad_stop_tags[i, :lens_[i] + 1], pad_start_tags[i, :lens_[i] + 1]]
                ) + \
                torch.sum(feats[i, r, tags[i, :lens_[i]]])

        return score

    def viterbi_decode(self, feats):
        backpointers = []
        backscores = []

        init_vvars = torch.Tensor(1, self.tagset_size).fill_(-10000.)
        init_vvars[0][self.tag_dictionary.get_idx_for_item(START_TAG)] = 0
        forward_var = autograd.Variable(init_vvars)
        #if torch.cuda.is_available():
        #    forward_var = forward_var.cuda()

        import torch.nn.functional as F
        for feat in feats:
            next_tag_var = forward_var.view(1, -1).expand(self.tagset_size, self.tagset_size) + self.transitions
            _, bptrs_t = torch.max(next_tag_var, dim=1)
            bptrs_t = bptrs_t.squeeze().data.cpu().numpy()
            next_tag_var = next_tag_var.data.cpu().numpy()
            viterbivars_t = next_tag_var[range(len(bptrs_t)), bptrs_t]
            viterbivars_t = autograd.Variable(torch.FloatTensor(viterbivars_t))
            #if torch.cuda.is_available():
            #    viterbivars_t = viterbivars_t.cuda()
            forward_var = viterbivars_t + feat
            backscores.append(forward_var)
            backpointers.append(bptrs_t)

        terminal_var = forward_var + self.transitions[self.tag_dictionary.get_idx_for_item(STOP_TAG)]
        terminal_var.data[self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000.
        terminal_var.data[self.tag_dictionary.get_idx_for_item(START_TAG)] = -10000.
        best_tag_id = argmax(terminal_var.unsqueeze(0))

        best_path = [best_tag_id]

        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)

        best_scores = []
        for backscore in backscores:
            softmax = F.softmax(backscore, dim=0)
            _, idx = torch.max(backscore, 0)
            prediction = idx.item()
            best_scores.append(softmax[prediction].item())

        start = best_path.pop()
        assert start == self.tag_dictionary.get_idx_for_item(START_TAG)
        best_path.reverse()
        return best_scores, best_path

    def neg_log_likelihood(self, sentences: List[Sentence]):
        features, lengths, tags = self.forward(sentences)

        if self.use_crf:

            # pad tags if using batch-CRF decoder
            if False:#torch.cuda.is_available():
                tags, _ = pad_tensors(tags, torch.cuda.LongTensor)
            else:
                tags, _ = pad_tensors(tags, torch.LongTensor)

            forward_score = self._forward_alg(features, lengths)
            gold_score = self._score_sentence(features, tags, lengths)

            score = forward_score - gold_score

            return score.sum()

        else:

            score = 0
            for sentence_feats, sentence_tags, sentence_length in zip(features, tags, lengths):
                sentence_feats = sentence_feats[:sentence_length]

                if False:#torch.cuda.is_available():
                    tag_tensor = autograd.Variable(torch.cuda.LongTensor(sentence_tags))
                else:
                    tag_tensor = autograd.Variable(torch.LongTensor(sentence_tags))
                score += torch.nn.functional.cross_entropy(sentence_feats, tag_tensor)

            return score

    def _forward_alg(self, feats, lens_):

        init_alphas = torch.Tensor(self.tagset_size).fill_(-10000.)
        init_alphas[self.tag_dictionary.get_idx_for_item(START_TAG)] = 0.

        forward_var = torch.FloatTensor(
            feats.shape[0],
            feats.shape[1] + 1,
            feats.shape[2],
        ).fill_(0)

        forward_var[:, 0, :] = init_alphas[None, :].repeat(feats.shape[0], 1)

        #if torch.cuda.is_available():
        #    forward_var = forward_var.cuda()

        transitions = self.transitions.view(
            1,
            self.transitions.shape[0],
            self.transitions.shape[1],
        ).repeat(feats.shape[0], 1, 1)

        for i in range(feats.shape[1]):
            emit_score = feats[:, i, :]

            tag_var = \
                emit_score[:, :, None].repeat(1, 1, transitions.shape[2]) + \
                transitions + \
                forward_var[:, i, :][:, :, None].repeat(1, 1, transitions.shape[2]).transpose(2, 1)

            max_tag_var, _ = torch.max(tag_var, dim=2)

            tag_var = tag_var - \
                      max_tag_var[:, :, None].repeat(1, 1, transitions.shape[2])

            agg_ = torch.log(torch.sum(torch.exp(tag_var), dim=2))

            cloned = forward_var.clone()
            cloned[:, i + 1, :] = max_tag_var + agg_

            forward_var = cloned

        forward_var = forward_var[range(forward_var.shape[0]), lens_, :]

        terminal_var = forward_var + \
                       self.transitions[self.tag_dictionary.get_idx_for_item(STOP_TAG)][None, :].repeat(
                           forward_var.shape[0], 1)

        alpha = log_sum_exp_batch(terminal_var)

        return alpha

    def predict(self, sentences: Union[List[Sentence], Sentence], mini_batch_size=32) -> List[Sentence]:

        with torch.no_grad():
            if type(sentences) is Sentence:
                sentences = [sentences]

            filtered_sentences = self._filter_empty_sentences(sentences)

            # remove previous embeddings
            clear_embeddings(filtered_sentences, also_clear_word_embeddings=True)

            # make mini-batches
            batches = [filtered_sentences[x:x + mini_batch_size] for x in
                       range(0, len(filtered_sentences), mini_batch_size)]

            for batch in batches:
                scores, predicted_ids = self._predict_scores_batch(batch)
                all_tokens = []
                for sentence in batch:
                    all_tokens.extend(sentence.tokens)

                for (token, score, predicted_id) in zip(all_tokens, scores, predicted_ids):
                    token: Token = token
                    # get the predicted tag
                    predicted_tag = self.tag_dictionary.get_item_for_index(predicted_id)
                    token.add_tag(self.tag_type, predicted_tag, score)

            return sentences

    def _predict_scores_batch(self, sentences: List[Sentence]):
        feature, lengths, tags = self.forward(sentences)

        all_confidences = []
        all_tags_seqs = []

        for feats, length in zip(feature, lengths):

            if self.use_crf:
                confidences, tag_seq = self.viterbi_decode(feats[:length])
            else:
                import torch.nn.functional as F

                tag_seq = []
                confidences = []
                for backscore in feats[:length]:
                    softmax = F.softmax(backscore, dim=0)
                    _, idx = torch.max(backscore, 0)
                    prediction = idx.item()
                    tag_seq.append(prediction)
                    confidences.append(softmax[prediction].item())

            all_tags_seqs.extend(tag_seq)
            all_confidences.extend(confidences)

        return all_confidences, all_tags_seqs

    @staticmethod
    def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
        filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
        if len(sentences) != len(filtered_sentences):
            log.warning('Ignore {} sentence(s) with no tokens.'.format(len(sentences) - len(filtered_sentences)))
        return filtered_sentences

    @staticmethod
    def load(model: str):
        model_file = None
        aws_resource_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models-v0.2'
        if model == "JNLPBA":
            model_file = "resources/taggers/pubmed/+JNLPBA/final-model.pt"
        if model == "BC2GM":
            model_file = "resources/taggers/pubmed/+BC2GM/final-model.pt"
        if model == "BC5CDR":
            model_file = "resources/taggers/pubmed/+BC5CDR/final-model.pt"
        if model == 'protocol':
            model_file = current_path + "/resources/taggers/c4r/final-model.pt"	    
        if model.lower() == "medner":
            model_file = current_path + "/resources/taggers/medner/final-model.pt"
        if model.lower() == "medner_big":
            model_file = "resources/taggers/result/elmo+flair+bigcorpus/final-model.pt"
        if model.lower() == 'ner':
            base_path = '/'.join([aws_resource_path,
                                  'NER-conll03--h256-l1-b32-%2Bglove%2Bnews-forward%2Bnews-backward--v0.2',
                                  'en-ner-conll03-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'ner-fast':
            base_path = '/'.join([aws_resource_path,
                                  'NER-conll03--h256-l1-b32-experimental--fast-v0.2',
                                  'en-ner-fast-conll03-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'ner-ontonotes':
            base_path = '/'.join([aws_resource_path,
                                  'NER-ontoner--h256-l1-b32-%2Bcrawl%2Bnews-forward%2Bnews-backward--v0.2',
                                  'en-ner-ontonotes-v0.3.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'ner-ontonotes-fast':
            base_path = '/'.join([aws_resource_path,
                                  'NER-ontoner--h256-l1-b32-%2Bcrawl%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
                                  'en-ner-ontonotes-fast-v0.3.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'pos':
            base_path = '/'.join([aws_resource_path,
                                  'POS-ontonotes--h256-l1-b32-%2Bmix-forward%2Bmix-backward--v0.2',
                                  'en-pos-ontonotes-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'pos-fast':
            base_path = '/'.join([aws_resource_path,
                                  'POS-ontonotes--h256-l1-b32-%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
                                  'en-pos-ontonotes-fast-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'frame':
            base_path = '/'.join([aws_resource_path,
                                  'FRAME-conll12--h256-l1-b8-%2Bnews%2Bnews-forward%2Bnews-backward--v0.2',
                                  'en-frame-ontonotes-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'frame-fast':
            base_path = '/'.join([aws_resource_path,
                                  'FRAME-conll12--h256-l1-b8-%2Bnews%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
                                  'en-frame-ontonotes-fast-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'chunk':
            base_path = '/'.join([aws_resource_path,
                                  'NP-conll2000--h256-l1-b32-%2Bnews-forward%2Bnews-backward--v0.2',
                                  'en-chunk-conll2000-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'chunk-fast':
            base_path = '/'.join([aws_resource_path,
                                  'NP-conll2000--h256-l1-b32-%2Bnews-forward-fast%2Bnews-backward-fast--v0.2',
                                  'en-chunk-conll2000-fast-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'de-pos':
            base_path = '/'.join([aws_resource_path,
                                  'UPOS-udgerman--h256-l1-b8-%2Bgerman-forward%2Bgerman-backward--v0.2',
                                  'de-pos-ud-v0.2.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'de-ner':
            base_path = '/'.join([aws_resource_path,
                                  'NER-conll03ger--h256-l1-b32-%2Bde-fasttext%2Bgerman-forward%2Bgerman-backward--v0.2',
                                  'de-ner-conll03-v0.3.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model.lower() == 'de-ner-germeval':
            base_path = '/'.join([aws_resource_path,
                                  'NER-germeval--h256-l1-b32-%2Bde-fasttext%2Bgerman-forward%2Bgerman-backward--v0.2',
                                  'de-ner-germeval-v0.3.pt'])
            model_file = cached_path(base_path, cache_dir='models')

        if model_file is not None:
            tagger: SequenceTagger = SequenceTagger.load_from_file(model_file)
            return tagger


In [8]:
import warnings
import logging
from typing import List, Union

import torch
import torch.nn as nn

#import flair.embeddings
#from flair.data import Dictionary, Sentence, Label
#from flair.training_utils import convert_labels_to_one_hot, clear_embeddings


log = logging.getLogger(__name__)


class TextClassifier(nn.Module):
    """
    Text Classification Model
    The model takes word embeddings, puts them into an LSTM to obtain a text representation, and puts the
    text representation in the end into a linear layer to get the actual class label.
    The model can handle single and multi class data sets.
    """

    def __init__(self,
                 document_embeddings: flair.embeddings.DocumentEmbeddings,
                 label_dictionary: Dictionary,
                 multi_label: bool):

        super(TextClassifier, self).__init__()

        self.document_embeddings = document_embeddings
        self.label_dictionary: Dictionary = label_dictionary
        self.multi_label = multi_label

        self.document_embeddings: flair.embeddings.DocumentLSTMEmbeddings = document_embeddings

        self.decoder = nn.Linear(self.document_embeddings.embedding_length, len(self.label_dictionary))

        self._init_weights()

        if multi_label:
            self.loss_function = nn.BCELoss()
        else:
            self.loss_function = nn.CrossEntropyLoss()

        # auto-spawn on GPU if available
        if torch.cuda.is_available():
            self.cuda()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.decoder.weight)

    def forward(self, sentences) -> List[List[float]]:
        self.document_embeddings.embed(sentences)

        text_embedding_list = [sentence.get_embedding() for sentence in sentences]
        text_embedding_tensor = torch.cat(text_embedding_list, 0)

        if torch.cuda.is_available():
            text_embedding_tensor = text_embedding_tensor.cuda()

        label_scores = self.decoder(text_embedding_tensor)

        return label_scores

    def save(self, model_file: str):
        """
        Saves the current model to the provided file.
        :param model_file: the model file
        """
        model_state = {
            'state_dict': self.state_dict(),
            'document_embeddings': self.document_embeddings,
            'label_dictionary': self.label_dictionary,
            'multi_label': self.multi_label,
        }
        torch.save(model_state, model_file, pickle_protocol=4)

    @classmethod
    def load_from_file(cls, model_file):
        """
        Loads the model from the given file.
        :param model_file: the model file
        :return: the loaded text classifier model
        """

        # ATTENTION: suppressing torch serialization warnings. This needs to be taken out once we sort out recursive
        # serialization of torch objects
        # https://docs.python.org/3/library/warnings.html#temporarily-suppressing-warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            if torch.cuda.is_available():
                state = torch.load(model_file)
            else:
                state = torch.load(model_file, map_location={'cuda:0': 'cpu'})

        model = TextClassifier(
            document_embeddings=state['document_embeddings'],
            label_dictionary=state['label_dictionary'],
            multi_label=state['multi_label']
        )

        model.load_state_dict(state['state_dict'])
        model.eval()
        return model

    def predict(self, sentences: Union[Sentence, List[Sentence]], mini_batch_size: int = 32) -> List[Sentence]:
        """
        Predicts the class labels for the given sentences. The labels are directly added to the sentences.
        :param sentences: list of sentences
        :param mini_batch_size: mini batch size to use
        :return: the list of sentences containing the labels
        """
        with torch.no_grad():
            if type(sentences) is Sentence:
                sentences = [sentences]

            filtered_sentences = self._filter_empty_sentences(sentences)

            batches = [filtered_sentences[x:x + mini_batch_size] for x in range(0, len(filtered_sentences), mini_batch_size)]

            for batch in batches:
                scores = self.forward(batch)
                predicted_labels = self.obtain_labels(scores)

                for (sentence, labels) in zip(batch, predicted_labels):
                    sentence.labels = labels

                clear_embeddings(batch)

            return sentences

    @staticmethod
    def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
        filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
        if len(sentences) != len(filtered_sentences):
            log.warning('Ignore {} sentence(s) with no tokens.'.format(len(sentences) - len(filtered_sentences)))
        return filtered_sentences

    def calculate_loss(self, scores: List[List[float]], sentences: List[Sentence]) -> float:
        """
        Calculates the loss.
        :param scores: the prediction scores from the model
        :param sentences: list of sentences
        :return: loss value
        """
        if self.multi_label:
            return self._calculate_multi_label_loss(scores, sentences)

        return self._calculate_single_label_loss(scores, sentences)

    def obtain_labels(self, scores: List[List[float]]) -> List[List[Label]]:
        """
        Predicts the labels of sentences.
        :param scores: the prediction scores from the model
        :return: list of predicted labels
        """

        if self.multi_label:
            return [self._get_multi_label(s) for s in scores]

        return [self._get_single_label(s) for s in scores]

    def _get_multi_label(self, label_scores) -> List[Label]:
        labels = []

        sigmoid = torch.nn.Sigmoid()

        results = list(map(lambda x: sigmoid(x), label_scores))
        for idx, conf in enumerate(results):
            if conf > 0.5:
                label = self.label_dictionary.get_item_for_index(idx)
                labels.append(Label(label, conf.item()))

        return labels

    def _get_single_label(self, label_scores) -> List[Label]:
        conf, idx = torch.max(label_scores, 0)
        label = self.label_dictionary.get_item_for_index(idx.item())

        return [Label(label, conf.item())]

    def _calculate_multi_label_loss(self, label_scores, sentences: List[Sentence]) -> float:
        sigmoid = nn.Sigmoid()
        return self.loss_function(sigmoid(label_scores), self._labels_to_one_hot(sentences))

    def _calculate_single_label_loss(self, label_scores, sentences: List[Sentence]) -> float:
        return self.loss_function(label_scores, self._labels_to_indices(sentences))

    def _labels_to_one_hot(self, sentences: List[Sentence]):
        label_list = [sentence.get_label_names() for sentence in sentences]
        one_hot = convert_labels_to_one_hot(label_list, self.label_dictionary)
        one_hot = [torch.FloatTensor(l).unsqueeze(0) for l in one_hot]
        one_hot = torch.cat(one_hot, 0)
        if torch.cuda.is_available():
            one_hot = one_hot.cuda()
        return one_hot

    def _labels_to_indices(self, sentences: List[Sentence]):
        indices = [
            torch.LongTensor([self.label_dictionary.get_idx_for_item(label.value) for label in sentence.labels])
            for sentence in sentences
        ]

        vec = torch.cat(indices, 0)
        if torch.cuda.is_available():
            vec = vec.cuda()

        return vec


In [29]:
#!pip install flair

In [18]:
#from flair.data import Sentence
#from flair.models import SequenceTagger
#import flair
import torch
import os
import sys

flair.device = torch.device('cpu')
tagger: SequenceTagger = SequenceTagger.load("ner") # protocol ner

#to be removed
sys.argv[1] = './pipeline/'
sys.argv[2] = './output/'
print ("load protocol model")
outfolder = sys.argv[2]
for file in os.listdir(sys.argv[1]):
    print(file)
    if not file.endswith(".txt"):
        continue
    #print (file)
    all_entities = []
    offset = 0
    print (sys.argv[1]+file) 
    for line in open(sys.argv[1]+file).readlines(): 
        print ("--------")
        print (line.strip())
        sentence: Sentence = Sentence(line.strip())
        tagger.predict(sentence)
        print (sentence.to_tagged_string())
        entities = sentence.to_offset_tags(offset)
        #print (entities)
        new_entities = []
        for ent in entities:
            #if ent[3] not in ["this study", "that study", "the disease", "the finding", "the study", "the analysis", "treatment", "test", "testing", "treatment", "symptoms", "Lilly"]:
            new_entities.append(ent)
        offset += len(line)
        #print(sentence.to_tagged_string())
        all_entities += new_entities
        #print (entities)
    
    filew = open(outfolder + file.replace(".txt", ".ann"), "w")
    filew2 = open(outfolder + file, "w")
    txt = open(sys.argv[1]+file).read()
    filew2.write(txt)
    filew2.close
    index = 1
    for entity in all_entities: 
        fixed = False
        if txt[int(entity[1]):int(entity[2])] != entity[3]:
            for offset in [-5, -4, -3, -2, -1, 1, 2, 3,4,5]:
                if (offset < 0 and int(entity[1]) >= abs(offset) and txt[int(entity[1])-offset:int(entity[2])-offset] == entity[3]) or (offset > 0 and int(entity[2]) <= len(txt)-offset and txt[int(entity[1])+offset:int(entity[2])+offset] == entity[3]):
                    fixed = True
                    entity[1] = str(int(entity[1]) + offset)
                    entity[2] = str(int(entity[2]) + offset)
                    
            #if not fixed:
            #    print ("does not match found:")
            #    print (entity)
            #    print (txt[int(entity[1]):int(entity[2])] )
        else:
            fixed = True
        if fixed:
            filew.write("T"+str(index)+"\t"+entity[0]+" "+entity[1]+" "+entity[2]+"\t"+entity[3]+"\n")
        index += 1
    filew.close()

load protocol model
.ipynb_checkpoints
requirements4.txt
./pipeline/requirements4.txt
--------
opencv-python==4.1.0.25


NameError: name 'clear_embeddings' is not defined