In [None]:
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import json
import sys

sys.path.append('..')

from tqdm import tqdm
import logging
import torch
import numpy as np
from colorama import init
from termcolor import colored

import blink.ner as NER
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from blink.biencoder.biencoder import BiEncoderRanker, load_biencoder
from blink.crossencoder.crossencoder import CrossEncoderRanker, load_crossencoder
from blink.biencoder.data_process import (
    process_mention_data,
    get_candidate_representation,
)
import blink.candidate_ranking.utils as utils
from blink.crossencoder.train_cross import modify, evaluate
from blink.crossencoder.data_process import prepare_crossencoder_data
from blink.indexer.faiss_indexer import DenseFlatIndexer, DenseHNSWFlatIndexer

In [None]:
import statistics
import pickle

In [None]:
from addict import Dict

In [None]:
import pandas as pd

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import StandardScaler

class testData(Dataset):
    
    def __init__(self, X_data):
        self.X_data = X_data
        
    def __getitem__(self, index):
        return self.X_data[index]
        
    def __len__ (self):
        return len(self.X_data)

class binaryClassification(nn.Module):
    def __init__(self, n):
        super(binaryClassification, self).__init__()
        self.fc1 = nn.Linear(n, 2)
        self.fc2 = nn.Linear(2, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        
    def forward(self, inputs):
        x = self.fc1(inputs)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        #x = nn.Sigmoid(x)
        
        return x

In [None]:
HIGHLIGHTS = [
    "on_red",
    "on_green",
    "on_yellow",
    "on_blue",
    "on_magenta",
    "on_cyan",
]


def _print_colorful_text(input_sentence, samples):
    init()  # colorful output
    msg = ""
    if samples and (len(samples) > 0):
        msg += input_sentence[0 : int(samples[0]["start_pos"])]
        for idx, sample in enumerate(samples):
            msg += colored(
                input_sentence[int(sample["start_pos"]) : int(sample["end_pos"])],
                "grey",
                HIGHLIGHTS[idx % len(HIGHLIGHTS)],
            )
            if idx < len(samples) - 1:
                msg += input_sentence[
                    int(sample["end_pos"]) : int(samples[idx + 1]["start_pos"])
                ]
            else:
                msg += input_sentence[int(sample["end_pos"]) :]
    else:
        msg = input_sentence
        print("Failed to identify entity from text:")
    print("\n" + str(msg) + "\n")


def _print_colorful_prediction(
    idx, sample, e_id, e_title, e_text, e_url, show_url=False
):
    print(colored(sample["mention"], "grey", HIGHLIGHTS[idx % len(HIGHLIGHTS)]))
    to_print = "id:{}\ntitle:{}\ntext:{}\n".format(e_id, e_title, e_text[:256])
    if show_url:
        to_print += "url:{}\n".format(e_url)
    print(to_print)


def _annotate(ner_model, input_sentences):
    ner_output_data = ner_model.predict(input_sentences)
    sentences = ner_output_data["sentences"]
    mentions = ner_output_data["mentions"]
    samples = []
    for mention in mentions:
        record = {}
        record["label"] = "unknown"
        record["label_id"] = -1
        # LOWERCASE EVERYTHING !
        record["context_left"] = sentences[mention["sent_idx"]][
            : mention["start_pos"]
        ].lower()
        record["context_right"] = sentences[mention["sent_idx"]][
            mention["end_pos"] :
        ].lower()
        record["mention"] = mention["text"].lower()
        record["start_pos"] = int(mention["start_pos"])
        record["end_pos"] = int(mention["end_pos"])
        record["sent_idx"] = mention["sent_idx"]
        samples.append(record)
    return samples


def _load_candidates(
    entity_catalogue, entity_encoding, faiss_index=None, index_path=None, logger=None
):
    # only load candidate encoding if not using faiss index
    if faiss_index is None:
        candidate_encoding = torch.load(entity_encoding)
        indexer = None
    else:
        if logger:
            logger.info("Using faiss index to retrieve entities.")
        candidate_encoding = None
        assert index_path is not None, "Error! Empty indexer path."
        if faiss_index == "flat":
            indexer = DenseFlatIndexer(1)
        elif faiss_index == "hnsw":
            indexer = DenseHNSWFlatIndexer(1)
        else:
            raise ValueError("Error! Unsupported indexer type! Choose from flat,hnsw.")
        indexer.deserialize_from(index_path)

    # load all the 5903527 entities
    title2id = {}
    id2title = {}
    id2text = {}
    wikipedia_id2local_id = {}
    local_idx = 0
    with open(entity_catalogue, "r") as fin:
        lines = fin.readlines()
        for line in lines:
            entity = json.loads(line)

            if "idx" in entity:
                split = entity["idx"].split("curid=")
                if len(split) > 1:
                    wikipedia_id = int(split[-1].strip())
                else:
                    wikipedia_id = entity["idx"].strip()

                assert wikipedia_id not in wikipedia_id2local_id
                wikipedia_id2local_id[wikipedia_id] = local_idx

            title2id[entity["title"]] = local_idx
            id2title[local_idx] = entity["title"]
            id2text[local_idx] = entity["text"]
            local_idx += 1
    return (
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        indexer,
    )


def __map_test_entities(test_entities_path, title2id, logger):
    # load the 732859 tac_kbp_ref_know_base entities
    kb2id = {}
    missing_pages = 0
    n = 0
    with open(test_entities_path, "r") as fin:
        lines = fin.readlines()
        for line in lines:
            entity = json.loads(line)
            if entity["title"] not in title2id:
                missing_pages += 1
            else:
                kb2id[entity["entity_id"]] = title2id[entity["title"]]
            n += 1
    if logger:
        logger.info("missing {}/{} pages".format(missing_pages, n))
    return kb2id


def __load_test(test_filename, kb2id, wikipedia_id2local_id, logger, consider_all=False):
    test_samples = []
    with open(test_filename, "r") as fin:
        lines = fin.readlines()
        for line in lines:
            record = json.loads(line)
            record["label"] = str(record["label_id"])

            # for tac kbp we should use a separate knowledge source to get the entity id (label_id)
            if kb2id and len(kb2id) > 0:
                if record["label"] in kb2id:
                    record["label_id"] = kb2id[record["label"]]
                else:
                    if consider_all:
                        # NIL
                        record["label_id"] = -1
                    else:
                        continue

            # check that each entity id (label_id) is in the entity collection
            elif wikipedia_id2local_id and len(wikipedia_id2local_id) > 0:
                try:
                    key = int(record["label"].strip())
                    if key in wikipedia_id2local_id:
                        record["label_id"] = wikipedia_id2local_id[key]
                    else:
                        if consider_all:
                            # NIL
                            record["label_id"] = -1
                        else:
                            continue
                except:
                    if consider_all:
                        # NIL
                        record["label_id"] = -1
                    else:
                        continue

            # LOWERCASE EVERYTHING !
            record["context_left"] = record["context_left"].lower()
            record["context_right"] = record["context_right"].lower()
            record["mention"] = record["mention"].lower()
            test_samples.append(record)

    if logger:
        logger.info("{}/{} samples considered".format(len(test_samples), len(lines)))
    return test_samples


def _get_test_samples(
    test_filename, test_entities_path, title2id, wikipedia_id2local_id, logger, consider_all=False
):
    kb2id = None
    if test_entities_path:
        kb2id = __map_test_entities(test_entities_path, title2id, logger)
    test_samples = __load_test(test_filename, kb2id, wikipedia_id2local_id, logger, consider_all=consider_all)
    return test_samples


def _process_biencoder_dataloader(samples, tokenizer, biencoder_params):
    _, tensor_data = process_mention_data(
        samples,
        tokenizer,
        biencoder_params["max_context_length"],
        biencoder_params["max_cand_length"],
        silent=True,
        logger=None,
        debug=biencoder_params["debug"],
    )
    sampler = SequentialSampler(tensor_data)
    dataloader = DataLoader(
        tensor_data, sampler=sampler, batch_size=biencoder_params["eval_batch_size"]
    )
    return dataloader


def _run_biencoder(biencoder, dataloader, candidate_encoding, top_k=100, indexer=None, save_encodings=False):
    biencoder.model.eval()
    labels = []
    nns = []
    all_scores = []
    encodings = []
    for batch in tqdm(dataloader):
        context_input, _, label_ids = batch
        with torch.no_grad():
            if indexer is not None:
                context_encoding = biencoder.encode_context(context_input).numpy()
                context_encoding = np.ascontiguousarray(context_encoding)
                if save_encodings:
                    encodings.extend([e.tolist() for e in context_encoding])
                print('encoding_shape', context_encoding.shape)
                global my_enc
                my_enc = context_encoding
                scores, indicies = indexer.search_knn(context_encoding, top_k)
            else:
                scores = biencoder.score_candidate(
                    context_input, None, cand_encs=candidate_encoding  # .to(device)
                )
                scores, indicies = scores.topk(top_k)
                scores = scores.data.numpy()
                indicies = indicies.data.numpy()

        labels.extend(label_ids.data.numpy())
        nns.extend(indicies)
        all_scores.extend(scores)
    return labels, nns, all_scores, encodings


def _process_crossencoder_dataloader(context_input, label_input, crossencoder_params):
    tensor_data = TensorDataset(context_input, label_input)
    sampler = SequentialSampler(tensor_data)
    dataloader = DataLoader(
        tensor_data, sampler=sampler, batch_size=crossencoder_params["eval_batch_size"]
    )
    return dataloader


def _run_crossencoder(crossencoder, dataloader, logger, context_len, device="cuda"):
    crossencoder.model.eval()
    accuracy = 0.0
    crossencoder.to(device)

    res = evaluate(crossencoder, dataloader, device, logger, context_len, zeshel=False, silent=False)
    accuracy = res["normalized_accuracy"]
    logits = res["logits"]

    if accuracy > -1:
        predictions = np.argsort(logits, axis=1)
    else:
        predictions = []

    return accuracy, predictions, logits


def load_models(args, logger=None):

    # load biencoder model
    if logger:
        logger.info("loading biencoder model")
    with open(args.biencoder_config) as json_file:
        biencoder_params = json.load(json_file)
        biencoder_params["path_to_model"] = args.biencoder_model
    biencoder = load_biencoder(biencoder_params)

    crossencoder = None
    crossencoder_params = None
    if not args.fast:
        # load crossencoder model
        if logger:
            logger.info("loading crossencoder model")
        with open(args.crossencoder_config) as json_file:
            crossencoder_params = json.load(json_file)
            crossencoder_params["path_to_model"] = args.crossencoder_model
        crossencoder = load_crossencoder(crossencoder_params)

    # load candidate entities
    if logger:
        logger.info("loading candidate entities")
    (
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
    ) = _load_candidates(
        args.entity_catalogue,
        args.entity_encoding,
        faiss_index=getattr(args, 'faiss_index', None),
        index_path=getattr(args, 'index_path' , None),
        logger=logger,
    )
    
    nil_prediction_model_bi = None
    nil_prediction_features_bi = []
    nil_prediction_model = None
    nil_prediction_features = []
    
    if (hasattr(args, 'nil_prediction')
            and args.nil_prediction
            #and hasattr(args, 'nil_prediction_scaler')
            and hasattr(args, 'nil_prediction_model')
            and hasattr(args, 'nil_prediction_features')
        ):
        #nil_prediction_scaler = _load_pickle_model(args.nil_prediction_scaler)
        nil_prediction_model_bi = _load_pickle_model(args.nil_prediction_model_bi)
        nil_prediction_features_bi = args.nil_prediction_features_bi
        #nil_prediction_model = _load_torch_model(args.nil_prediction_model, len(args.nil_prediction_features))
        nil_prediction_model = _load_pickle_model(args.nil_prediction_model)
        nil_prediction_features = args.nil_prediction_features

    return (
        biencoder,
        biencoder_params,
        crossencoder,
        crossencoder_params,
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
        nil_prediction_model_bi,
        nil_prediction_features_bi,
        nil_prediction_model,
        nil_prediction_features
    )




In [None]:
def _scores_get_stats(scores):
    global bi_higher_is_better
    scores = scores.tolist()
    _stats = {
        "max": max(scores),
        "second": sorted(scores, reverse=bi_higher_is_better)[1],
        "min": min(scores),
        "mean": statistics.mean(scores),
        "median": statistics.median(scores),
        "stdev": statistics.stdev(scores),
    }
    return _stats

def _load_pickle_model(path):
    with open(path, 'rb') as fd:
        mdl = pickle.load(fd)
    return mdl

def _load_torch_model(path, n):
    model = binaryClassification(n)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
def run(
    args,
    logger,
    biencoder,
    biencoder_params,
    crossencoder,
    crossencoder_params,
    candidate_encoding,
    title2id,
    id2title,
    id2text,
    wikipedia_id2local_id,
    faiss_indexer=None,
    nil_prediction_model_bi = None,
    nil_prediction_features_bi = ['max_bi'],
    nil_prediction_model = None,
    nil_prediction_features = ['max_cross'],
    test_data=None,
    local_id2wikipedia_id=None
):

    if not test_data and not args.test_mentions and not args.interactive:
        msg = (
            "ERROR: either you start BLINK with the "
            "interactive option (-i) or you pass in input test mentions (--test_mentions)"
            "and test entitied (--test_entities)"
        )
        raise ValueError(msg)

    id2url = {
        v: "https://en.wikipedia.org/wiki?curid=%s" % k
        for k, v in wikipedia_id2local_id.items()
    }

    stopping_condition = False
    while not stopping_condition:

        samples = None

        if args.interactive:
            logger.info("interactive mode")

            # biencoder_params["eval_batch_size"] = 1

            # Load NER model
            ner_model = NER.get_model()

            # Interactive
            text = input("insert text:")

            # Identify mentions
            samples = _annotate(ner_model, [text])

            _print_colorful_text(text, samples)

        else:
            if logger:
                logger.info("test dataset mode")

            if test_data:
                samples = test_data
            else:
                # Load test mentions
                samples = _get_test_samples(
                    args.test_mentions,
                    args.test_entities,
                    title2id,
                    wikipedia_id2local_id,
                    logger,
                )

            stopping_condition = True
            
        if len(samples) == 0:
            return (
                -1,
                -1,
                -1,
                -1,
                len(samples),
                [],
                [],
            )

        # don't look at labels
        keep_all = (
            args.interactive
            or samples[0]["label"] == "unknown"
            or samples[0]["label_id"] < 0
            or (hasattr(args, 'keep_all') and args.keep_all)
        )

        # prepare the data for biencoder
        if logger:
            logger.info("preparing data for biencoder")
        dataloader = _process_biencoder_dataloader(
            samples, biencoder.tokenizer, biencoder_params
        )

        # run biencoder
        if logger:
            logger.info("run biencoder")
        top_k = args.top_k
        labels, nns, scores, encodings = _run_biencoder(
            biencoder, dataloader, candidate_encoding, top_k, faiss_indexer, bool(args.save_encodings) if hasattr(args, 'save_encodings') else False
        )

        if hasattr(args, 'save_encodings') and args.save_encodings:
            with open(args.save_encodings, 'w') as fd:
                for _enc, _lab in zip(encodings, labels):
                    assert len(_lab) == 1
                    _lab = int(_lab[0])
                    current = {
                        "encoding": _enc,
                        "label": _lab,
                        "wikipedia_id": 0 if local_id2wikipedia_id is None else local_id2wikipedia_id[_lab],
                        "title": id2title[_lab]
                    }
                    json.dump(current, fd)
                    fd.write('\n')

        if args.save_scores_bi:
            scores_bi = {
                "labels": [l.tolist() for l in labels],
                "scores": [l.tolist() for l in scores],
                "nns": [l.tolist() for l in nns]
            }
            with open(args.save_scores_bi, 'w') as fd:
                json.dump(scores_bi, fd)
                
        if hasattr(args, 'nil_prediction') and args.nil_prediction and nil_prediction_model:
            global nil_features_bi
            global nil_preds_bi

            nil_features_bi = np.array(list(map(lambda x: list(_scores_get_stats(x).values()), scores)))
            nil_features_bi = pd.DataFrame(data=nil_features_bi, columns=['max_bi',
                       'second_bi',
                       'min_bi',
                       'mean_bi',
                       'median_bi',
                       'stdev_bi'])
            nil_features_bi = nil_features_bi[nil_prediction_features_bi]

            #nil_features = nil_prediction_scaler.transform(nil_features)
            ##nil_preds = nil_prediction_model.predict(nil_features)
            #nil_features = torch.FloatTensor(nil_features)
            #nil_preds = nil_prediction_model(nil_features)
            #nil_preds = torch.sigmoid(nil_preds)
            #nil_preds = nil_preds.cpu().detach().numpy().reshape(-1,).tolist()
            
            nil_preds_bi = nil_prediction_model_bi.predict_proba(nil_features_bi)
            nil_preds_bi = np.array([i_1 for _, i_1 in nil_preds_bi])


        if args.interactive:

            print("\nfast (biencoder) predictions:")

            _print_colorful_text(text, samples)

            # print biencoder prediction
            idx = 0
            for entity_list, sample, _score, nil_p in zip(nns, samples, scores, nil_preds_bi):
                e_id = entity_list[0]
                e_title = id2title[e_id]
                e_text = id2text[e_id]
                e_url = id2url[e_id]
                _print_colorful_prediction(
                    idx, sample, e_id, e_title, e_text, e_url, args.show_url
                )
                print("bi_Score:", _score[0])
                print("all scores:", _score[1:])
                print('NIL:', nil_p)
                idx += 1
            print()

            if args.fast:
                # use only biencoder
                continue

        else:

            biencoder_accuracy = -1
            recall_at = -1
            if not keep_all:
                # get recall values
                top_k = args.top_k
                x = []
                y = []
                for i in range(1, top_k):
                    temp_y = 0.0
                    for label, top in zip(labels, nns):
                        if label in top[:i]:
                            temp_y += 1
                    if len(labels) > 0:
                        temp_y /= len(labels)
                    x.append(i)
                    y.append(temp_y)
                # plt.plot(x, y)
                biencoder_accuracy = y[0]
                recall_at = y[-1]
                print("biencoder accuracy: %.4f" % biencoder_accuracy)
                print("biencoder recall@%d: %.4f" % (top_k, y[-1]))

            if args.fast:

                predictions = []
                for entity_list in nns:
                    sample_prediction = []
                    for e_id in entity_list:
                        e_title = id2title[e_id]
                        sample_prediction.append(e_title)
                    predictions.append(sample_prediction)

                # use only biencoder
                return (
                    biencoder_accuracy,
                    recall_at,
                    -1,
                    -1,
                    len(samples),
                    predictions,
                    scores,
                )

        # prepare crossencoder data
        context_input, candidate_input, label_input = prepare_crossencoder_data(
            crossencoder.tokenizer, samples, labels, nns, id2title, id2text, keep_all,
        )

        context_input = modify(
            context_input, candidate_input, crossencoder_params["max_seq_length"]
        )

        dataloader = _process_crossencoder_dataloader(
            context_input, label_input, crossencoder_params
        )

        # run crossencoder and get accuracy
        accuracy, index_array, unsorted_scores = _run_crossencoder(
            crossencoder,
            dataloader,
            logger,
            context_len=biencoder_params["max_context_length"],
        )
        if hasattr(args, 'nil_prediction') and args.nil_prediction and nil_prediction_model:
            global nil_features
            global nil_preds

            nil_features = np.array(list(
                map(
                    lambda x: list(x[0].values()) + list(x[1].values()),
                    #lambda x: list(x[0].values()),
                    list(
                        zip(
                            list(map(_scores_get_stats, scores)), # bi scores
                            list(map(_scores_get_stats, unsorted_scores)) # cross scores
                        )))))
            nil_features = pd.DataFrame(data=nil_features, columns=['max_bi',
                       'second_bi',
                       'min_bi',
                       'mean_bi',
                       'median_bi',
                       'stdev_bi',
                       'max_cross',
                       'second_cross',
                       'min_cross',
                       'mean_cross',
                       'median_cross',
                       'stdev_cross'])
            nil_features = nil_features[nil_prediction_features]

            #nil_features = nil_prediction_scaler.transform(nil_features)
            ##nil_preds = nil_prediction_model.predict(nil_features)
            #nil_features = torch.FloatTensor(nil_features)
            #nil_preds = nil_prediction_model(nil_features)
            #nil_preds = torch.sigmoid(nil_preds)
            #nil_preds = nil_preds.cpu().detach().numpy().reshape(-1,).tolist()
            
            nil_preds = nil_prediction_model.predict_proba(nil_features)
            nil_preds = np.array([i_1 for _, i_1 in nil_preds])

        if args.save_scores_cross:
            print('----- Score cross length -----')
            print('labels', len(labels))
            print('unsorted_scores', len(unsorted_scores))
            print('index_array', len(index_array))
            print('nns', len(nns))
            scores_cross = {
                "labels": [l.tolist() for l in labels],
                "unsorted_scores": [l.tolist() for l in unsorted_scores],
                "index_array": index_array.tolist(),
                "nns": [l.tolist() for l in nns]
            }
            with open(args.save_scores_cross, 'w') as fd:
                json.dump(scores_cross, fd)

        if args.interactive:

            print("\naccurate (crossencoder) predictions:")

            _print_colorful_text(text, samples)

            # print crossencoder prediction
            idx = 0
            for entity_list, index_list, sample, _scores, _nil in zip(nns, index_array, samples, unsorted_scores, nil_preds):
                e_id = entity_list[index_list[-1]]
                e_title = id2title[e_id]
                e_text = id2text[e_id]
                e_url = id2url[e_id]
                _print_colorful_prediction(
                    idx, sample, e_id, e_title, e_text, e_url, args.show_url
                )
                print("cross_score:", _scores[index_list[-1]])
                print("all scores:", _scores)
                print("NIL score:", _nil)
                idx += 1
            print()
        else:

            scores = []
            predictions = []
            for entity_list, index_list, scores_list in zip(
                nns, index_array, unsorted_scores
            ):

                index_list = index_list.tolist()

                # descending order
                index_list.reverse()

                sample_prediction = []
                sample_scores = []
                for index in index_list:
                    e_id = entity_list[index]
                    e_title = id2title[e_id]
                    sample_prediction.append(e_title)
                    sample_scores.append(scores_list[index])
                predictions.append(sample_prediction)
                scores.append(sample_scores)

            crossencoder_normalized_accuracy = -1
            overall_unormalized_accuracy = -1
            if not keep_all:
                crossencoder_normalized_accuracy = accuracy
                print(
                    "crossencoder normalized accuracy: %.4f"
                    % crossencoder_normalized_accuracy
                )

                if len(samples) > 0:
                    overall_unormalized_accuracy = (
                        crossencoder_normalized_accuracy * len(label_input) / len(samples)
                    )
                print(
                    "overall unnormalized accuracy: %.4f" % overall_unormalized_accuracy
                )
            return (
                biencoder_accuracy,
                recall_at,
                crossencoder_normalized_accuracy,
                overall_unormalized_accuracy,
                len(samples),
                predictions,
                scores,
            )

In [None]:
# interactive
args = Dict()

args.nil_prediction = True
#args.nil_prediction_scaler = "../models/nil_pred/indexer/stdscaler+max_bi+second_bi+min_bi+mean_bi+median_bi+stdev_bi+max_cross+second_cross+min_cross+mean_cross+median_cross+stdev_cross+train_hard_aug.pkl"
#args.nil_prediction_model = "../models/nil_pred/indexer/nrl+max_bi+second_bi+min_bi+mean_bi+median_bi+stdev_bi+max_cross+second_cross+min_cross+mean_cross+median_cross+stdev_cross+train_hard_aug.torch"
args.nil_prediction_model_bi = "../models/nil_pred/models_ip/svc_bi+train_hard_aug.pkl"
args.nil_prediction_features_bi = ['max_bi', 'second_bi', 'min_bi', 'mean_bi', 'median_bi' ,'stdev_bi']
args.nil_prediction_model = "../models/nil_pred/models_ip/svc_all+train_hard_aug.pkl"
args.nil_prediction_features = ['max_bi', 'second_bi', 'min_bi', 'mean_bi', 'median_bi' ,'stdev_bi', 'max_cross', 'second_cross', 'min_cross', 'mean_cross', 'median_cross', 'stdev_cross']
args.interactive = True
args.top_k = 10
args.biencoder_config = "../models/biencoder_wiki_large.json"
args.biencoder_model = "../models/biencoder_wiki_large.bin"
args.crossencoder_config = "../models/crossencoder_wiki_large.json"
args.crossencoder_model = "../models/crossencoder_wiki_large.bin"
args.entity_catalogue = "../models/entity.jsonl"
args.entity_encoding = "../models/all_entities_large.t7"
#bi_higher_is_better = False
#args.faiss_index = "hnsw"
#args.index_path = "../models/faiss_hnsw_index.pkl"
bi_higher_is_better = True
args.faiss_index = "flat"
args.index_path = "../models/faiss_flat_index.pkl"

logger = utils.get_logger(None)

In [None]:
# test file
args = Dict()

args.nil_prediction = False

args.interactive = False
args.top_k = 10
args.biencoder_config = "../models/biencoder_wiki_large.json"
args.biencoder_model = "../models/biencoder_wiki_large.bin"
args.crossencoder_config = "../models/crossencoder_wiki_large.json"
args.crossencoder_model = "../models/crossencoder_wiki_large.bin"
args.entity_catalogue = "../models/entity.jsonl"
args.entity_encoding = "../models/all_entities_large.t7"
#args.faiss_index = "hnsw"
#args.index_path = "../models/faiss_hnsw_index.pkl"
args.faiss_index = None
args.index_path = None

args.test_mentions = "../data/BLINK_benchmark/ace2004_questions.jsonl"

args.save_scores_bi = "../output/scores_bi.json"
args.save_scores_cross = "../output/scores_cross.json"

args.keep_all = True
args.consider_all = True

logger = utils.get_logger(None)

In [None]:
# get encodings (starting from jsonl file)
args = Dict()

args.interactive = False
args.top_k = 10
args.biencoder_config = "../models/biencoder_wiki_large.json"
args.biencoder_model = "../models/biencoder_wiki_large.bin"
args.crossencoder_config = "../models/crossencoder_wiki_large.json"
args.crossencoder_model = "../models/crossencoder_wiki_large.bin"
args.entity_catalogue = "../models/entity.jsonl"
args.entity_encoding = "../models/all_entities_large.t7"
args.faiss_index = "hnsw"
args.index_path = "../models/faiss_hnsw_index.pkl"

#args.test_mentions = "../data/test10.jsonl"
args.test_mentions = "../data/BLINK_benchmark/AIDA-YAGO2_train.jsonl"

args.save_encodings = "../output/encodings/AIDA-YAGO2_train_encodings.jsonl"

logger = utils.get_logger(None)

In [None]:
# do not rerun
# takes time and memory
models = load_models(args, logger)
print("Models load complete.")
local_id2wikipedia_id = None

In [None]:
_get_local_id = True

In [None]:
biencoder = models[0]
biencoder_params = models[1]
crossencoder = models[2]
crossencoder_params = models[3]
candidate_encoding = models[4]
title2id = models[5]
id2title = models[6]
id2text = models[7]
wikipedia_id2local_id = models[8]
faiss_indexer = models[9]
if (hasattr(args, 'save_encodings') and args.save_encodings) or _get_local_id:
    local_id2wikipedia_id = {}
    for k,v in wikipedia_id2local_id.items():
        local_id2wikipedia_id[v] = k

In [None]:
# remember to define run function above
run(args, logger, *models, local_id2wikipedia_id=local_id2wikipedia_id)

# bi context encoding tests

In [None]:
ner_model = NER.get_model()

In [None]:
text = """Henry, king of England from 22 April 1509, married for the fifth time with Catherine Howard in 1540."""

In [None]:
samples = _annotate(ner_model, [text])

In [None]:
samples

In [None]:
bi_dataloader = _process_biencoder_dataloader(
            samples, biencoder.tokenizer, biencoder_params
        )

In [None]:
_context_encoding = None
biencoder.model.eval()
for batch in bi_dataloader:
    context_input, _, label_ids = batch
    with torch.no_grad():
        context_encoding = biencoder.encode_context(context_input).numpy()
        context_encoding = np.ascontiguousarray(context_encoding)
        _context_encoding = context_encoding if _context_encoding is None else np.concatenate([_context_encoding, context_encoding])

In [None]:
_context_encoding.shape

In [None]:
import time

start_time = time.monotonic()

retrieved = faiss_indexer.search_knn(_context_encoding, 10)

print('seconds: ', time.monotonic() - start_time)
retrieved

In [None]:
[id2title[i] for i in retrieved[1][0]]

In [None]:
id2text[4679300]

In [None]:
reco = faiss_indexer.index.reconstruct(4679300)
reco.shape

In [None]:
_context_encoding[0].shape

In [None]:
(_context_encoding[0] ** 2).sum()

In [None]:
(np.array([1,2,3])**2).sum()

In [None]:
np.dot(reco[:1024], _context_encoding[0])

In [None]:
a = reco[:1024]
b = _context_encoding[0]

In [None]:
np.linalg.norm(a-b)

In [None]:
np.sqrt(((a-b)**2).sum())

In [None]:
np.dot(reco, reco)

In [None]:
#original = b
original.shape

In [None]:
np.dot(b, original)

In [None]:
np.dot(b, a)

# cross

In [None]:
from blink.crossencoder.data_process import prepare_crossencoder_mentions

In [None]:
context_input_list = prepare_crossencoder_mentions(crossencoder.tokenizer, samples)

In [None]:
context_input = torch.LongTensor(context_input_list)
context_input

In [None]:
candidate_input = torch.reshape(context_input, (context_input.shape[0], 1, -1))
context_input = modify(
    context_input, candidate_input, crossencoder_params["max_seq_length"]
)

In [None]:
context_input.shape

In [None]:
context_input_list[0].shape

In [None]:
top_k = args.top_k
labels, nns, scores, encodings = _run_biencoder(
    biencoder, bi_dataloader, candidate_encoding, top_k, faiss_indexer, False
)

In [None]:
samples

In [None]:
nns

In [None]:
id2text[3422691]

In [None]:
# prepare crossencoder data
context_input, candidate_input, label_input = prepare_crossencoder_data(
    crossencoder.tokenizer, samples, labels, nns, id2title, id2text, keep_all=True,
)

context_input = modify(
    context_input, candidate_input, crossencoder_params["max_seq_length"]
)

In [None]:
cross_dataloader = _process_crossencoder_dataloader(
    context_input, label_input, crossencoder_params
)

crossencoder.model.eval()
device = "cuda"
crossencoder.to(device)
for batch in cross_dataloader:
    batch = tuple(t.to(device) for t in batch)
    with torch.no_grad():
        crossencoder(batch[0], batch[1], biencoder_params["max_context_length"])

In [None]:
crossencoder.model.eval()
with torch.no_grad():
    res = crossencoder(context_input.to("cuda"), label_input.to("cuda"), biencoder_params["max_context_length"])
res

In [None]:
res[1][1].max()

In [None]:
cross_dataloader = _process_crossencoder_dataloader(
    context_input, label_input, crossencoder_params
)

# run crossencoder and get accuracy
accuracy, index_array, unsorted_scores = _run_crossencoder(
    crossencoder,
    cross_dataloader,
    logger,
    context_len=biencoder_params["max_context_length"],
)

In [None]:
id2url = {
    v: "https://en.wikipedia.org/wiki?curid=%s" % k
    for k, v in wikipedia_id2local_id.items()
}

In [None]:
print("\naccurate (crossencoder) predictions:")

_print_colorful_text(text, samples)

# print crossencoder prediction
idx = 0
for entity_list, index_list, sample, _scores in zip(nns, index_array, samples, unsorted_scores):
    e_id = entity_list[index_list[-1]]
    e_title = id2title[e_id]
    e_text = id2text[e_id]
    e_url = id2url[e_id]
    _print_colorful_prediction(
        idx, sample, e_id, e_title, e_text, e_url, args.show_url
    )
    print("cross_score:", _scores[index_list[-1]])
    print("all scores:", _scores)
    idx += 1
print()

In [None]:
res

In [None]:
nns

In [None]:
for batch in cross_dataloader:
    print(batch[0].shape, batch[1].shape)

In [None]:
label_input

## save entity ids structures

In [None]:
with open('entity_ids/title2id.pickle', 'wb') as fd:
    pickle.dump(title2id, fd)

In [None]:
with open('entity_ids/id2title.pickle', 'wb') as fd:
    pickle.dump(id2title, fd)

In [None]:
with open('entity_ids/id2text.pickle', 'wb') as fd:
    pickle.dump(id2text, fd)

In [None]:
with open('entity_ids/wikipedia_id2local_id.pickle', 'wb') as fd:
    pickle.dump(wikipedia_id2local_id, fd)

In [None]:
with open('entity_ids/local_id2wikipedia_id.pickle', 'wb') as fd:
    pickle.dump(local_id2wikipedia_id, fd)

In [None]:
id2title[3674818]

In [None]:
wikipedia_id2local_id[243710]

In [None]:
local_id2wikipedia_id[3674818]

# get encodings fast

In [None]:
args.save_encodings = '../output/AIDA-YAGO2_train_encodings.jsonl'
args.test_mentions = '../data/BLINK_benchmark/AIDA-YAGO2_train.jsonl'

In [None]:
samples = _get_test_samples(
                    args.test_mentions,
                    args.test_entities,
                    title2id,
                    wikipedia_id2local_id,
                    logger,
                    consider_all= True
                )

In [None]:
def _run_biencoder_only_encodings(biencoder, dataloader, candidate_encoding, top_k=100, indexer=None, save_encodings=True):
    biencoder.model.eval()
    labels = []
    #nns = []
    #all_scores = []
    encodings = []
    for batch in tqdm(dataloader):
        context_input, _, label_ids = batch
        with torch.no_grad():
            if indexer is not None:
                context_encoding = biencoder.encode_context(context_input).numpy()
                context_encoding = np.ascontiguousarray(context_encoding)
                if save_encodings:
                    encodings.extend([e.tolist() for e in context_encoding])
                #scores, indicies = indexer.search_knn(context_encoding, top_k)
            else:
                raise Exception('not implemented for only getting encodings.')

        labels.extend(label_ids.data.numpy())
    return labels, encodings

In [None]:
dataloader = _process_biencoder_dataloader(
    samples, biencoder.tokenizer, biencoder_params
)

In [None]:
# run biencoder

top_k = args.top_k
labels, encodings = _run_biencoder_only_encodings(
    biencoder, dataloader, candidate_encoding, top_k, faiss_indexer, bool(args.save_encodings) if hasattr(args, 'save_encodings') else False
)

In [None]:
with open(args.save_encodings, 'w') as fd:
    for _enc, _lab in zip(encodings, labels):
        assert len(_lab) == 1
        _lab = int(_lab[0])
        current = {
            "encoding": _enc,
            "label": _lab,
            "wikipedia_id": 0 if local_id2wikipedia_id is None or _lab not in local_id2wikipedia_id else local_id2wikipedia_id[_lab],
            "title": id2title[_lab] if _lab in id2title else "**NOTFOUND**"
        }
        json.dump(current, fd)
        fd.write('\n')

# encodings

In [None]:
!wc -l ../output/*.jsonl

In [None]:
from blink.indexer.faiss_indexer import DenseFlatIndexer
from sklearn_extra.cluster import KMedoids

In [None]:
GetMedoid = lambda vX: KMedoids(n_clusters=1).fit(np.stack(vX)).cluster_centers_

In [None]:
encodings_p = ['../output/AIDA-YAGO2_testa_encodings.jsonl',
 '../output/AIDA-YAGO2_testb_encodings.jsonl',
 '../output/AIDA-YAGO2_train_encodings.jsonl']

In [None]:
train_df = pd.read_json('../output/AIDA-YAGO2_train_encodings.jsonl', lines=True)
testa_df = pd.read_json('../output/AIDA-YAGO2_testa_encodings.jsonl', lines=True)
testb_df = pd.read_json('../output/AIDA-YAGO2_testb_encodings.jsonl', lines=True)

In [None]:
# first
to_index = pd.DataFrame(train_df.query('wikipedia_id > 0').groupby('wikipedia_id')['encoding'].first())
to_index = to_index.sample(frac=1) # shuffle
to_index['index'] = range(to_index.shape[0])
to_index['wikipedia_id'] = to_index.index
to_index = to_index.set_index('index')
to_index

In [None]:
# medoid
to_index = pd.DataFrame(train_df.query('wikipedia_id > 0').groupby('wikipedia_id')['encoding'].apply(
    lambda x: GetMedoid(x)[0]))
to_index = to_index.sample(frac=1) # shuffle
to_index['index'] = range(to_index.shape[0])
to_index['wikipedia_id'] = to_index.index
to_index = to_index.set_index('index')
to_index

In [None]:
# index all (same entity multiple times)
to_index = pd.DataFrame(train_df.query('wikipedia_id > 0')[['wikipedia_id', 'encoding']])
to_index = to_index.sample(frac=1) # shuffle
to_index['index'] = range(to_index.shape[0])
to_index = to_index.set_index('index')
to_index

In [None]:
index_1 = DenseFlatIndexer(1024, 50000) # 1024 dimensions, 50000 default as BLINK
index_1.index.ntotal

In [None]:
# index in batch of 100 mentions
for i in range(100, to_index.shape[0], 100):
    #print(i-100, i)
    index_1.index_data(
        np.stack(
            to_index.iloc[i-100:i]['encoding'].values
        ).astype('float32'))

# index last batch
index_1.index_data(
    np.stack(
        to_index.iloc[i:to_index.shape[0]]['encoding'].values
    ).astype('float32'))

In [None]:
assert index_1.index.ntotal == to_index.shape[0]
index_1.index.ntotal

In [None]:
testa_linking_results = index_1.search_knn(np.stack(testa_df['encoding'].values).astype('float32'), 100)

In [None]:
myfun = np.vectorize(lambda x: to_index.iloc[x]['wikipedia_id'])

In [None]:
testa_linking_results_wiki_id = myfun(testa_linking_results[1])

In [None]:
def _eval_isin(x, array):
    array =  array.tolist()
    if x in array:
        return array.index(x)
    else:
        return None

In [None]:
eval_testa = pd.DataFrame(testa_df.apply(lambda x: _eval_isin(x['wikipedia_id'], testa_linking_results_wiki_id[x.name]), axis=1), columns=['found_at'])

In [None]:
eval_testa.dropna().shape[0]/eval_testa.shape[0]

In [None]:
# eval only on entities that are in the index
eval_testa['wikipedia_id'] = testa_df['wikipedia_id']
eval_testa_filtered = eval_testa[eval_testa['wikipedia_id'].isin(to_index['wikipedia_id'])]
eval_testa_filtered.shape[0]/eval_testa.shape[0]

In [None]:
def eval_test(test_df, name):
    eval_df = pd.DataFrame(data=[
         name,
         test_df.dropna().query('found_at < 1').shape[0]/test_df.shape[0],
         test_df.dropna().query('found_at < 2').shape[0]/test_df.shape[0],
         test_df.dropna().query('found_at < 3').shape[0]/test_df.shape[0],
         test_df.dropna().query('found_at < 5').shape[0]/test_df.shape[0],
         test_df.dropna().query('found_at < 10').shape[0]/test_df.shape[0],
         test_df.dropna().query('found_at < 30').shape[0]/test_df.shape[0],
         test_df.dropna().query('found_at < 100').shape[0]/test_df.shape[0],

    ], index = [
        'name',
        'recall@1',
        'recall@2',
        'recall@3',
        'recall@5',
        'recall@10',
        'recall@30',
        'recall@100',
    ])
    print(eval_df.to_markdown())
    print()
    print(eval_df.to_latex())
    return eval_df

In [None]:
eval_test(eval_testa_filtered, 'test a index all brutally')

In [None]:
testb_linking_results = index_1.search_knn(np.stack(testb_df['encoding'].values).astype('float32'), 100)

In [None]:
testb_linking_results_wiki_id = myfun(testb_linking_results[1])

In [None]:
eval_testb = pd.DataFrame(testb_df.apply(lambda x: _eval_isin(x['wikipedia_id'], testb_linking_results_wiki_id[x.name]), axis=1), columns=['found_at'])

In [None]:
# eval only on entities that are in the index
eval_testb['wikipedia_id'] = testb_df['wikipedia_id']
eval_testb_filtered = eval_testb[eval_testb['wikipedia_id'].isin(to_index['wikipedia_id'])]
eval_testb_filtered.shape[0]/eval_testb.shape[0]

In [None]:
eval_test(eval_testb_filtered, 'test b index all brutally')

In [None]:
(to_index['wikipedia_id'].value_counts() > 1).sum()

In [None]:
eval_testb

|            | 0                  |
|:-----------|:-------------------|
| name       | test a first       |
| recall@1   | 0.7374071015689513 |
| recall@2   | 0.8219102669969722 |
| recall@3   | 0.8560418387007982 |
| recall@5   | 0.8882466281310212 |
| recall@10  | 0.9193503991191853 |
| recall@30  | 0.9581612992017616 |
| recall@100 | 0.985962014863749  |

\begin{tabular}{ll}
\toprule
{} &             0 \\
\midrule
name       &  test a first \\
recall@1   &      0.737407 \\
recall@2   &       0.82191 \\
recall@3   &      0.856042 \\
recall@5   &      0.888247 \\
recall@10  &       0.91935 \\
recall@30  &      0.958161 \\
recall@100 &      0.985962 \\
\bottomrule
\end{tabular}

|            | 0                  |
|:-----------|:-------------------|
| name       | test b first       |
| recall@1   | 0.7006125574272588 |
| recall@2   | 0.7974732006125574 |
| recall@3   | 0.832312404287902  |
| recall@5   | 0.8709800918836141 |
| recall@10  | 0.9046707503828484 |
| recall@30  | 0.9525267993874426 |
| recall@100 | 0.9820061255742726 |

\begin{tabular}{ll}
\toprule
{} &             0 \\
\midrule
name       &  test b first \\
recall@1   &      0.700613 \\
recall@2   &      0.797473 \\
recall@3   &      0.832312 \\
recall@5   &       0.87098 \\
recall@10  &      0.904671 \\
recall@30  &      0.952527 \\
recall@100 &      0.982006 \\
\bottomrule
\end{tabular}



|            | 0                  |
|:-----------|:-------------------|
| name       | test a medoid      |
| recall@1   | 0.7946600605560143 |
| recall@2   | 0.8769611890999174 |
| recall@3   | 0.910542251582714  |
| recall@5   | 0.9380677126341866 |
| recall@10  | 0.9587118084227911 |
| recall@30  | 0.9870630333058079 |
| recall@100 | 0.9972474538948527 |

\begin{tabular}{ll}
\toprule
{} &              0 \\
\midrule
name       &  test a medoid \\
recall@1   &        0.79466 \\
recall@2   &       0.876961 \\
recall@3   &       0.910542 \\
recall@5   &       0.938068 \\
recall@10  &       0.958712 \\
recall@30  &       0.987063 \\
recall@100 &       0.997247 \\
\bottomrule
\end{tabular}

|            | 0                  |
|:-----------|:-------------------|
| name       | test b medoid      |
| recall@1   | 0.7687595712098009 |
| recall@2   | 0.8503062787136294 |
| recall@3   | 0.8862940275650842 |
| recall@5   | 0.9142419601837672 |
| recall@10  | 0.9498468606431854 |
| recall@30  | 0.9900459418070444 |
| recall@100 | 0.9973200612557427 |

\begin{tabular}{ll}
\toprule
{} &              0 \\
\midrule
name       &  test b medoid \\
recall@1   &        0.76876 \\
recall@2   &       0.850306 \\
recall@3   &       0.886294 \\
recall@5   &       0.914242 \\
recall@10  &       0.949847 \\
recall@30  &       0.990046 \\
recall@100 &        0.99732 \\
\bottomrule
\end{tabular}

|            | 0                         |
|:-----------|:--------------------------|
| name       | test a index all brutally |
| recall@1   | 0.9388934764657308        |
| recall@2   | 0.9584365538122763        |
| recall@3   | 0.9666941921277181        |
| recall@5   | 0.9774291219377924        |
| recall@10  | 0.9854115056427195        |
| recall@30  | 0.9931186347371318        |
| recall@100 | 0.9958711808422791        |

\begin{tabular}{ll}
\toprule
{} &                          0 \\
\midrule
name       &  test a index all brutally \\
recall@1   &                   0.938893 \\
recall@2   &                   0.958437 \\
recall@3   &                   0.966694 \\
recall@5   &                   0.977429 \\
recall@10  &                   0.985412 \\
recall@30  &                   0.993119 \\
recall@100 &                   0.995871 \\
\bottomrule
\end{tabular}

|            | 0                         |
|:-----------|:--------------------------|
| name       | test b index all brutally |
| recall@1   | 0.9291730474732006        |
| recall@2   | 0.9555895865237366        |
| recall@3   | 0.9655436447166922        |
| recall@5   | 0.9701378254211332        |
| recall@10  | 0.9785604900459418        |
| recall@30  | 0.9896630934150077        |
| recall@100 | 0.9950229709035222        |

\begin{tabular}{ll}
\toprule
{} &                          0 \\
\midrule
name       &  test b index all brutally \\
recall@1   &                   0.929173 \\
recall@2   &                    0.95559 \\
recall@3   &                   0.965544 \\
recall@5   &                   0.970138 \\
recall@10  &                    0.97856 \\
recall@30  &                   0.989663 \\
recall@100 &                   0.995023 \\
\bottomrule
\end{tabular}



In [None]:
over5 = pd.DataFrame(train_df.query('wikipedia_id != 0')['wikipedia_id'].value_counts()).query('wikipedia_id > 5').index

In [None]:
eval_testa['wikipedia_id'] = testa_df['wikipedia_id']

In [None]:
eval_testa[eval_testa['wikipedia_id'].isin(to_index['wikipedia_id'])].shape[0]/eval_testa.shape[0]

In [None]:
wiki_count_train_testa = pd.DataFrame(train_df.query('wikipedia_id != 0')['wikipedia_id'].value_counts()).join(
    pd.DataFrame(testa_df.query('wikipedia_id != 0')['wikipedia_id'].value_counts()),
    how='inner', lsuffix='train', rsuffix='_testa')
wiki_count_train_testa['min_count'] = wiki_count_train_testa.min(axis=1)
wiki_count_train_testa

In [None]:
wiki_count_train_testa

In [None]:
eval_test(eval_testa[eval_testa['wikipedia_id'].isin(
    wiki_count_train_testa.query('wikipedia_idtrain < 5 and wikipedia_id_testa > 5').index
)], 'aaa')

In [None]:
testa_linking_results_wiki_id

In [None]:
testa_df

In [None]:
testa_df.loc[1].name

In [None]:
encodings_df.shape

In [None]:
len(encodings_df['encoding'].iloc[0])

In [None]:
encodings_df['encoding'] = encodings_df['encoding'].apply(lambda x: np.array(x))

In [None]:
encodings_df.iloc[0:1000].query('wikipedia_id > 0').groupby('wikipedia_id')['encoding'].first()

In [None]:
encodings_df.query('wikipedia_id == 17867')

In [None]:
first = encodings_df.groupby('wikipedia_id')['encoding'].first()

In [None]:
mean = encodings_df.groupby('wikipedia_id')['encoding'].mean()

In [None]:
mean = mean.sample(frac=1)
mean.head()

In [None]:
first_index = DenseFlatIndexer(1024, 50000)

In [None]:
mean_batch_1 = mean.iloc[0:100]

In [None]:
np.stack(mean_batch_1.values).astype('float32')

In [None]:
mean_batch_1.values.astype('float32')

In [None]:
first_index.index_data(np.stack(mean_batch_1.values).astype('float32'))

In [None]:
mean_batch_1

In [None]:
encodings_df.query('wikipedia_id == 341466')

In [None]:
new_mention = encodings_df.loc[14882]['encoding'].astype('float32')
new_mention

In [None]:
np.stack([new_mention])

In [None]:
first_index.search_knn(np.stack([new_mention]), 2)

In [None]:
np.dot(mean_batch_1.values[2], new_mention)

In [None]:
mean_batch_1

In [None]:
mean_batch_2 = mean.iloc[100:200]

In [None]:
np.stack(mean_batch_2.values).shape

In [None]:
mean_batch_2

In [None]:
encodings_df.query('wikipedia_id == 5945')

In [None]:
new_mention2 = np.stack(encodings_df.loc[452].query('wikipedia_id == 5945')['encoding'].values)

In [None]:
first_index.index.ntotal

In [None]:
first_index.index_data(np.stack(mean_batch_2).astype('float32'))

In [None]:
first_index.search_knn(new_mention2.astype('float32'), 2)

In [None]:
mean_batch_2

In [None]:
# corretto. todo setup ambiente di testing