**Setup**

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ['HF_HOME'] = '/shared/data3/pk36/.cache'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
!export HF_HOME=/shared/data3/pk36/.cache

In [3]:
from model_definitions import initializeLLM, promptLLM, constructPrompt
import json
from utils import clean_json_string
from collections import deque
from taxonomy import Node

In [13]:
class Args:
    def __init__(self):
        
        self.topic = "natural language processing"
        self.init_dims = 2
        self.llm = 'samba'

        self.dataset = "Reasoning"
        self.data_dir = f"datasets/multi_dim/{self.dataset.lower().replace(' ', '_')}/"
        self.internal = f"{self.dataset}.txt"
        self.external = f"{self.dataset}_external.txt"
        self.groundtruth = "groundtruth.txt"
        
        self.length = 512
        self.dim = 768

        self.iters = 4

args = Args()

In [5]:
initializeLLM(args)



**Construct a 2-Level Multi-Dimensional Taxonomy**

In [6]:
from prompts import multi_dim_prompt

In [7]:
mod_topic = args.topic.replace(' ', '_').lower()
root = Node(
        id=0,
        label=mod_topic
    )
id2node = {0:root}
# we want to make this a directed acyclic graph (DAG) so maintain a list of the nodes
label2node = {mod_topic:root}

In [8]:
queue = deque([root])

while queue:
    curr_node = queue.popleft()
    label = curr_node.label
    # expand
    system_instruction, main_prompt, json_output_format = multi_dim_prompt(curr_node)
    prompts = [constructPrompt(args, system_instruction, main_prompt + "\n\n" + json_output_format)]
    outputs = promptLLM(args=args, prompts=prompts, max_new_tokens=2000, json_mode=True, temperature=0.1, top_p=0.99)[0]
    outputs = json.loads(clean_json_string(outputs)) if "```" in outputs else json.loads(outputs.strip())
    outputs = outputs[label]

    # add all children
    for key, value in outputs.items():
        key = key.replace(' ', '_').lower()
        if key not in label2node:
            child_node = Node(
                    id=len(id2node),
                    label=key,
                    description=value['description'],
                    datasets=value['datasets'],
                    methodologies=value['methodologies'],
                    evaluation_methods=value['evaluation_methods'],
                    applications=value['applications'],
                    parents=[curr_node]
                )
            curr_node.add_child(key, child_node)
            id2node[child_node.id] = child_node
            label2node[key] = child_node
            if child_node.level < args.init_dims:
                queue.append(child_node)
        else:
            child_node = label2node[key]
            child_node.add_parent(curr_node)
            child_node.add_dataset(value['datasets'])
            child_node.add_methodology(value['methodologies'])
            child_node.add_evaluation_method(value['evaluation_methods'])
            child_node.add_application(value['applications'])

In [9]:
root.display(0, indent_multiplier=5, simple=True)

Label: natural_language_processing
Description: None
Level: 0
----------------------------------------
Children:
     Label: text_classification
     Description: The task of assigning predefined categories to text based on its content.
     Level: 1
     ----------------------------------------
     Children:
          Label: topic_modeling
          Description: Identifying underlying topics or themes in a large corpus of text.
          Level: 2
          ----------------------------------------
          Label: text_classification_for_low_resource_languages
          Description: Classifying text in low-resource languages with limited labeled data.
          Level: 2
          ----------------------------------------
          Label: multimodal_text_classification
          Description: Classifying text with multimodal inputs (e.g. images, audio, video).
          Level: 2
          ----------------------------------------
          Label: adversarial_text_classification
          

**Read in dataset**

In [10]:
from datasets import load_dataset
from tqdm import tqdm

In [11]:
if not os.path.exists(args.data_dir):
    os.makedirs(args.data_dir)

In [12]:
ds = load_dataset("TimSchopf/nlp_taxonomy_data")

In [None]:
with open(os.path.join(args.data_dir, 'external.txt'), 'w') as e, open(os.path.join(args.data_dir, 'internal.txt'), 'w') as i:
    external_count = 0
    internal_count = 0
    for p in tqdm(ds['train']):
        temp_dict = {"Title": p['title'], "Abstract": p['abstract']}
        formatted_dict = json.dumps(temp_dict)
        if (args.dataset in p['classification_labels']):
            i.write(f'{formatted_dict}\n')
            internal_count += 1
        else:
            e.write(f'{formatted_dict}\n')
            external_count += 1
print(f'Internal: {internal_count}, External Count: {external_count}')

100%|██████████| 178521/178521 [00:13<00:00, 13026.63it/s]


8672 169849


**Enrich each node with a set of terms**

In [None]:
id2node

{0: Node(label=natural_language_processing, description=None, level=0),
 1: Node(label=text_classification, description=The task of assigning predefined categories to text based on its content., level=1),
 2: Node(label=language_modeling, description=The task of predicting the next word in a sequence of text given the context., level=1),
 3: Node(label=named_entity_recognition, description=The task of identifying and categorizing named entities in text into predefined categories., level=1),
 4: Node(label=machine_translation, description=The task of translating text from one language to another., level=1),
 5: Node(label=question_answering, description=The task of answering questions based on the content of a given text., level=1),
 6: Node(label=text_summarization, description=The task of summarizing a long piece of text into a shorter summary., level=1),
 7: Node(label=dialogue_systems, description=The task of generating responses to user input in a conversational setting., level=1),

In [75]:
id2node[3].methodologies, id2node[11].methodologies

(['Supervised learning with conditional random fields',
  'Deep learning with recurrent neural networks'],
 ['Supervised learning with conditional random fields',
  'Unsupervised learning with clustering algorithms'])