**Setup**

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ['HF_HOME'] = '/shared/data3/pk36/.cache'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
!export HF_HOME=/shared/data3/pk36/.cache

In [28]:
from model_definitions import initializeLLM, promptLLM, constructPrompt
import json
from utils import clean_json_string
from collections import deque
from taxonomy import Node

In [None]:
class Args:
    def __init__(self):
        
        self.topic = "natural language processing"
        self.init_dims = 2
        self.llm = 'samba'

        self.dataset = "ner_event_kgc"
        self.data_dir = f"datasets/gen_kgc/{self.dataset}/"
        self.internal = f"{self.dataset}.txt"
        self.external = f"{self.dataset}_external.txt"
        self.groundtruth = "groundtruth.txt"
        
        self.length = 512
        self.dim = 768

        self.iters = 4

args = Args()

In [6]:
initializeLLM(args)



**Construct a 2-Level Multi-Dimensional Taxonomy**

In [7]:
from prompts import multi_dim_prompt

In [85]:
mod_topic = args.topic.replace(' ', '_').lower()
root = Node(
        id=0,
        label=mod_topic
    )
id2node = {0:root}
# we want to make this a directed acyclic graph (DAG) so maintain a list of the nodes
label2node = {mod_topic:root}

In [86]:
queue = deque([root])

while queue:
    curr_node = queue.popleft()
    label = curr_node.label
    # expand
    system_instruction, main_prompt, json_output_format = multi_dim_prompt(curr_node)
    prompts = [constructPrompt(args, system_instruction, main_prompt + "\n\n" + json_output_format)]
    outputs = promptLLM(args=args, prompts=prompts, max_new_tokens=2000, json_mode=True, temperature=0.1, top_p=0.99)[0]
    outputs = json.loads(clean_json_string(outputs)) if "```" in outputs else json.loads(outputs.strip())
    outputs = outputs[label]

    # add all children
    for key, value in outputs.items():
        key = key.replace(' ', '_').lower()
        if key not in label2node:
            child_node = Node(
                    id=len(id2node),
                    label=key,
                    description=value['description'],
                    datasets=value['datasets'],
                    methodologies=value['methodologies'],
                    evaluation_methods=value['evaluation_methods'],
                    applications=value['applications'],
                    parents=[curr_node]
                )
            curr_node.add_child(key, child_node)
            id2node[child_node.id] = child_node
            label2node[key] = child_node
            if child_node.level < args.init_dims:
                queue.append(child_node)
        else:
            child_node = label2node[key]
            child_node.add_parent(curr_node)
            child_node.add_dataset(value['datasets'])
            child_node.add_methodology(value['methodologies'])
            child_node.add_evaluation_method(value['evaluation_methods'])
            child_node.add_application(value['applications'])

In [None]:
root.display(0, indent_multiplier=5, simple=True)

Label: natural_language_processing
Description: None
Level: 0
----------------------------------------
Children:
     Label: text_classification
     Description: The task of assigning predefined categories to text based on its content.
     Level: 1
     Datasets: ['A dataset of labeled text from news articles for training text classification models.', 'A collection of news articles categorized by topic', 'A dataset of labeled product reviews for sentiment analysis', 'A dataset of text from product reviews for training sentiment analysis models.']
     Methodologies: ['Convolutional Neural Networks (CNNs)', 'Supervised learning with convolutional neural networks', 'Support Vector Machines (SVMs)', 'Random Forests', 'Transfer learning with pre-trained language models']
     Evaluation Methods: ['A metric to evaluate the robustness of text classification models to adversarial attacks.', 'Precision and recall for topic modeling', 'Accuracy metric to evaluate the performance of text class

In [72]:
id2node

{0: Node(label=natural_language_processing, description=None, level=0),
 1: Node(label=text_classification, description=The task of assigning predefined categories to text based on its content., level=1),
 2: Node(label=language_modeling, description=The task of predicting the next word in a sequence of text given the context., level=1),
 3: Node(label=named_entity_recognition, description=The task of identifying and categorizing named entities in text into predefined categories., level=1),
 4: Node(label=machine_translation, description=The task of translating text from one language to another., level=1),
 5: Node(label=question_answering, description=The task of answering questions based on the content of a given text., level=1),
 6: Node(label=text_summarization, description=The task of summarizing a long piece of text into a shorter summary., level=1),
 7: Node(label=dialogue_systems, description=The task of generating responses to user input in a conversational setting., level=1),

In [75]:
id2node[3].methodologies, id2node[11].methodologies

(['Supervised learning with conditional random fields',
  'Deep learning with recurrent neural networks'],
 ['Supervised learning with conditional random fields',
  'Unsupervised learning with clustering algorithms'])