In [30]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import json
import re
import pickle
import random
import time
from xml.dom import minidom
import xml.etree.ElementTree as ET
from tqdm import tqdm

sys.path.append('../../')
from utils import *

In [31]:
train_data_path = './BlurbGenreCollection_EN_train.txt'
dev_data_path = './BlurbGenreCollection_EN_dev.txt'
test_data_path = './BlurbGenreCollection_EN_test.txt'

In [3]:
# load the train_data as xml
data = ""
train_data = []
with open(train_data_path, 'r') as f:
    for line in f.readlines():
        data += line

In [4]:
train_data = ET.fromstring(data.replace('&', '&amp;'))
total_labels = 0
train_data_dict = []
for book in tqdm(train_data.findall('book'), total = len(train_data.findall('book'))):
    
    # check the keys
    title = book.find('title').text
    text = book.find('body').text
    i = 0
    labels = []
    topics = book.find('metadata').find('topics')

    while True:
        label = topics.findall(f'd{str(i)}')
        if len(label) == 0:
            break
        for l in label:
            total_labels += 1
            labels.append(l.text)
        i += 1
    train_data_dict.append({'token': 'Title: ' + title + '. ' + 'Text: ' + text, 'label': labels})

# store the train_data_dict as json lines
with open('./train_data.jsonl', 'w') as f:
    for data in train_data_dict:
        json.dump(data, f, ensure_ascii=False)
        f.write('\n')
total_labels

100%|██████████| 58715/58715 [00:00<00:00, 209490.24it/s]


176558

In [5]:
# do the same thing for dev_data
data = ""
dev_data = []
with open(dev_data_path, 'r') as f:
    for line in f.readlines():
        data += line

dev_data = ET.fromstring(data.replace('&', '&amp;'))

dev_data_dict = []
for book in dev_data.findall('book'):

    # check the keys
    title = book.find('title').text
    text = book.find('body').text
    i = 0
    labels = []
    topics = book.find('metadata').find('topics')

    while True:
        label = topics.findall(f'd{str(i)}')
        if len(label) == 0:
            break
        for l in label:
            labels.append(l.text)
        i += 1
    dev_data_dict.append({'token': 'Title: ' + title + '. ' + 'Text: ' + text, 'label': labels})
    
# store the dev_data_dict as json lines
with open('./dev_data.jsonl', 'w') as f:
    for data in dev_data_dict:
        json.dump(data, f, ensure_ascii=False)
        f.write('\n')

In [6]:
# do the same thing for test_data
data = ""
test_data = []
with open(test_data_path, 'r') as f:
    for line in f.readlines():
        data += line

test_data = ET.fromstring(data.replace('&', '&amp;'))

test_data_dict = []
for book in test_data.findall('book'):
        
        # check the keys
        title = book.find('title').text
        text = book.find('body').text
        i = 0
        labels = []
        topics = book.find('metadata').find('topics')
    
        while True:
            label = topics.findall(f'd{str(i)}')
            if len(label) == 0:
                break
            for l in label:
                labels.append(l.text)
            i += 1
        test_data_dict.append({'token': 'Title: ' + title + '. ' + 'Text: ' + text, 'label': labels})

# store the test_data_dict as json lines
with open('./test_data.jsonl', 'w') as f:
    for data in test_data_dict:
        json.dump(data, f, ensure_ascii=False)
        f.write('\n')

In [20]:
train_label_set = set()
for data in train_data_dict:
    for label in data['label']:
        train_label_set.add(label)

In [21]:
# load the hierarchy
from collections import defaultdict
hiera = defaultdict(list)
labels = set()
with open('./hierarchy.txt', 'r') as f:
    for line in f.readlines():
        line = line.replace('\n', '')
        line = line.split('\t')
        labels.add(line[0])
        if len(line) != 2:
            continue
        labels.add(line[1])
        if (line[0] not in train_label_set) or (line[1] not in train_label_set):
            continue
        else:
            hiera[line[0]].append(line[1])

labels = train_label_set

r_hiera = {}
for parent, childrens in hiera.items():
    for children in childrens:
        r_hiera[children] = parent

# get the labels that are not in the r_hiera keys
labels = list(labels)
for label in labels:
    if label not in r_hiera.keys():
        r_hiera[label] = 'Root'


In [19]:
train_label_set = set()
for data in train_data_dict:
    for label in data['label']:
        train_label_set.add(label)
set(labels) - train_label_set

{'Childrens Media Tie-In Books',
 'Children’s Activity & Novelty Books',
 'Children’s Board Books',
 'Children’s Boxed Sets',
 'Children’s Chapter Books',
 'Children’s Picture Books'}

In [25]:
def compute_path_no(labels):
    visited = set()
    paths = defaultdict(list)
    path = 0
    for label in labels:
        if label in visited:
            continue
        visited.add(label)
        while label != 'Root':
            if label not in paths:
                paths[label] = []
            paths[r_hiera[label]].append(label)

            label = r_hiera[label]
            visited.add(label)
            if label in visited:
                break
    path = np.sum([1 for path in paths.values() if len(path) < 1])

    return path

# compute the path number for each train_data
for data in train_data_dict:
    data['path_no'] = compute_path_no(data['label'])
for data in dev_data_dict:
    data['path_no'] = compute_path_no(data['label'])
for data in test_data_dict:
    data['path_no'] = compute_path_no(data['label'])

from collections import Counter
print(Counter([data['path_no'] for data in train_data_dict]))
print(Counter([data['path_no'] for data in dev_data_dict]))
print(Counter([data['path_no'] for data in test_data_dict]))

Counter({1: 34000, 2: 18445, 3: 5656, 4: 606, 5: 8})
Counter({1: 8507, 2: 4659, 3: 1456, 4: 161, 5: 2})
Counter({1: 10636, 2: 5800, 3: 1763, 4: 193, 5: 2})


In [26]:
root = [k for k, v in r_hiera.items() if v == 'Root']
with open('./bgc.taxonomy', 'w') as f:
    queue = []
    f.write('Root\t')
    f.write('\t'.join(root))
    queue.extend(root)

    f.write('\n')
    while len(queue) > 0:
        parent = queue.pop(0)
        if parent in hiera:
            f.write(parent + '\t')
            f.write('\t'.join(hiera[parent]))
            queue.extend(hiera[parent])
            f.write('\n')

In [27]:
hiera, _label_dict, r_hiera, label_depth = get_hierarchy_info('./bgc.taxonomy')


# dump _label_dict as value_dict.pt
torch.save(_label_dict, './value_dict.pt')

In [28]:
label_hier = dict(hiera)
label_hier.pop('Root')

label_hier = {_label_dict[k]: set([_label_dict[v] for v in vs]) for k, vs in label_hier.items()}
label_hier

with open('./slot.pt', 'wb') as f:
    torch.save(label_hier, f)

In [29]:
label_hier

{0: {7, 8},
 2: {9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21},
 3: {22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
 4: {40, 41, 42, 43, 44, 45, 46, 47, 48},
 5: {49, 50, 51, 52},
 7: {53, 54, 55, 56, 57},
 9: {58, 59, 60, 61},
 13: {62, 63, 64, 65, 66},
 14: {67, 68, 69, 70, 71, 72, 73, 74},
 15: {75, 76, 77},
 22: {78, 79},
 24: {80, 81, 82, 83, 84},
 28: {85, 86, 87, 88, 89, 90, 91, 92, 93, 94},
 30: {95, 96, 97},
 31: {98, 99},
 32: {100, 101, 102, 103},
 33: {104, 105, 106},
 34: {107, 108, 109, 110},
 35: {111, 112},
 36: {113, 114, 115},
 37: {116, 117, 118, 119, 120, 121, 122},
 38: {123, 124, 125},
 39: {126, 127, 128, 129},
 95: {130, 131, 132, 133, 134, 135},
 96: {136, 137, 138},
 97: {139, 140, 141, 142, 143, 144, 145}}