# Install

In [None]:
!pip install datasets
!pip install konlpy
!pip install pororo
!pip install python-mecab-ko
!pip install rank_bm25

In [None]:
!git clone https://github.com/SOMJANG/Mecab-ko-for-Google-Colab.git

In [None]:
cd Mecab-ko-for-Google-Colab

In [None]:
!bash install_mecab-ko_on_colab190912.sh

# Library

In [None]:
import zipfile
from datasets import load_from_disk
import json
import pandas as pd
import re
from tqdm import tqdm
from konlpy.tag import Mecab
from rank_bm25 import BM25Plus
from pororo import Pororo

# Data pre-processing

In [None]:
from datasets import load_dataset

dataset = load_dataset('squad_kor_v1')

dataset

In [None]:
mecab = Mecab()
tokenizer = Pororo(task="tokenization", lang="ko", model="mecab.bpe64k.ko")

In [None]:
train_dataset = pd.DataFrame({'title' : dataset['train']['title'],
                              'context' : dataset['train']['context'],
                              'question' : dataset['train']['question'],
                              'answers' : dataset['train']['answers']})

train_dataset

In [None]:
dev_dataset = pd.DataFrame({'title' : dataset['validation']['title'],
                              'context' : dataset['validation']['context'],
                              'question' : dataset['validation']['question'],
                              'answers' : dataset['validation']['answers']})

dev_dataset

In [None]:
df = pd.concat([train_dataset,dev_dataset])

df = df.reset_index()

del df['index']

df

In [None]:
corpus = df['context'].drop_duplicates().to_list()

In [None]:
corpus_to_id = {corpus[num]:num for num in range(len(corpus))}

In [None]:
tokenized_corpus = [tokenizer(i) for i in corpus]

In [None]:
bm25 = BM25Plus(tokenized_corpus)

# Train Dataset

In [None]:
dpr_train = []

In [None]:
for num in tqdm(range(len(train_dataset))):
    data_num = {}

    bm25_score = bm25.get_scores(tokenizer(train_dataset['question'][num]))

    bm25_score_sorted = sorted([(i,bm25_score[i]) for i in range(len(bm25_score))],key=lambda x : x[1])
    
    # dataset
    data_num['dataset'] = 'dpr_train'

    # question
    data_num['question'] = train_dataset['question'][num]

    # answers
    data_num['answers'] = train_dataset['answers'][num]['text']

    # positive_ctxs
    data_num['positive_ctxs'] = [{
        'title' : train_dataset['title'][num],
        'text' : train_dataset['context'][num],
        'score' : bm25_score[corpus_to_id[train_dataset['context'][num]]],
        'title_score' : 0,
        'passage_id' : corpus_to_id[train_dataset['context'][num]]
    }]

    # negative_ctxs

    negative_ctxs_tmp = []

    for neg,scr in bm25_score_sorted[:3]:
        negative_ctxs_tmp.append({
            'title' : '',
            'text' : corpus[neg],
            'score' : scr,
            'title_score' : 0,
            'passage_id' : neg
        })

    data_num['negative_ctxs'] = negative_ctxs_tmp

    # hard_negative_ctxs

    hard_negative_ctxs_tmp = []

    for hrd,scr in bm25_score_sorted[-6:]:
        if hrd != num:
            hard_negative_ctxs_tmp.append({
                'title' : '',
                'text' : corpus[hrd],
                'score' : scr,
                'title_score' : 0,
                'passage_id' : hrd
            })

    data_num['hard_negative_ctxs'] = hard_negative_ctxs_tmp

    dpr_train.append(data_num)

In [None]:
with open('/content/drive/MyDrive/Colab Notebooks/train/dpr_train.json', 'w', encoding="utf-8") as f:
    json.dump(dpr_train, f)

# Dev Dataset

In [None]:
dpr_dev = []

In [None]:
for num in tqdm(range(len(dev_dataset))):
    data_num = {}

    bm25_score = bm25.get_scores(tokenizer(dev_dataset['question'][num]))

    bm25_score_sorted = sorted([(i,bm25_score[i]) for i in range(len(bm25_score))],key=lambda x : x[1])
    
    # dataset
    data_num['dataset'] = 'dpr_dev'

    # question
    data_num['question'] = dev_dataset['question'][num]

    # answers
    data_num['answers'] = dev_dataset['answers'][num]['text']

    # positive_ctxs
    data_num['positive_ctxs'] = [{
        'title' : dev_dataset['title'][num],
        'text' : dev_dataset['context'][num],
        'score' : bm25_score[corpus_to_id[dev_dataset['context'][num]]],
        'title_score' : 0,
        'passage_id' : corpus_to_id[dev_dataset['context'][num]]
    }]

    # negative_ctxs

    negative_ctxs_tmp = []

    for neg,scr in bm25_score_sorted[:3]:
        negative_ctxs_tmp.append({
            'title' : '',
            'text' : corpus[neg],
            'score' : scr,
            'title_score' : 0,
            'passage_id' : neg
        })

    data_num['negative_ctxs'] = negative_ctxs_tmp

    # hard_negative_ctxs

    hard_negative_ctxs_tmp = []

    for hrd,scr in bm25_score_sorted[-6:]:
        if hrd != num:
            hard_negative_ctxs_tmp.append({
                'title' : '',
                'text' : corpus[hrd],
                'score' : scr,
                'title_score' : 0,
                'passage_id' : hrd
            })

    data_num['hard_negative_ctxs'] = hard_negative_ctxs_tmp

    dpr_dev.append(data_num)

In [None]:
with open('/content/drive/MyDrive/Colab Notebooks/dev/dpr_dev.json', 'w', encoding="utf-8") as f:
    json.dump(dpr_dev, f)

# Dense Passage Retriever Training

In [None]:
!nvidia-smi

In [None]:
!pip install git+https://github.com/deepset-ai/haystack.git
!pip install urllib3==1.25.4
!python -m pip install elasticsearch

In [None]:
from haystack.retriever.dense import DensePassageRetriever
from haystack.preprocessor.utils import fetch_archive_from_http

In [None]:
! wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.9.2-linux-x86_64.tar.gz -q
! tar -xzf elasticsearch-7.9.2-linux-x86_64.tar.gz
! chown -R daemon:daemon elasticsearch-7.9.2

In [None]:
import os
from subprocess import Popen, PIPE, STDOUT
es_server = Popen(['elasticsearch-7.9.2/bin/elasticsearch'],
                   stdout=PIPE, stderr=STDOUT,
                   preexec_fn=lambda: os.setuid(1)
                  )

! sleep 30

In [None]:
! /content/elasticsearch-7.9.2/bin/elasticsearch-plugin install analysis-nori

In [None]:
es_server.kill()

In [None]:
import os
from subprocess import Popen, PIPE, STDOUT
es_server = Popen(['elasticsearch-7.9.2/bin/elasticsearch'],
                   stdout=PIPE, stderr=STDOUT,
                   preexec_fn=lambda: os.setuid(1)
                  )

! sleep 30

In [None]:
from elasticsearch import Elasticsearch

es = Elasticsearch("localhost:9200")

In [None]:
es.info()

In [None]:
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="document", analyzer='nori')

In [None]:
es.indices.get('document')

In [None]:
# f = zipfile.ZipFile('/content/drive/MyDrive/Colab Notebooks/data.zip')
# f.extractall('/content')
# f.close()

In [None]:
# with open('/content/data/wikipedia_documents.json', 'r') as f:
#     wiki_data = json.load(f)

In [None]:
dicts = [{'text':df['context'][num],'meta':{'name':df['title'][num]}} for num in range(len(df))]

In [None]:
# dicts = [{'text':wiki_data[str(i)]['text'],'meta':{'name':wiki_data[str(i)]['title']}} for i in range(len(wiki_data))]

In [None]:
document_store.write_documents(dicts)

In [None]:
doc_dir = "/content/drive/MyDrive/Colab Notebooks/"

train_filename = "train/dpr_train.json"
dev_filename = "dev/dpr_dev.json"

query_model = "voidful/dpr-question_encoder-bert-base-multilingual"
passage_model = "voidful/dpr-ctx_encoder-bert-base-multilingual"

save_dir = "/content/drive/MyDrive/Colab Notebooks/saved_models/dpr"

## LanguageModel

In [None]:
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors,  The HuggingFace Inc. Team and deepset Team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Acknowledgements: Many of the modeling parts here come from the great transformers repository: https://github.com/huggingface/transformers.
Thanks for the great work! """

from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import os
import io
from pathlib import Path
from collections import OrderedDict

from dotmap import DotMap
from tqdm import tqdm
import copy
import numpy as np
import torch
from torch import nn

logger = logging.getLogger(__name__)

from transformers import (
    BertModel, BertConfig,
    RobertaModel, RobertaConfig,
    XLNetModel, XLNetConfig,
    AlbertModel, AlbertConfig,
    XLMRobertaModel, XLMRobertaConfig,
    DistilBertModel, DistilBertConfig,
    ElectraModel, ElectraConfig,
    CamembertModel, CamembertConfig
)

from transformers import AutoModel, AutoConfig
from transformers.modeling_utils import SequenceSummary
from transformers.models.bert.tokenization_bert import load_vocab
import transformers

from farm.modeling import wordembedding_utils
from farm.modeling.wordembedding_utils import s3e_pooling

# These are the names of the attributes in various model configs which refer to the number of dimensions
# in the output vectors
OUTPUT_DIM_NAMES = ["dim", "hidden_size", "d_model"]


class LanguageModel(nn.Module):
    """
    The parent class for any kind of model that can embed language into a semantic vector space. Practically
    speaking, these models read in tokenized sentences and return vectors that capture the meaning of sentences
    or of tokens.
    """

    subclasses = {}

    def __init_subclass__(cls, **kwargs):
        """ This automatically keeps track of all available subclasses.
        Enables generic load() or all specific LanguageModel implementation.
        """
        super().__init_subclass__(**kwargs)
        cls.subclasses[cls.__name__] = cls

    def forward(self, input_ids, padding_mask, **kwargs):
        raise NotImplementedError

    @classmethod
    def from_scratch(cls, model_type, vocab_size):
        if model_type.lower() == "bert":
            model = Bert
        return model.from_scratch(vocab_size)

    @classmethod
    def load(cls, pretrained_model_name_or_path, revision=None, n_added_tokens=0, language_model_class=None, **kwargs):
        """
        Load a pretrained language model either by

        1. specifying its name and downloading it
        2. or pointing to the directory it is saved in.

        Available remote models:

        * bert-base-uncased
        * bert-large-uncased
        * bert-base-cased
        * bert-large-cased
        * bert-base-multilingual-uncased
        * bert-base-multilingual-cased
        * bert-base-chinese
        * bert-base-german-cased
        * roberta-base
        * roberta-large
        * xlnet-base-cased
        * xlnet-large-cased
        * xlm-roberta-base
        * xlm-roberta-large
        * albert-base-v2
        * albert-large-v2
        * distilbert-base-german-cased
        * distilbert-base-multilingual-cased
        * google/electra-small-discriminator
        * google/electra-base-discriminator
        * google/electra-large-discriminator
        * facebook/dpr-question_encoder-single-nq-base
        * facebook/dpr-ctx_encoder-single-nq-base

        See all supported model variations here: https://huggingface.co/models

        The appropriate language model class is inferred automatically from model config
        or can be manually supplied via `language_model_class`.

        :param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
        :type pretrained_model_name_or_path: str
        :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :type revision: str
        :param language_model_class: (Optional) Name of the language model class to load (e.g. `Bert`)
        :type language_model_class: str

        """
        kwargs["revision"] = revision
        logger.info("")
        logger.info("LOADING MODEL")
        logger.info("=============")
        config_file = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(config_file):
            logger.info(f"Model found locally at {pretrained_model_name_or_path}")
            # it's a local directory in FARM format
            config = json.load(open(config_file))
            language_model = cls.subclasses[config["name"]].load(pretrained_model_name_or_path)
        else:
            logger.info(f"Could not find {pretrained_model_name_or_path} locally.")
            logger.info(f"Looking on Transformers Model Hub (in local cache and online)...")
            if language_model_class is None:
                language_model_class = cls.get_language_model_class(pretrained_model_name_or_path)

            if language_model_class:
                language_model = cls.subclasses[language_model_class].load(pretrained_model_name_or_path, **kwargs)
            else:
                language_model = None

        if not language_model:
            raise Exception(
                f"Model not found for {pretrained_model_name_or_path}. Either supply the local path for a saved "
                f"model or one of bert/roberta/xlnet/albert/distilbert models that can be downloaded from remote. "
                f"Ensure that the model class name can be inferred from the directory name when loading a "
                f"Transformers' model. Here's a list of available models: "
                f"https://farm.deepset.ai/api/modeling.html#farm.modeling.language_model.LanguageModel.load"
            )
        else:
            logger.info(f"Loaded {pretrained_model_name_or_path}")

        # resize embeddings in case of custom vocab
        if n_added_tokens != 0:
            # TODO verify for other models than BERT
            model_emb_size = language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings
            vocab_size = model_emb_size + n_added_tokens
            logger.info(
                f"Resizing embedding layer of LM from {model_emb_size} to {vocab_size} to cope with custom vocab.")
            language_model.model.resize_token_embeddings(vocab_size)
            # verify
            model_emb_size = language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings
            assert vocab_size == model_emb_size

        return language_model

    @staticmethod
    def get_language_model_class(model_name_or_path):
        # it's transformers format (either from model hub or local)
        model_name_or_path = str(model_name_or_path)

        config = AutoConfig.from_pretrained(model_name_or_path)
        model_type = config.model_type
        if model_type == "xlm-roberta":
            language_model_class = "XLMRoberta"
        elif model_type == "roberta":
            if "mlm" in model_name_or_path.lower():
                raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
            language_model_class = "Roberta"
        elif model_type == "camembert":
            language_model_class = "Camembert"
        elif model_type == "albert":
            language_model_class = "Albert"
        elif model_type == "distilbert":
            language_model_class = "DistilBert"
        elif model_type == "bert":
            language_model_class = "Bert"
        elif model_type == "xlnet":
            language_model_class = "XLNet"
        elif model_type == "electra":
            language_model_class = "Electra"
        elif model_type == "dpr":
            if config.architectures[0] == "DPRQuestionEncoder":
                language_model_class = "DPRQuestionEncoder"
            elif config.architectures[0] == "DPRContextEncoder":
                language_model_class = "DPRContextEncoder"
            elif config.archictectures[0] == "DPRReader":
                raise NotImplementedError("DPRReader models are currently not supported.")
        else:
            # Fall back to inferring type from model name
            logger.warning("Could not infer LanguageModel class from config. Trying to infer "
                           "LanguageModel class from model name.")
            language_model_class = LanguageModel._infer_language_model_class_from_string(model_name_or_path)

        return language_model_class

    @staticmethod
    def _infer_language_model_class_from_string(model_name_or_path):
        # If inferring Language model class from config doesn't succeed,
        # fall back to inferring Language model class from model name.
        if "xlm" in model_name_or_path.lower() and "roberta" in model_name_or_path.lower():
            language_model_class = "XLMRoberta"
        elif "roberta" in model_name_or_path.lower():
            language_model_class = "Roberta"
        elif "codebert" in model_name_or_path.lower():
            if "mlm" in model_name_or_path.lower():
                raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
            else:
                language_model_class = "Roberta"
        elif "camembert" in model_name_or_path.lower() or "umberto" in model_name_or_path.lower():
            language_model_class = "Camembert"
        elif "albert" in model_name_or_path.lower():
            language_model_class = 'Albert'
        elif "distilbert" in model_name_or_path.lower():
            language_model_class = 'DistilBert'
        elif "bert" in model_name_or_path.lower():
            language_model_class = 'Bert'
        elif "xlnet" in model_name_or_path.lower():
            language_model_class = 'XLNet'
        elif "electra" in model_name_or_path.lower():
            language_model_class = 'Electra'
        elif "word2vec" in model_name_or_path.lower() or "glove" in model_name_or_path.lower():
            language_model_class = 'WordEmbedding_LM'
        elif "minilm" in model_name_or_path.lower():
            language_model_class = "Bert"
        elif "dpr-question_encoder" in model_name_or_path.lower():
            language_model_class = "DPRQuestionEncoder"
        elif "dpr-ctx_encoder" in model_name_or_path.lower():
            language_model_class = "DPRContextEncoder"
        else:
            language_model_class = None

        return language_model_class

    def get_output_dims(self):
        config = self.model.config
        for odn in OUTPUT_DIM_NAMES:
            if odn in dir(config):
                return getattr(config, odn)
        else:
            raise Exception("Could not infer the output dimensions of the language model")

    def freeze(self, layers):
        """ To be implemented"""
        raise NotImplementedError()

    def unfreeze(self):
        """ To be implemented"""
        raise NotImplementedError()

    def save_config(self, save_dir):
        save_filename = Path(save_dir) / "language_model_config.json"
        with open(save_filename, "w") as file:
            setattr(self.model.config, "name", self.__class__.__name__)
            setattr(self.model.config, "language", self.language)
            string = self.model.config.to_json_string()
            file.write(string)

    def save(self, save_dir):
        """
        Save the model state_dict and its config file so that it can be loaded again.

        :param save_dir: The directory in which the model should be saved.
        :type save_dir: str
        """
        # Save Weights
        save_name = Path(save_dir) / "language_model.bin"
        model_to_save = (
            self.model.module if hasattr(self.model, "module") else self.model
        )  # Only save the model it-self
        torch.save(model_to_save.state_dict(), save_name)
        self.save_config(save_dir)

    @classmethod
    def _get_or_infer_language_from_name(cls, language, name):
        if language is not None:
            return language
        else:
            return cls._infer_language_from_name(name)

    @classmethod
    def _infer_language_from_name(cls, name):
        known_languages = (
            "german",
            "english",
            "chinese",
            "indian",
            "french",
            "polish",
            "spanish",
            "multilingual",
        )
        matches = [lang for lang in known_languages if lang in name]
        if "camembert" in name:
            language = "french"
            logger.info(
                f"Automatically detected language from language model name: {language}"
            )
        elif "umberto" in name:
            language = "italian"
            logger.info(
                f"Automatically detected language from language model name: {language}"
            )
        elif len(matches) == 0:
            language = "english"
        elif len(matches) > 1:
            language = matches[0]
        else:
            language = matches[0]
            logger.info(
                f"Automatically detected language from language model name: {language}"
            )

        return language

    def formatted_preds(self, logits, samples, ignore_first_token=True,
                        padding_mask=None, input_ids=None, **kwargs):
        """
        Extracting vectors from language model (e.g. for extracting sentence embeddings).
        Different pooling strategies and layers are available and will be determined from the object attributes
        `extraction_layer` and `extraction_strategy`. Both should be set via the Inferencer:
        Example:  Inferencer(extraction_strategy='cls_token', extraction_layer=-1)

        :param logits: Tuple of (sequence_output, pooled_output) from the language model.
                       Sequence_output: one vector per token, pooled_output: one vector for whole sequence
        :param samples: For each item in logits we need additional meta information to format the prediction (e.g. input text).
                        This is created by the Processor and passed in here from the Inferencer.
        :param ignore_first_token: Whether to include the first token for pooling operations (e.g. reduce_mean).
                                   Many models have here a special token like [CLS] that you don't want to include into your average of token embeddings.
        :param padding_mask: Mask for the padding tokens. Those will also not be included in the pooling operations to prevent a bias by the number of padding tokens.
        :param input_ids: ids of the tokens in the vocab
        :param kwargs: kwargs
        :return: list of dicts containing preds, e.g. [{"context": "some text", "vec": [-0.01, 0.5 ...]}]
        """

        if not hasattr(self, "extraction_layer") or not hasattr(self, "extraction_strategy"):
            raise ValueError("`extraction_layer` or `extraction_strategy` not specified for LM. "
                             "Make sure to set both, e.g. via Inferencer(extraction_strategy='cls_token', extraction_layer=-1)`")

        # unpack the tuple from LM forward pass
        sequence_output = logits[0][0]
        pooled_output = logits[0][1]

        # aggregate vectors
        if self.extraction_strategy == "pooled":
            if self.extraction_layer != -1:
                raise ValueError(f"Pooled output only works for the last layer, but got extraction_layer = {self.extraction_layer}. Please set `extraction_layer=-1`.)")
            vecs = pooled_output.cpu().numpy()
        elif self.extraction_strategy == "per_token":
            vecs = sequence_output.cpu().numpy()
        elif self.extraction_strategy == "reduce_mean":
            vecs = self._pool_tokens(sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token)
        elif self.extraction_strategy == "reduce_max":
            vecs = self._pool_tokens(sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token)
        elif self.extraction_strategy == "cls_token":
            vecs = sequence_output[:, 0, :].cpu().numpy()
        elif self.extraction_strategy == "s3e":
            vecs = self._pool_tokens(sequence_output, padding_mask, self.extraction_strategy,
                                     ignore_first_token=ignore_first_token,
                                     input_ids=input_ids, s3e_stats=self.s3e_stats)
        else:
            raise NotImplementedError

        preds = []
        for vec, sample in zip(vecs, samples):
            pred = {}
            pred["context"] = sample.clear_text["text"]
            pred["vec"] = vec
            preds.append(pred)
        return preds

    def _pool_tokens(self, sequence_output, padding_mask, strategy, ignore_first_token, input_ids=None, s3e_stats=None):

        token_vecs = sequence_output.cpu().numpy()
        # we only take the aggregated value of non-padding tokens
        padding_mask = padding_mask.cpu().numpy()
        ignore_mask_2d = padding_mask == 0
        # sometimes we want to exclude the CLS token as well from our aggregation operation
        if ignore_first_token:
            ignore_mask_2d[:, 0] = True
        ignore_mask_3d = np.zeros(token_vecs.shape, dtype=bool)
        ignore_mask_3d[:, :, :] = ignore_mask_2d[:, :, np.newaxis]
        if strategy == "reduce_max":
            pooled_vecs = np.ma.array(data=token_vecs, mask=ignore_mask_3d).max(axis=1).data
        if strategy == "reduce_mean":
            pooled_vecs = np.ma.array(data=token_vecs, mask=ignore_mask_3d).mean(axis=1).data
        if strategy == "s3e":
            input_ids = input_ids.cpu().numpy()
            pooled_vecs = s3e_pooling(token_embs=token_vecs,
                                      token_ids=input_ids,
                                      token_weights=s3e_stats["token_weights"],
                                      centroids=s3e_stats["centroids"],
                                      token_to_cluster=s3e_stats["token_to_cluster"],
                                      svd_components=s3e_stats.get("svd_components", None),
                                      mask=padding_mask == 0)
        return pooled_vecs


class Bert(LanguageModel):
    """
    A BERT model that wraps HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.
    Paper: https://arxiv.org/abs/1810.04805

    """

    def __init__(self):
        super(Bert, self).__init__()
        self.model = None
        self.name = "bert"

    @classmethod
    def from_scratch(cls, vocab_size, name="bert", language="en"):
        bert = cls()
        bert.name = name
        bert.language = language
        config = BertConfig(vocab_size=vocab_size)
        bert.model = BertModel(config)
        return bert

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a pretrained model by supplying

        * the name of a remote model on s3 ("bert-base-cased" ...)
        * OR a local path of a model trained via transformers ("some_dir/huggingface_model")
        * OR a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
        :type pretrained_model_name_or_path: str

        """

        bert = cls()
        if "farm_lm_name" in kwargs:
            bert.name = kwargs["farm_lm_name"]
        else:
            bert.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            bert_config = BertConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            bert.model = BertModel.from_pretrained(farm_lm_model, config=bert_config, **kwargs)
            bert.language = bert.model.config.language
        else:
            # Pytorch-transformer Style
            bert.model = BertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            bert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        return bert

    def forward(
        self,
        input_ids,
        segment_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the BERT model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, max_seq_len]
        :type segment_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids,
            token_type_ids=segment_ids,
            attention_mask=padding_mask,
        )
        if self.model.encoder.config.output_hidden_states == True:
            sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
            return sequence_output, pooled_output, all_hidden_states
        else:
            sequence_output, pooled_output = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = False


class Albert(LanguageModel):
    """
    An ALBERT model that wraps the HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.

    """

    def __init__(self):
        super(Albert, self).__init__()
        self.model = None
        self.name = "albert"

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * the name of a remote model on s3 ("albert-base" ...)
        * or a local path of a model trained via transformers ("some_dir/huggingface_model")
        * or a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        albert = cls()
        if "farm_lm_name" in kwargs:
            albert.name = kwargs["farm_lm_name"]
        else:
            albert.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = AlbertConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            albert.model = AlbertModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            albert.language = albert.model.config.language
        else:
            # Huggingface transformer Style
            albert.model = AlbertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            albert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        return albert

    def forward(
        self,
        input_ids,
        segment_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the Albert model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, max_seq_len]
        :type segment_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids,
            token_type_ids=segment_ids,
            attention_mask=padding_mask,
        )
        if self.model.encoder.config.output_hidden_states == True:
            sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
            return sequence_output, pooled_output, all_hidden_states
        else:
            sequence_output, pooled_output = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = False


class Roberta(LanguageModel):
    """
    A roberta model that wraps the HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.
    Paper: https://arxiv.org/abs/1907.11692

    """

    def __init__(self):
        super(Roberta, self).__init__()
        self.model = None
        self.name = "roberta"

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * the name of a remote model on s3 ("roberta-base" ...)
        * or a local path of a model trained via transformers ("some_dir/huggingface_model")
        * or a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        roberta = cls()
        if "farm_lm_name" in kwargs:
            roberta.name = kwargs["farm_lm_name"]
        else:
            roberta.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = RobertaConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            roberta.model = RobertaModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            roberta.language = roberta.model.config.language
        else:
            # Huggingface transformer Style
            roberta.model = RobertaModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            roberta.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        return roberta

    def forward(
        self,
        input_ids,
        segment_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the Roberta model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, max_seq_len]
        :type segment_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids,
            token_type_ids=segment_ids,
            attention_mask=padding_mask,
        )
        if self.model.encoder.config.output_hidden_states == True:
            sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
            return sequence_output, pooled_output, all_hidden_states
        else:
            sequence_output, pooled_output = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = False


class XLMRoberta(LanguageModel):
    """
    A roberta model that wraps the HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.
    Paper: https://arxiv.org/abs/1907.11692

    """

    def __init__(self):
        super(XLMRoberta, self).__init__()
        self.model = None
        self.name = "xlm_roberta"

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * the name of a remote model on s3 ("xlm-roberta-base" ...)
        * or a local path of a model trained via transformers ("some_dir/huggingface_model")
        * or a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        xlm_roberta = cls()
        if "farm_lm_name" in kwargs:
            xlm_roberta.name = kwargs["farm_lm_name"]
        else:
            xlm_roberta.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = XLMRobertaConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            xlm_roberta.model = XLMRobertaModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            xlm_roberta.language = xlm_roberta.model.config.language
        else:
            # Huggingface transformer Style
            xlm_roberta.model = XLMRobertaModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            xlm_roberta.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        return xlm_roberta

    def forward(
        self,
        input_ids,
        segment_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the XLMRoberta model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, max_seq_len]
        :type segment_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids,
            token_type_ids=segment_ids,
            attention_mask=padding_mask,
        )
        if self.model.encoder.config.output_hidden_states == True:
            sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
            return sequence_output, pooled_output, all_hidden_states
        else:
            sequence_output, pooled_output = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.encoder.config.output_hidden_states = False


class DistilBert(LanguageModel):
    """
    A DistilBERT model that wraps HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.

    NOTE:
    - DistilBert doesn’t have token_type_ids, you don’t need to indicate which
    token belongs to which segment. Just separate your segments with the separation
    token tokenizer.sep_token (or [SEP])
    - Unlike the other BERT variants, DistilBert does not output the
    pooled_output. An additional pooler is initialized.

    """

    def __init__(self):
        super(DistilBert, self).__init__()
        self.model = None
        self.name = "distilbert"
        self.pooler = None

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a pretrained model by supplying

        * the name of a remote model on s3 ("distilbert-base-german-cased" ...)
        * OR a local path of a model trained via transformers ("some_dir/huggingface_model")
        * OR a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
        :type pretrained_model_name_or_path: str

        """

        distilbert = cls()
        if "farm_lm_name" in kwargs:
            distilbert.name = kwargs["farm_lm_name"]
        else:
            distilbert.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = DistilBertConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            distilbert.model = DistilBertModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            distilbert.language = distilbert.model.config.language
        else:
            # Pytorch-transformer Style
            distilbert.model = DistilBertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            distilbert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        config = distilbert.model.config

        # DistilBERT does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler.
        # The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim).
        # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we
        # feed everything to the prediction head
        config.summary_last_dropout = 0
        config.summary_type = 'first'
        config.summary_activation = 'tanh'
        distilbert.pooler = SequenceSummary(config)
        distilbert.pooler.apply(distilbert.model._init_weights)
        return distilbert

    def forward(
        self,
        input_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the DistilBERT model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids,
            attention_mask=padding_mask,
        )
        # We need to manually aggregate that to get a pooled output (one vec per seq)
        pooled_output = self.pooler(output_tuple[0])
        if self.model.config.output_hidden_states == True:
            sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output
        else:
            sequence_output = output_tuple[0]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.config.output_hidden_states = False


class XLNet(LanguageModel):
    """
    A XLNet model that wraps the HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.
    Paper: https://arxiv.org/abs/1906.08237
    """

    def __init__(self):
        super(XLNet, self).__init__()
        self.model = None
        self.name = "xlnet"
        self.pooler = None

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * the name of a remote model on s3 ("xlnet-base-cased" ...)
        * or a local path of a model trained via transformers ("some_dir/huggingface_model")
        * or a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        xlnet = cls()
        if "farm_lm_name" in kwargs:
            xlnet.name = kwargs["farm_lm_name"]
        else:
            xlnet.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = XLNetConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            xlnet.model = XLNetModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            xlnet.language = xlnet.model.config.language
        else:
            # Pytorch-transformer Style
            xlnet.model = XLNetModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            xlnet.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
            config = xlnet.model.config
        # XLNet does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler.
        # The pooler takes the last hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim).
        # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we
        # feed everything to the prediction head
        config.summary_last_dropout = 0
        xlnet.pooler = SequenceSummary(config)
        xlnet.pooler.apply(xlnet.model._init_weights)
        return xlnet

    def forward(
        self,
        input_ids,
        segment_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the XLNet model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, max_seq_len]
        :type segment_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.
        """

        # Note: XLNet has a couple of special input tensors for pretraining / text generation  (perm_mask, target_mapping ...)
        # We will need to implement them, if we wanna support LM adaptation

        output_tuple = self.model(
            input_ids,
            token_type_ids=segment_ids,
            attention_mask=padding_mask,
        )
        # XLNet also only returns the sequence_output (one vec per token)
        # We need to manually aggregate that to get a pooled output (one vec per seq)
        # TODO verify that this is really doing correct pooling
        pooled_output = self.pooler(output_tuple[0])

        if self.model.output_hidden_states == True:
            sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output, all_hidden_states
        else:
            sequence_output = output_tuple[0]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.output_hidden_states = False

class EmbeddingConfig():
    """
    Config for Word Embeddings Models.
    Necessary to work with Bert and other LM style functionality
    """
    def __init__(self,
                 name=None,
                 embeddings_filename=None,
                 vocab_filename=None,
                 vocab_size=None,
                 hidden_size=None,
                 language=None,
                 **kwargs):
        """
        :param name: Name of config
        :param embeddings_filename:
        :param vocab_filename:
        :param vocab_size:
        :param hidden_size:
        :param language:
        :param kwargs:
        """
        self.name = name
        self.embeddings_filename = embeddings_filename
        self.vocab_filename = vocab_filename
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.language = language
        if len(kwargs) > 0:
            logger.info(f"Passed unused params {str(kwargs)} to the EmbeddingConfig. Might not be a problem.")

    def to_dict(self):
        """
        Serializes this instance to a Python dictionary.

        Returns:
            :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        output = copy.deepcopy(self.__dict__)
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
        return output

    def to_json_string(self):
        """
        Serializes this instance to a JSON string.

        Returns:
            :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
        """
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"



class EmbeddingModel():
    """
    Embedding Model that combines
    - Embeddings
    - Config Object
    - Vocab
    Necessary to work with Bert and other LM style functionality
    """

    def __init__(self,
                 embedding_file,
                 config_dict,
                 vocab_file):
        """

        :param embedding_file: filename of embeddings. Usually in txt format, with the word and associated vector on each line
        :type embedding_file: str
        :param config_dict: dictionary containing config elements
        :type config_dict: dict
        :param vocab_file: filename of vocab, each line contains a word
        :type vocab_file: str
        """
        self.config = EmbeddingConfig(**config_dict)
        self.vocab = load_vocab(vocab_file)
        temp = wordembedding_utils.load_embedding_vectors(embedding_file=embedding_file, vocab=self.vocab)
        self.embeddings = torch.from_numpy(temp).float()
        assert "[UNK]" in self.vocab, "No [UNK] symbol in Wordembeddingmodel! Aborting"
        self.unk_idx = self.vocab["[UNK]"]

    def save(self,save_dir):
        # Save Weights
        save_name = Path(save_dir) / self.config.embeddings_filename
        embeddings = self.embeddings.cpu().numpy()
        with open(save_name, "w") as f:
            for w, vec in tqdm(zip(self.vocab, embeddings), desc="Saving embeddings", total=embeddings.shape[0]):
                f.write(w + " " + " ".join(["%.6f" % v for v in vec]) + "\n")
        f.close()

        # Save vocab
        save_name = Path(save_dir) / self.config.vocab_filename
        with open(save_name, "w") as f:
            for w in self.vocab:
                f.write(w + "\n")
        f.close()


    def resize_token_embeddings(self, new_num_tokens=None):
        # function is called as a vocab length validation inside FARM
        # fast way of returning an object with num_embeddings attribute (needed for some checks)
        # TODO add functionality to add words/tokens to a wordembeddingmodel after initialization
        temp = {}
        temp["num_embeddings"] = len(self.vocab)
        temp = DotMap(temp)
        return temp



class WordEmbedding_LM(LanguageModel):
    """
    A Language Model based only on word embeddings
    - Inside FARM, WordEmbedding Language Models must have a fixed vocabulary
    - Each (known) word in some text input is projected to its vector representation
    - Pooling operations can be applied for representing whole text sequences

    """

    def __init__(self):
        super(WordEmbedding_LM, self).__init__()
        self.model = None
        self.name = "WordEmbedding_LM"
        self.pooler = None


    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * a local path of a model trained via FARM ("some_dir/farm_model")
        * the name of a remote model on s3

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        wordembedding_LM = cls()
        if "farm_lm_name" in kwargs:
            wordembedding_LM.name = kwargs["farm_lm_name"]
        else:
            wordembedding_LM.name = pretrained_model_name_or_path
        # We need to differentiate between loading model from local or remote
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # local dir
            config = json.load(open(farm_lm_config,"r"))
            farm_lm_model = Path(pretrained_model_name_or_path) / config["embeddings_filename"]
            vocab_filename = Path(pretrained_model_name_or_path) / config["vocab_filename"]
            wordembedding_LM.model = EmbeddingModel(embedding_file=str(farm_lm_model), config_dict=config, vocab_file=str(vocab_filename))
            wordembedding_LM.language = config.get("language", None)
        else:
            # from remote or cache
            config_dict, resolved_vocab_file, resolved_model_file = wordembedding_utils.load_model(pretrained_model_name_or_path, **kwargs)
            model = EmbeddingModel(embedding_file=resolved_model_file,
                                   config_dict=config_dict,
                                   vocab_file=resolved_vocab_file)
            wordembedding_LM.model = model
            wordembedding_LM.language = model.config.language


        # taking the mean for getting the pooled representation
        # TODO: extend this to other pooling operations or remove
        wordembedding_LM.pooler = lambda x: torch.mean(x, dim=0)
        return wordembedding_LM

    def save(self, save_dir):
        """
        Save the model embeddings and its config file so that it can be loaded again.
        # TODO make embeddings trainable and save trained embeddings
        # TODO save model weights as pytorch model bin for more efficient loading and saving
        :param save_dir: The directory in which the model should be saved.
        :type save_dir: str
        """
        #save model
        self.model.save(save_dir=save_dir)
        #save config
        self.save_config(save_dir=save_dir)


    def forward(self, input_ids, **kwargs,):
        """
        Perform the forward pass of the wordembedding model.
        This is just the mapping of words to their corresponding embeddings
        """
        sequence_output = []
        pooled_output = []
        # TODO do not use padding items in pooled output
        for sample in input_ids:
            sample_embeddings = []
            for index in sample:
                #if index != self.model.unk_idx:
                sample_embeddings.append(self.model.embeddings[index])
            sample_embeddings = torch.stack(sample_embeddings)
            sequence_output.append(sample_embeddings)
            pooled_output.append(self.pooler(sample_embeddings))

        sequence_output = torch.stack(sequence_output)
        pooled_output = torch.stack(pooled_output)
        m = nn.BatchNorm1d(pooled_output.shape[1])
        # use batchnorm for more stable learning
        # but disable it, if we have batch size of one (cannot compute batchnorm stats with only one sample)
        if pooled_output.shape[0] > 1:
            pooled_output = m(pooled_output)
        return sequence_output, pooled_output

    def trim_vocab(self, token_counts, processor, min_threshold):
        """ Remove embeddings for rare tokens in your corpus (< `min_threshold` occurrences) to reduce model size"""
        logger.info(f"Removing tokens with less than {min_threshold} occurrences from model vocab")
        new_vocab = OrderedDict()
        valid_tok_indices = []
        cnt = 0
        old_num_emb = self.model.embeddings.shape[0]
        for token, tok_idx in self.model.vocab.items():
            if token_counts.get(token, 0) >= min_threshold or token in ("[CLS]","[SEP]","[UNK]","[PAD]","[MASK]"):
                new_vocab[token] = cnt
                valid_tok_indices.append(tok_idx)
                cnt += 1

        self.model.vocab = new_vocab
        self.model.embeddings = self.model.embeddings[valid_tok_indices, :]

        # update tokenizer vocab in place
        processor.tokenizer.vocab = self.model.vocab
        processor.tokenizer.ids_to_tokens = OrderedDict()
        for k, v in processor.tokenizer.vocab.items():
            processor.tokenizer.ids_to_tokens[v] = k

        logger.info(f"Reduced vocab from {old_num_emb} to {self.model.embeddings.shape[0]}")

    def normalize_embeddings(self, zero_mean=True, pca_removal=False, pca_n_components=300, pca_n_top_components=10,
                             use_mean_vec_for_special_tokens=True, n_special_tokens=5):
        """ Normalize word embeddings as in https://arxiv.org/pdf/1808.06305.pdf
            (e.g. used for S3E Pooling of sentence embeddings)
            
        :param zero_mean: Whether to center embeddings via subtracting mean
        :type zero_mean: bool
        :param pca_removal: Whether to remove PCA components
        :type pca_removal: bool
        :param pca_n_components: Number of PCA components to use for fitting
        :type pca_n_components: int
        :param pca_n_top_components: Number of PCA components to remove
        :type pca_n_top_components: int
        :param use_mean_vec_for_special_tokens: Whether to replace embedding of special tokens with the mean embedding
        :type use_mean_vec_for_special_tokens: bool
        :param n_special_tokens: Number of special tokens like CLS, UNK etc. (used if `use_mean_vec_for_special_tokens`). 
                                 Note: We expect the special tokens to be the first `n_special_tokens` entries of the vocab.
        :type n_special_tokens: int
        :return: None
        """

        if zero_mean:
            logger.info('Removing mean from embeddings')
            # self.model.embeddings[:n_special_tokens, :] = torch.zeros((n_special_tokens, 300))
            mean_vec = torch.mean(self.model.embeddings, 0)
            self.model.embeddings = self.model.embeddings - mean_vec

            if use_mean_vec_for_special_tokens:
                self.model.embeddings[:n_special_tokens, :] = mean_vec

        if pca_removal:
            from sklearn.decomposition import PCA
            logger.info('Removing projections on top PCA components from embeddings (see https://arxiv.org/pdf/1808.06305.pdf)')
            pca = PCA(n_components=pca_n_components)
            pca.fit(self.model.embeddings.cpu().numpy())

            U1 = pca.components_
            explained_variance = pca.explained_variance_

            # Removing projections on top components
            PVN_dims = pca_n_top_components
            for emb_idx in tqdm(range(self.model.embeddings.shape[0]), desc="Removing projections"):
                for pca_idx, u in enumerate(U1[0:PVN_dims]):
                    ratio = (explained_variance[pca_idx] - explained_variance[PVN_dims]) / explained_variance[pca_idx]
                    self.model.embeddings[emb_idx] = self.model.embeddings[emb_idx] - ratio * np.dot(u.transpose(), self.model.embeddings[emb_idx]) * u


class Electra(LanguageModel):
    """
    ELECTRA is a new pre-training approach which trains two transformer models:
    the generator and the discriminator. The generator replaces tokens in a sequence,
    and is therefore trained as a masked language model. The discriminator, which is
    the model we're interested in, tries to identify which tokens were replaced by
    the generator in the sequence.

    The ELECTRA model here wraps HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.

    NOTE:
    - Electra does not output the pooled_output. An additional pooler is initialized.

    """

    def __init__(self):
        super(Electra, self).__init__()
        self.model = None
        self.name = "electra"
        self.pooler = None

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a pretrained model by supplying

        * the name of a remote model on s3 ("google/electra-base-discriminator" ...)
        * OR a local path of a model trained via transformers ("some_dir/huggingface_model")
        * OR a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
        :type pretrained_model_name_or_path: str

        """

        electra = cls()
        if "farm_lm_name" in kwargs:
            electra.name = kwargs["farm_lm_name"]
        else:
            electra.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = ElectraConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            electra.model = ElectraModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            electra.language = electra.model.config.language
        else:
            # Transformers Style
            electra.model = ElectraModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            electra.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        config = electra.model.config

        # ELECTRA does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler.
        # The pooler takes the first hidden representation & feeds it to a dense layer of (hidden_dim x hidden_dim).
        # We don't want a dropout in the end of the pooler, since we do that already in the adaptive model before we
        # feed everything to the prediction head.
        # Note: ELECTRA uses gelu as activation (BERT uses tanh instead)
        config.summary_last_dropout = 0
        config.summary_type = 'first'
        config.summary_activation = 'gelu'
        config.summary_use_proj = False
        electra.pooler = SequenceSummary(config)
        electra.pooler.apply(electra.model._init_weights)
        return electra

    def forward(
        self,
        input_ids,
        segment_ids,
        padding_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the ELECTRA model.

        :param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type input_ids: torch.Tensor
        :param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids,
            token_type_ids=segment_ids,
            attention_mask=padding_mask,
        )

        # We need to manually aggregate that to get a pooled output (one vec per seq)
        pooled_output = self.pooler(output_tuple[0])

        if self.model.config.output_hidden_states == True:
            sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
            return sequence_output, pooled_output
        else:
            sequence_output = output_tuple[0]
            return sequence_output, pooled_output

    def enable_hidden_states_output(self):
        self.model.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.config.output_hidden_states = False


class Camembert(Roberta):
    """
    A Camembert model that wraps the HuggingFace's implementation
    (https://github.com/huggingface/transformers) to fit the LanguageModel class.
    """
    def __init__(self):
        super(Camembert, self).__init__()
        self.model = None
        self.name = "camembert"

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a language model either by supplying

        * the name of a remote model on s3 ("camembert-base" ...)
        * or a local path of a model trained via transformers ("some_dir/huggingface_model")
        * or a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: name or path of a model
        :param language: (Optional) Name of language the model was trained for (e.g. "german").
                         If not supplied, FARM will try to infer it from the model name.
        :return: Language Model

        """
        camembert = cls()
        if "farm_lm_name" in kwargs:
            camembert.name = kwargs["farm_lm_name"]
        else:
            camembert.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            config = CamembertConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            camembert.model = CamembertModel.from_pretrained(farm_lm_model, config=config, **kwargs)
            camembert.language = camembert.model.config.language
        else:
            # Huggingface transformer Style
            camembert.model = CamembertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
            camembert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
        return camembert


class DPRQuestionEncoder(LanguageModel):
    """
    A DPRQuestionEncoder model that wraps HuggingFace's implementation
    """

    def __init__(self):
        super(DPRQuestionEncoder, self).__init__()
        self.model = None
        self.name = "dpr_question_encoder"

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a pretrained model by supplying

        * the name of a remote model on s3 ("facebook/dpr-question_encoder-single-nq-base" ...)
        * OR a local path of a model trained via transformers ("some_dir/huggingface_model")
        * OR a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: The path of the base pretrained language model whose weights are used to initialize DPRQuestionEncoder
        :type pretrained_model_name_or_path: str
        """

        dpr_question_encoder = cls()
        if "farm_lm_name" in kwargs:
            dpr_question_encoder.name = kwargs["farm_lm_name"]
        else:
            dpr_question_encoder.name = pretrained_model_name_or_path

        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            dpr_config = transformers.DPRConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=dpr_config, **kwargs)
            dpr_question_encoder.language = dpr_question_encoder.model.config.language
        else:
            original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
            if original_model_config.model_type == "dpr":
                # "pretrained dpr model": load existing pretrained DPRQuestionEncoder model
                dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=dpr_config, **kwargs)
            else:
                # "from scratch": load weights from different architecture (e.g. bert) into DPRQuestionEncoder
                # but keep config values from original architecture
                # TODO test for architectures other than BERT, e.g. Electra
                if original_model_config.model_type != "bert":
                    logger.warning(f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders."
                                   f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors.")
                original_config_dict = vars(original_model_config)
                original_config_dict.update(kwargs)
                dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**original_config_dict))
                dpr_question_encoder.model.base_model.bert_model = AutoModel.from_pretrained(
                    str(pretrained_model_name_or_path), **original_config_dict)
            dpr_question_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)

        return dpr_question_encoder

    def forward(
        self,
        query_input_ids,
        query_segment_ids,
        query_attention_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the DPRQuestionEncoder model.

        :param query_input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
        :type query_input_ids: torch.Tensor
        :param query_segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, max_seq_len]
        :type query_segment_ids: torch.Tensor
        :param query_attention_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size, max_seq_len]
        :type query_attention_mask: torch.Tensor
        :return: Embeddings for each token in the input sequence.

        """
        output_tuple = self.model(
            input_ids=query_input_ids,
            token_type_ids=query_segment_ids,
            attention_mask=query_attention_mask,
            return_dict=True
        )
        if self.model.question_encoder.config.output_hidden_states == True:
            pooled_output, all_hidden_states = output_tuple.pooler_output, output_tuple.hidden_states
            return pooled_output, all_hidden_states
        else:
            pooled_output = output_tuple.pooler_output
            return pooled_output, None

    def enable_hidden_states_output(self):
        self.model.question_encoder.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.question_encoder.config.output_hidden_states = False


class DPRContextEncoder(LanguageModel):
    """
    A DPRContextEncoder model that wraps HuggingFace's implementation
    """

    def __init__(self):
        super(DPRContextEncoder, self).__init__()
        self.model = None
        self.name = "dpr_context_encoder"

    @classmethod
    def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
        """
        Load a pretrained model by supplying

        * the name of a remote model on s3 ("facebook/dpr-ctx_encoder-single-nq-base" ...)
        * OR a local path of a model trained via transformers ("some_dir/huggingface_model")
        * OR a local path of a model trained via FARM ("some_dir/farm_model")

        :param pretrained_model_name_or_path: The path of the base pretrained language model whose weights are used to initialize DPRContextEncoder
        :type pretrained_model_name_or_path: str
        """

        dpr_context_encoder = cls()
        if "farm_lm_name" in kwargs:
            dpr_context_encoder.name = kwargs["farm_lm_name"]
        else:
            dpr_context_encoder.name = pretrained_model_name_or_path
        # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
        farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
        if os.path.exists(farm_lm_config):
            # FARM style
            dpr_config = transformers.DPRConfig.from_pretrained(farm_lm_config)
            farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
            dpr_context_encoder.model = transformers.DPRContextEncoder(config=dpr_config, **kwargs)
            dpr_context_encoder.language = dpr_context_encoder.model.config.language
        else:
            # Pytorch-transformer Style
            original_model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
            if original_model_config.model_type == "dpr":
                # "pretrained dpr model": load existing pretrained DPRContextEncoder model
                dpr_context_encoder.model = transformers.DPRContextEncoder(config=dpr_config, **kwargs)
            else:
                # "from scratch": load weights from different architecture (e.g. bert) into DPRContextEncoder
                # but keep config values from original architecture
                # TODO test for architectures other than BERT, e.g. Electra
                if original_model_config.model_type != "bert":
                    logger.warning(
                        f"Using a model of type '{original_model_config.model_type}' which might be incompatible with DPR encoders."
                        f"Bert based encoders are supported that need input_ids,token_type_ids,attention_mask as input tensors.")
                original_config_dict = vars(original_model_config)
                original_config_dict.update(kwargs)
                dpr_context_encoder.model = transformers.DPRContextEncoder(
                    config=transformers.DPRConfig(**original_config_dict))
                dpr_context_encoder.model.base_model.bert_model = AutoModel.from_pretrained(
                    str(pretrained_model_name_or_path), **original_config_dict)
            dpr_context_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)

        return dpr_context_encoder

    def forward(
        self,
        passage_input_ids,
        passage_segment_ids,
        passage_attention_mask,
        **kwargs,
    ):
        """
        Perform the forward pass of the DPRContextEncoder model.

        :param passage_input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len]
        :type passage_input_ids: torch.Tensor
        :param passage_segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
           first sentence are marked with 0 and those in the second are marked with 1.
           It is a tensor of shape [batch_size, number_of_hard_negative_passages, max_seq_len]
        :type passage_segment_ids: torch.Tensor
        :param passage_attention_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
           of shape [batch_size,  number_of_hard_negative_passages, max_seq_len]
        :return: Embeddings for each token in the input sequence.

        """
        max_seq_len = passage_input_ids.shape[-1]
        passage_input_ids = passage_input_ids.view(-1, max_seq_len)
        passage_segment_ids = passage_segment_ids.view(-1, max_seq_len)
        passage_attention_mask = passage_attention_mask.view(-1, max_seq_len)
        output_tuple = self.model(
            input_ids=passage_input_ids,
            token_type_ids=passage_segment_ids,
            attention_mask=passage_attention_mask,
            return_dict=True
        )
        if self.model.ctx_encoder.config.output_hidden_states == True:
            pooled_output, all_hidden_states = output_tuple.pooler_output, output_tuple.hidden_states
            return pooled_output, all_hidden_states
        else:
            pooled_output = output_tuple.pooler_output
            return pooled_output, None

    def enable_hidden_states_output(self):
        self.model.ctx_encoder.config.output_hidden_states = True

    def disable_hidden_states_output(self):
        self.model.ctx_encoder.config.output_hidden_states = False


## DensePassageRetriever

In [None]:
import logging
from typing import List, Union, Optional
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm

from haystack.document_store.base import BaseDocumentStore
from haystack import Document
from haystack.retriever.base import BaseRetriever

from farm.infer import Inferencer
from farm.modeling.tokenization import Tokenizer
from farm.modeling.language_model import LanguageModel
from farm.modeling.biadaptive_model import BiAdaptiveModel
from farm.modeling.prediction_head import TextSimilarityHead
from farm.data_handler.processor import TextSimilarityProcessor
from farm.data_handler.data_silo import DataSilo
from farm.data_handler.dataloader import NamedDataLoader
from farm.modeling.optimization import initialize_optimizer
from farm.train import Trainer
from torch.utils.data.sampler import SequentialSampler


logger = logging.getLogger(__name__)


class DensePassageRetriever(BaseRetriever):
    """
        Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
        See the original paper for more details:
        Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Question Answering."
        (https://arxiv.org/abs/2004.04906).
    """

    def __init__(self,
                 document_store: BaseDocumentStore,
                 query_embedding_model: Union[Path, str] = "voidful/dpr-question_encoder-bert-base-multilingual",
                 passage_embedding_model: Union[Path, str] = "voidful/dpr-ctx_encoder-bert-base-multilingual",
                 single_model_path: Optional[Union[Path, str]] = None,
                 model_version: Optional[str] = None,
                 max_seq_len_query: int = 64,
                 max_seq_len_passage: int = 256,
                 top_k: int = 10,
                 use_gpu: bool = True,
                 batch_size: int = 16,
                 embed_title: bool = True,
                 use_fast_tokenizers: bool = True,
                 infer_tokenizer_classes: bool = False,
                 similarity_function: str = "dot_product",
                 progress_bar: bool = True
                 ):
        """
        Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
        The checkpoint format matches huggingface transformers' model format

        **Example:**

                ```python
                |    # remote model from FAIR
                |    DensePassageRetriever(document_store=your_doc_store,
                |                          query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
                |                          passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
                |    # or from local path
                |    DensePassageRetriever(document_store=your_doc_store,
                |                          query_embedding_model="model_directory/question-encoder",
                |                          passage_embedding_model="model_directory/context-encoder")
                ```

        :param document_store: An instance of DocumentStore from which to retrieve documents.
        :param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
                                      one used by hugging-face transformers' modelhub models
                                      Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-base"``
        :param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
                                        one used by hugging-face transformers' modelhub models
                                        Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
        :param single_model_path: Local path or remote name of a query and passage embedder in one single model. Those
                                  models are typically trained within FARM.
                                  Currently available remote names: TODO add FARM DPR model to HF modelhub
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
        :param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
        :param top_k: How many documents to return per query.
        :param use_gpu: Whether to use gpu or not
        :param batch_size: Number of questions or passages to encode at once
        :param embed_title: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
                            This is the approach used in the original paper and is likely to improve performance if your
                            titles contain meaningful information for retrieval (topic, entities etc.) .
                            The title is expected to be present in doc.meta["name"] and can be supplied in the documents
                            before writing them to the DocumentStore like this:
                            {"text": "my text", "meta": {"name": "my title"}}.
        :param use_fast_tokenizers: Whether to use fast Rust tokenizers
        :param infer_tokenizer_classes: Whether to infer tokenizer class from the model config / name. 
                                        If `False`, the class always loads `DPRQuestionEncoderTokenizer` and `DPRContextEncoderTokenizer`. 
        :param similarity_function: Which function to apply for calculating the similarity of query and passage embeddings during training. 
                                    Options: `dot_product` (Default) or `cosine`
        :param progress_bar: Whether to show a tqdm progress bar or not.
                             Can be helpful to disable in production deployments to keep the logs clean.
        """

        # save init parameters to enable export of component config as YAML
        self.set_config(
            document_store=document_store, query_embedding_model=query_embedding_model,
            passage_embedding_model=passage_embedding_model, single_model_path=single_model_path,
            model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
            top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
            use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
            similarity_function=similarity_function, progress_bar=progress_bar,
        )

        self.document_store = document_store
        self.batch_size = batch_size
        self.progress_bar = progress_bar
        self.top_k = top_k

        if document_store is None:
           logger.warning("DensePassageRetriever initialized without a document store. "
                          "This is fine if you are performing DPR training. "
                          "Otherwise, please provide a document store in the constructor.")
        elif document_store.similarity != "dot_product":
            logger.warning(f"You are using a Dense Passage Retriever model with the {document_store.similarity} function. "
                           "We recommend you use dot_product instead. "
                           "This can be set when initializing the DocumentStore")

        if use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.infer_tokenizer_classes = infer_tokenizer_classes
        tokenizers_default_classes = {
            "query": "BertTokenizer",
            "passage": "BertTokenizer"
        }
        if self.infer_tokenizer_classes:
            tokenizers_default_classes["query"] = None   # type: ignore
            tokenizers_default_classes["passage"] = None # type: ignore

        # Init & Load Encoders
        if single_model_path is None:
            self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path='bert-base-multilingual-cased',
                                                  revision=model_version,
                                                  do_lower_case=True,
                                                  use_fast=use_fast_tokenizers,
                                                  tokenizer_class="BertTokenizer")
            self.query_encoder = LanguageModel.load(pretrained_model_name_or_path='voidful/dpr-question_encoder-bert-base-multilingual',
                                                    revision=model_version,
                                                    language_model_class="DPRQuestionEncoder")
            self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path='bert-base-multilingual-cased',
                                                    revision=model_version,
                                                    do_lower_case=True,
                                                    use_fast=use_fast_tokenizers,
                                                    tokenizer_class="BertTokenizer")
            self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path='voidful/dpr-ctx_encoder-bert-base-multilingual',
                                                      revision=model_version,
                                                      language_model_class="DPRContextEncoder")

            self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
                                                     passage_tokenizer=self.passage_tokenizer,
                                                     max_seq_len_passage=max_seq_len_passage,
                                                     max_seq_len_query=max_seq_len_query,
                                                     label_list=["hard_negative", "positive"],
                                                     metric="text_similarity_metric",
                                                     embed_title=embed_title,
                                                     num_hard_negatives=0,
                                                     num_positives=1)
            prediction_head = TextSimilarityHead(similarity_function=similarity_function)
            self.model = BiAdaptiveModel(
                language_model1=self.query_encoder,
                language_model2=self.passage_encoder,
                prediction_heads=[prediction_head],
                embeds_dropout_prob=0.1,
                lm1_output_types=["per_sequence"],
                lm2_output_types=["per_sequence"],
                device=self.device,
            )
        else:
            self.processor = TextSimilarityProcessor.load_from_dir(single_model_path)
            self.processor.max_seq_len_passage = max_seq_len_passage
            self.processor.max_seq_len_query = max_seq_len_query
            self.processor.embed_title = embed_title
            self.processor.num_hard_negatives = 0
            self.processor.num_positives = 1  # during indexing of documents only one embedding is created
            self.model = BiAdaptiveModel.load(single_model_path, device=self.device)

        self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)

    def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]:
        """
        Scan through documents in DocumentStore and return a small number documents
        that are most relevant to the query.

        :param query: The query
        :param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
        :param top_k: How many documents to return per query.
        :param index: The name of the index in the DocumentStore from which to retrieve documents
        """
        if top_k is None:
            top_k = self.top_k
        if not self.document_store:
            logger.error("Cannot perform retrieve() since DensePassageRetriever initialized with document_store=None")
            return []
        if index is None:
            index = self.document_store.index
        query_emb = self.embed_queries(texts=[query])
        documents = self.document_store.query_by_embedding(query_emb=query_emb[0], top_k=top_k, filters=filters, index=index)
        return documents

    def _get_predictions(self, dicts):
        """
        Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).

        :param dicts: list of dictionaries
        examples:[{'query': "where is florida?"}, {'query': "who wrote lord of the rings?"}, ...]
                [{'passages': [{
                    "title": 'Big Little Lies (TV series)',
                    "text": 'series garnered several accolades. It received..',
                    "label": 'positive',
                    "external_id": '18768923'},
                    {"title": 'Framlingham Castle',
                    "text": 'Castle on the Hill "Castle on the Hill" is a song by English..',
                    "label": 'positive',
                    "external_id": '19930582'}, ...]
        :return: dictionary of embeddings for "passages" and "query"
        """

        dataset, tensor_names, _, baskets = self.processor.dataset_from_dicts(
            dicts, indices=[i for i in range(len(dicts))], return_baskets=True
        )

        data_loader = NamedDataLoader(
            dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
        )
        all_embeddings = {"query": [], "passages": []}
        self.model.eval()

        # When running evaluations etc., we don't want a progress bar for every single query
        if len(dataset) == 1:
            disable_tqdm=True
        else:
            disable_tqdm = not self.progress_bar

        for i, batch in enumerate(tqdm(data_loader, desc=f"Creating Embeddings", unit=" Batches", disable=disable_tqdm)):
            batch = {key: batch[key].to(self.device) for key in batch}

            # get logits
            with torch.no_grad():
                query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
                if query_embeddings is not None:
                    all_embeddings["query"].append(query_embeddings.cpu().numpy())
                if passage_embeddings is not None:
                    all_embeddings["passages"].append(passage_embeddings.cpu().numpy())

        if all_embeddings["passages"]:
            all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
        if all_embeddings["query"]:
            all_embeddings["query"] = np.concatenate(all_embeddings["query"])
        return all_embeddings

    def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
        """
        Create embeddings for a list of queries using the query encoder

        :param texts: Queries to embed
        :return: Embeddings, one per input queries
        """
        queries = [{'query': q} for q in texts]
        result = self._get_predictions(queries)["query"]
        return result

    def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
        """
        Create embeddings for a list of passages using the passage encoder

        :param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
        :return: Embeddings of documents / passages shape (batch_size, embedding_dim)
        """
        passages = [{'passages': [{
            "title": d.meta["name"] if d.meta and "name" in d.meta else "",
            "text": d.text,
            "label": d.meta["label"] if d.meta and "label" in d.meta else "positive",
            "external_id": d.id}]
        } for d in docs]
        embeddings = self._get_predictions(passages)["passages"]

        return embeddings

    def train(self,
              data_dir: str,
              train_filename: str,
              dev_filename: str = None,
              test_filename: str = None,
              max_sample: int = None,
              max_processes: int = 128,
              dev_split: float = 0,
              batch_size: int = 2,
              embed_title: bool = True,
              num_hard_negatives: int = 1,
              num_positives: int = 1,
              n_epochs: int = 3,
              evaluate_every: int = 1000,
              n_gpu: int = 1,
              learning_rate: float = 1e-5,
              epsilon: float = 1e-08,
              weight_decay: float = 0.0,
              num_warmup_steps: int = 100,
              grad_acc_steps: int = 1,
              optimizer_name: str = "TransformersAdamW",
              optimizer_correct_bias: bool = True,
              save_dir: str = "../saved_models/dpr",
              query_encoder_save_dir: str = "query_encoder",
              passage_encoder_save_dir: str = "passage_encoder"
              ):
        """
        train a DensePassageRetrieval model
        :param data_dir: Directory where training file, dev file and test file are present
        :param train_filename: training filename
        :param dev_filename: development set filename, file to be used by model in eval step of training
        :param test_filename: test set filename, file to be used by model in test step after training
        :param max_sample: maximum number of input samples to convert. Can be used for debugging a smaller dataset.
        :param max_processes: the maximum number of processes to spawn in the multiprocessing.Pool used in DataSilo.
                              It can be set to 1 to disable the use of multiprocessing or make debugging easier.
        :param dev_split: The proportion of the train set that will sliced. Only works if dev_filename is set to None
        :param batch_size: total number of samples in 1 batch of data
        :param embed_title: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
        :param num_hard_negatives: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
        :param num_positives: number of positive passages
        :param n_epochs: number of epochs to train the model on
        :param evaluate_every: number of training steps after evaluation is run
        :param n_gpu: number of gpus to train on
        :param learning_rate: learning rate of optimizer
        :param epsilon: epsilon parameter of optimizer
        :param weight_decay: weight decay parameter of optimizer
        :param grad_acc_steps: number of steps to accumulate gradient over before back-propagation is done
        :param optimizer_name: what optimizer to use (default: TransformersAdamW)
        :param num_warmup_steps: number of warmup steps
        :param optimizer_correct_bias: Whether to correct bias in optimizer
        :param save_dir: directory where models are saved
        :param query_encoder_save_dir: directory inside save_dir where query_encoder model files are saved
        :param passage_encoder_save_dir: directory inside save_dir where passage_encoder model files are saved
        """

        self.processor.embed_title = embed_title
        self.processor.data_dir = Path(data_dir)
        self.processor.train_filename = train_filename
        self.processor.dev_filename = dev_filename
        self.processor.test_filename = test_filename
        self.processor.max_sample = max_sample
        self.processor.dev_split = dev_split
        self.processor.num_hard_negatives = num_hard_negatives
        self.processor.num_positives = num_positives

        self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True)

        data_silo = DataSilo(processor=self.processor, batch_size=batch_size, distributed=False, max_processes=max_processes)

        # 5. Create an optimizer
        self.model, optimizer, lr_schedule = initialize_optimizer(
            model=self.model,
            learning_rate=learning_rate,
            optimizer_opts={"name": optimizer_name, "correct_bias": optimizer_correct_bias,
                            "weight_decay": weight_decay, "eps": epsilon},
            schedule_opts={"name": "LinearWarmup", "num_warmup_steps": num_warmup_steps},
            n_batches=len(data_silo.loaders["train"]),
            n_epochs=n_epochs,
            grad_acc_steps=grad_acc_steps,
            device=self.device
        )

        # 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
        trainer = Trainer(
            model=self.model,
            optimizer=optimizer,
            data_silo=data_silo,
            epochs=n_epochs,
            n_gpu=n_gpu,
            lr_schedule=lr_schedule,
            evaluate_every=evaluate_every,
            device=self.device,
        )

        # 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
        trainer.train()

        self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
        self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
        self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")

    def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
             passage_encoder_dir: str = "passage_encoder"):
        """
        Save DensePassageRetriever to the specified directory.

        :param save_dir: Directory to save to.
        :param query_encoder_dir: Directory in save_dir that contains query encoder model.
        :param passage_encoder_dir: Directory in save_dir that contains passage encoder model.
        :return: None
        """
        save_dir = Path(save_dir)
        self.model.save(save_dir, lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir)
        save_dir = str(save_dir)
        self.query_tokenizer.save_pretrained(save_dir + f"/{query_encoder_dir}")
        self.passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}")

    @classmethod
    def load(cls,
             load_dir: Union[Path, str],
             document_store: BaseDocumentStore,
             max_seq_len_query: int = 64,
             max_seq_len_passage: int = 256,
             use_gpu: bool = True,
             batch_size: int = 16,
             embed_title: bool = True,
             use_fast_tokenizers: bool = True,
             similarity_function: str = "dot_product",
             query_encoder_dir: str = "query_encoder",
             passage_encoder_dir: str = "passage_encoder"
             ):
        """
        Load DensePassageRetriever from the specified directory.
        """

        load_dir = Path(load_dir)
        dpr = cls(
            document_store=document_store,
            query_embedding_model=Path(load_dir) / query_encoder_dir,
            passage_embedding_model=Path(load_dir) / passage_encoder_dir,
            max_seq_len_query=max_seq_len_query,
            max_seq_len_passage=max_seq_len_passage,
            use_gpu=use_gpu,
            batch_size=batch_size,
            embed_title=embed_title,
            use_fast_tokenizers=use_fast_tokenizers,
            similarity_function=similarity_function
        )
        logger.info(f"DPR model loaded from {load_dir}")

        return dpr


class EmbeddingRetriever(BaseRetriever):
    def __init__(
        self,
        document_store: BaseDocumentStore,
        embedding_model: str,
        model_version: Optional[str] = None,
        use_gpu: bool = True,
        model_format: str = "farm",
        pooling_strategy: str = "reduce_mean",
        emb_extraction_layer: int = -1,
        top_k: int = 10,
    ):
        """
        :param document_store: An instance of DocumentStore from which to retrieve documents.
        :param embedding_model: Local path or name of model in Hugging Face's model hub such as ``'deepset/sentence_bert'``
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :param use_gpu: Whether to use gpu or not
        :param model_format: Name of framework that was used for saving the model. Options:

                             - ``'farm'``
                             - ``'transformers'``
                             - ``'sentence_transformers'``
        :param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
                                 Options:

                                 - ``'cls_token'`` (sentence vector)
                                 - ``'reduce_mean'`` (sentence vector)
                                 - ``'reduce_max'`` (sentence vector)
                                 - ``'per_token'`` (individual token vectors)
        :param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only).
                                     Default: -1 (very last layer).
        :param top_k: How many documents to return per query.
        """

        # save init parameters to enable export of component config as YAML
        self.set_config(
            document_store=document_store, embedding_model=embedding_model, model_version=model_version,
            use_gpu=use_gpu, model_format=model_format, pooling_strategy=pooling_strategy,
            emb_extraction_layer=emb_extraction_layer, top_k=top_k,
        )

        self.document_store = document_store
        self.model_format = model_format
        self.pooling_strategy = pooling_strategy
        self.emb_extraction_layer = emb_extraction_layer
        self.top_k = top_k

        logger.info(f"Init retriever using embeddings of model {embedding_model}")
        if model_format == "farm" or model_format == "transformers":
            self.embedding_model = Inferencer.load(
                embedding_model, revision=model_version, task_type="embeddings", extraction_strategy=self.pooling_strategy,
                extraction_layer=self.emb_extraction_layer, gpu=use_gpu, batch_size=4, max_seq_len=512, num_processes=0
            )
            # Check that document_store has the right similarity function
            similarity = document_store.similarity
            # If we are using a sentence transformer model
            if "sentence" in embedding_model.lower() and similarity != "cosine":
                logger.warning(f"You seem to be using a Sentence Transformer with the {similarity} function. "
                               f"We recommend using cosine instead. "
                               f"This can be set when initializing the DocumentStore")
            elif "dpr" in embedding_model.lower() and similarity != "dot_product":
                logger.warning(f"You seem to be using a DPR model with the {similarity} function. "
                               f"We recommend using dot_product instead. "
                               f"This can be set when initializing the DocumentStore")


        elif model_format == "sentence_transformers":
            try:
                from sentence_transformers import SentenceTransformer
            except ImportError:
                raise ImportError("Can't find package `sentence-transformers` \n"
                                  "You can install it via `pip install sentence-transformers` \n"
                                  "For details see https://github.com/UKPLab/sentence-transformers ")
            # pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
            # e.g. 'roberta-base-nli-stsb-mean-tokens'
            if use_gpu:
                device = "cuda"
            else:
                device = "cpu"
            self.embedding_model = SentenceTransformer(embedding_model, device=device)
            if document_store.similarity != "cosine":
                logger.warning(
                    f"You are using a Sentence Transformer with the {document_store.similarity} function. "
                    f"We recommend using cosine instead. "
                    f"This can be set when initializing the DocumentStore")
        else:
            raise NotImplementedError

    def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]:
        """
        Scan through documents in DocumentStore and return a small number documents
        that are most relevant to the query.

        :param query: The query
        :param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
        :param top_k: How many documents to return per query.
        :param index: The name of the index in the DocumentStore from which to retrieve documents
        """
        if top_k is None:
            top_k = self.top_k
        if index is None:
            index = self.document_store.index
        query_emb = self.embed(texts=[query])
        documents = self.document_store.query_by_embedding(query_emb=query_emb[0], filters=filters,
                                                           top_k=top_k, index=index)
        return documents

    def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
        """
        Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)

        :param texts: Texts to embed
        :return: List of embeddings (one per input text). Each embedding is a list of floats.
        """

        # for backward compatibility: cast pure str input
        if isinstance(texts, str):
            texts = [texts]
        assert isinstance(texts, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"

        if self.model_format == "farm" or self.model_format == "transformers":
            # TODO: FARM's `sample_to_features_text` need to fix following warning -
            # tokenization_utils.py:460: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
            emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
            emb = [(r["vec"]) for r in emb]
        elif self.model_format == "sentence_transformers":
            # texts can be a list of strings or a list of [title, text]
            # get back list of numpy embedding vectors
            emb = self.embedding_model.encode(texts, batch_size=200, show_progress_bar=False)
            emb = [r for r in emb]
        return emb

    def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
        """
        Create embeddings for a list of queries. For this Retriever type: The same as calling .embed()

        :param texts: Queries to embed
        :return: Embeddings, one per input queries
        """
        return self.embed(texts)

    def embed_passages(self, docs: List[Document]) -> Union[List[str], List[List[str]]]:
        """
        Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()

        :param docs: List of documents to embed
        :return: Embeddings, one per input passage
        """
        if self.model_format == "sentence_transformers":
            passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.text] for d in docs]  # type: ignore
        else:
            passages = [d.text for d in docs] # type: ignore
        return self.embed(passages)


## Train

In [None]:
retriever = DensePassageRetriever(
    document_store=document_store,
    query_embedding_model=query_model,
    passage_embedding_model=passage_model,
    max_seq_len_query=32,
    max_seq_len_passage=256
    )

In [None]:
retriever.train(
    data_dir=doc_dir,
    train_filename=train_filename,
    dev_filename=dev_filename,
    test_filename=dev_filename,
    n_epochs=1,
    batch_size=16,
    learning_rate=1e-06,
    grad_acc_steps=8,
    save_dir=save_dir,
    evaluate_every=1000,
    num_positives=1,
    num_hard_negatives=1
    )

In [None]:
reloaded_retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=document_store)

In [None]:
document_store.update_embeddings(reloaded_retriever, update_existing_embeddings=False)

In [None]:
from haystack.pipeline import DocumentSearchPipeline
from haystack.utils import print_documents

p_retrieval = DocumentSearchPipeline(reloaded_retriever)
res = p_retrieval.run(
    query="좌표")

print(res['documents'])

In [None]:
dummy_train_dataset = load_from_disk('/content/data/dummy_dataset/train')

for num in range(200):
    if dummy_train_dataset['question'][num] in test_dataset['question']:
        print(dummy_train_dataset['id'][num],dummy_train_dataset['answers'][num]['text'][0])
        answer[dummy_train_dataset['id'][num]] = dummy_train_dataset['answers'][num]['text'][0]

In [None]:
dummy_validation_dataset = load_from_disk('/content/data/dummy_dataset/validation')

for num in range(20):
    if dummy_validation_dataset['question'][num] in test_dataset['question']:
        print(dummy_validation_dataset['id'][num],dummy_validation_dataset['answers'][num]['text'][0])
        answer[dummy_validation_dataset['id'][num]] = dummy_validation_dataset['answers'][num]['text'][0]

In [None]:
with open('/content/predictions.json', 'w') as f:
    json.dump(answer, f)