In [None]:
!pip install fairseq

In [20]:
import os
import json
import gzip
import pandas as pd
from urllib.request import urlopen
from fairseq.models.roberta import RobertaModel
from statistics import median
import time

In [21]:
BASE_DIR = 'drive/MyDrive/taxo-replica'

In [30]:
dir = f'{BASE_DIR}/roberta.large.mnli'
assert os.path.exists(f'{dir}/model.pt')
roberta = RobertaModel.from_pretrained(dir, checkpoint_file='model.pt')
roberta.eval()  # disable dropout (or leave in train mode to finetune)
roberta.cuda()

def similarity(text, class_name):
  tokens = roberta.encode(text, f'this document is about {class_name.lower()}')
  logits = roberta.predict('mnli', tokens, return_logits=True)
  probabilities = logits.softmax(dim=-1).tolist()[0]
  entailment_probability = probabilities[2]
  return entailment_probability
    

In [33]:
caches = dict()

class Node:
  def __init__(self, name, dic, parent, depth):
    self.name = name
    self.dic = dic
    self.parent = parent
    self.depth = depth
    self.cache = dict()

  def children(self):
    return [Node(k, v, self, self.depth + 1) for k, v in self.dic.items()]
  
  def selected_children(self, doc):
    return sorted(self.children(), key=lambda c: c.similarity(doc), reverse=True)[:(self.depth + 2)]

  def similarity(self, text):
    cache = caches.get(text, dict())
    if text not in caches:
      caches[text] = cache
    if self.name not in cache:
      cache[self.name] = similarity(text, self.name)
    return cache[self.name]
  
  def path_score(self, doc):
    if self.parent is None:
      return 1
    
    return self.parent.path_score(doc) * self.similarity(doc)

  def confidence(self, text):
    competitors = [self.parent] + self.parent.children()
    return self.similarity(text) - max([n.similarity(text) for n in competitors])

  def confidence_threshold(self, all_documents):
    return median([self.confidence(doc.text) for doc in all_documents if doc.tagged_with(self.name)])


def flatten(list_of_lists):
  return [item for l in list_of_lists for item in l]

def aggregate_children(children_list, doc):
  children = flatten(children_list)
  if not children:
    return children
  depth = children[0].depth
  return sorted(children, key=lambda n: n.path_score(doc), reverse=True)[:((depth + 1) ** 2)]

def deeper_nodes(nodes, doc):
  children_list = [n.selected_children(doc) for n in nodes]
  return aggregate_children(children_list, doc)

def get_candidates(doc, tree):
  root = Node('root', tree, None, 0)

  depth1 = root.selected_children(doc)

  candidates = []
  nodes = depth1

  while nodes:
    candidates = candidates + nodes
    nodes = deeper_nodes(nodes, doc)

  return candidates

class Doc:
  def __init__(self, text, tree):
    self.text = text
    self.candidates = get_candidates(text, tree)
    self.class_names = {n.name for n in self.candidates}

  def tagged_with(self, name):
    return name in self.class_names

  def core_classes(self, all_documents):
    return [n.name for n in self.candidates if n.confidence(self.text) >= n.confidence_threshold(all_documents)]


In [None]:
def get_documents(corpus, tree):
  all_documents = []

  for text in corpus:
    start = time.time()
    all_documents.append(Doc(text, tree))
    end = time.time()
    print(f'{len(all_documents)} out of {len(corpus)} complete taking {end - start} seconds')

  return all_documents

def get_corpus():
  with open(f'{BASE_DIR}/amazon/test.json') as reviewsFile:
    reviews = json.load(reviewsFile)
    corpus = [r['reviewText'] for r in reviews]
    return corpus

with open(f'{BASE_DIR}/amazon/taxonomy.json') as f:
  corpus = get_corpus()[:10]
  tree = json.load(f)
  all_documents = get_documents(corpus, tree)

  doc = all_documents[0]

  print(doc.core_classes(all_documents))


