In [None]:
from typing import List, Tuple, Any
import time

import numpy as np
import tensorflow_hub as hub
import tensorflow as tf

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.commands.utils import expand_path


@register('use_sentence_ranker')
class USESentenceRanker(Component):
    def __init__(self, top_n=20, return_vectors=False, active: bool = True, **kwargs):
        """
        :param top_n: top n sentences to return
        :param return_vectors: return unranged USE vectors instead of sentences
        :param active: when is not active, return all sentences
        """
        self.embed = hub.Module(str(expand_path("general_electrics/hub")))
        self.session = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=0.4,
                allow_growth=False
            )))
        self.session.run([tf.global_variables_initializer(), tf.tables_initializer()])
        self.q_ph = tf.placeholder(shape=(None,), dtype=tf.string)  # question placeholder
        self.c_ph = tf.placeholder(shape=(None,), dtype=tf.string)  # context placeholder
        self.q_emb = self.embed(self.q_ph)  # question embeddng
        self.c_emb = self.embed(self.c_ph)  # context embedding
        self.top_n = top_n
        self.return_vectors = return_vectors
        self.active = active

    def __call__(self, query_context: List[Tuple[str, List[str]]]):
        """
        Rank sentences and return top n sentences.
        Args: 
            [(query1_str, [context1_str, context2_str...]), 
             (query2_str, [context...])]
        """
        predictions = []
        all_top_scores = []
        fake_scores = [0.001] * len(query_context)

        for el in query_context:  # el means elements
            # DEBUG
            # start_time = time.time()
            qe, ce = self.session.run([self.q_emb, self.c_emb],  # qe: query embedding; ce: context embeddings matrix
                                      feed_dict={
                                          self.q_ph: [el[0]], # 1 query string
                                          self.c_ph: el[1],   # list of context strings
                                      })
            # print("Time spent: {}".format(time.time() - start_time))
            if self.return_vectors:
                predictions.append((qe, ce))
            else:
                scores = (qe @ ce.T).squeeze()  # cosine similary(note that USE embedding are approximately normalized, so we do not normalize them)
                if self.active:  # return top sents
                    thresh = self.top_n
                else:            # return all sents
                    thresh = len(el[1])
                if scores.size == 1:   # when only has 1 context string
                    scores = [scores]
                top_scores = np.sort(scores)[::-1][:thresh]
                all_top_scores.append(top_scores)
                # Sorts the sentences!
                # sentence_top_ids = np.argsort(scores)[::-1][:thresh]
                # predictions.append([el[1][x] for x in sentence_top_ids])
                predictions.append(el[1])
        if self.return_vectors:
            return predictions, fake_scores
        else:
            return [' '.join(sentences) for sentences in predictions], all_top_scores
