In [None]:
!pip install anytree

In [None]:
import pandas as pd
from anytree import Node, RenderTree, AsciiStyle, PreOrderIter
from copy import deepcopy

In [None]:
# Level of the cut
X = 2

# Read CSV
df = pd.read_csv('FinalDataset/DBpediaClasses.csv')

In [None]:
# Create dictionary to store nodes
nodes = {}

# Build the tree and the nodes
for _, row in df.iterrows():
    parent_name = row['SubClass']
    child_name = row['Class']
    
    # Creat parent node if it does not exist
    if parent_name not in nodes:
        nodes[parent_name] = Node(parent_name)
    
    # Create children node and assign it to the parent
    if child_name not in nodes:
        nodes[child_name] = Node(child_name, parent=nodes[parent_name])
    else:
        nodes[child_name].parent = nodes[parent_name]

root_nodes = [node for node in nodes.values() if node.is_root]

In [None]:
# Show the tree
for root_node in root_nodes:
    for pre, fill, node in RenderTree(root_node, style=AsciiStyle()):
        print(f"{pre}{node.name}")

In [None]:
root_nodes_copy = deepcopy(root_nodes)

In [None]:
def prune_tree(node, max_depth, current_depth=1):
    if current_depth >= max_depth:
        node.children = []
    else:
        for child in node.children:
            prune_tree(child, max_depth, current_depth + 1)

# Cut the tree
for root in root_nodes_copy:
    prune_tree(root, X)

# Show the new tree
for pre, fill, node in RenderTree(root_nodes_copy[0]):
    print("%s%s" % (pre, node.name))

In [None]:
df_new = pd.read_csv('FinalDataset/final_dataset.csv')

In [None]:
def find_leaves(node):
    if not node.children:
        return [node]
    leaves = []
    for child in node.children:
        leaves.extend(find_leaves(child))
    return leaves

leaves_names = []
for root in root_nodes_copy:
    leaves = find_leaves(root)
    leaves_names.extend([leaf.name for leaf in leaves])

print(leaves_names)

In [None]:
filtered_df = df_new[df_new['Subclass'].isin(leaves_names)]

filtered_df = filtered_df.rename(columns={'Class': 'NewClass'})
num_labels = filtered_df['NewClass'].nunique()

print(num_labels)

filtered_df.to_csv('FinalDataset/polished_dataset_realsub_'+str(num_labels)+'.csv', index=False)