In [16]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import json
import re
import pickle
import random
import time
from xml.dom import minidom
import xml.etree.ElementTree as ET
from tqdm import tqdm

sys.path.append('../../')
from utils import *

In [43]:
train_data_path = './BlurbGenreCollection_EN_train.txt'
dev_data_path = './BlurbGenreCollection_EN_dev.txt'
test_data_path = './BlurbGenreCollection_EN_test.txt'

In [44]:
# load the train_data as xml
data = ""
train_data = []
with open(train_data_path, 'r') as f:
    for line in f.readlines():
        data += line

In [45]:
train_data = ET.fromstring(data.replace('&', '&amp;'))
total_labels = 0
train_data_dict = []
for book in tqdm(train_data.findall('book'), total = len(train_data.findall('book'))):
    
    # check the keys
    title = book.find('title').text
    text = book.find('body').text
    i = 0
    labels = []
    topics = book.find('metadata').find('topics')

    while True:
        label = topics.findall(f'd{str(i)}')
        if len(label) == 0:
            break
        for l in label:
            total_labels += 1
            labels.append(l.text)
        i += 1
    train_data_dict.append({'token': 'Title: ' + title + '. ' + 'Text: ' + text, 'label': labels})

# store the train_data_dict as json lines
with open('./train_data.jsonl', 'w') as f:
    for data in train_data_dict:
        json.dump(data, f, ensure_ascii=False)
        f.write('\n')
total_labels

100%|██████████| 58715/58715 [00:00<00:00, 307973.13it/s]


176558

In [46]:
# do the same thing for dev_data
data = ""
dev_data = []
with open(dev_data_path, 'r') as f:
    for line in f.readlines():
        data += line

dev_data = ET.fromstring(data.replace('&', '&amp;'))

dev_data_dict = []
for book in dev_data.findall('book'):

    # check the keys
    title = book.find('title').text
    text = book.find('body').text
    i = 0
    labels = []
    topics = book.find('metadata').find('topics')

    while True:
        label = topics.findall(f'd{str(i)}')
        if len(label) == 0:
            break
        for l in label:
            labels.append(l.text)
        i += 1
    dev_data_dict.append({'token': 'Title: ' + title + '. ' + 'Text: ' + text, 'label': labels})
    
# store the dev_data_dict as json lines
with open('./dev_data.jsonl', 'w') as f:
    for data in dev_data_dict:
        json.dump(data, f, ensure_ascii=False)
        f.write('\n')

In [None]:
# do the same thing for test_data
data = ""
test_data = []
with open(test_data_path, 'r') as f:
    for line in f.readlines():
        data += line

test_data = ET.fromstring(data.replace('&', '&amp;'))

test_data_dict = []
for book in test_data.findall('book'):
        
        # check the keys
        title = book.find('title').text
        text = book.find('body').text
        i = 0
        labels = []
        topics = book.find('metadata').find('topics')
    
        while True:
            label = topics.findall(f'd{str(i)}')
            if len(label) == 0:
                break
            for l in label:
                labels.append(l.text)
            i += 1
        test_data_dict.append({'token': 'Title: ' + title + '. ' + 'Text: ' + text, 'label': labels})

# store the test_data_dict as json lines
with open('./test_data.jsonl', 'w') as f:
    for data in test_data_dict:
        json.dump(data, f, ensure_ascii=False)
        f.write('\n')

In [28]:
# load the hierarchy
from collections import defaultdict
hiera = defaultdict(list)
labels = set()
with open('./hierarchy.txt', 'r') as f:
    for line in f.readlines():
        line = line.replace('\n', '')
        line = line.split('\t')
        labels.add(line[0])
        if len(line) != 2:
            continue
        labels.add(line[1])
        hiera[line[0]].append(line[1])

r_hiera = {}
for parent, childrens in hiera.items():
    for children in childrens:
        r_hiera[children] = parent

# get the labels that are not in the r_hiera keys
labels = list(labels)
for label in labels:
    if label not in r_hiera.keys():
        r_hiera[label] = 'Root'


In [29]:
def compute_path_no(labels):
    visited = set()
    paths = defaultdict(list)
    path = 0
    for label in labels:
        if label in visited:
            continue
        visited.add(label)
        while label != 'Root':
            if label not in paths:
                paths[label] = []
            paths[r_hiera[label]].append(label)

            label = r_hiera[label]
            visited.add(label)
            if label in visited:
                break
    path = np.sum([1 for path in paths.values() if len(path) < 1])

    return path

# compute the path number for each train_data
for data in train_data_dict:
    data['path_no'] = compute_path_no(data['label'])
for data in dev_data_dict:
    data['path_no'] = compute_path_no(data['label'])
for data in test_data_dict:
    data['path_no'] = compute_path_no(data['label'])

from collections import Counter
print(Counter([data['path_no'] for data in train_data_dict]))
print(Counter([data['path_no'] for data in dev_data_dict]))
print(Counter([data['path_no'] for data in test_data_dict]))

Counter({1: 34000, 2: 18445, 3: 5656, 4: 606, 5: 8})
Counter({1: 8507, 2: 4659, 3: 1456, 4: 161, 5: 2})
Counter({1: 10636, 2: 5800, 3: 1763, 4: 193, 5: 2})


In [95]:
root = [k for k, v in r_hiera.items() if v == 'Root']
with open('./bgc.taxonomy', 'w') as f:
    queue = []
    f.write('Root\t')
    f.write('\t'.join(root))
    queue.extend(root)

    f.write('\n')
    while len(queue) > 0:
        parent = queue.pop(0)
        if parent in hiera:
            f.write(parent + '\t')
            f.write('\t'.join(hiera[parent]))
            queue.extend(hiera[parent])
            f.write('\n')

In [103]:
hiera, _label_dict, r_hiera, label_depth = get_hierarchy_info('./bgc.taxonomy')
_label_dict

{'Classics': 0,
 'Teen & Young Adult': 1,
 'Children’s Books': 2,
 'Poetry': 3,
 'Humor': 4,
 'Fiction': 5,
 'Nonfiction': 6,
 'Fiction Classics': 7,
 'Literary Collections': 8,
 'Literary Criticism': 9,
 'Nonfiction Classics': 10,
 'Teen & Young Adult Action & Adventure': 11,
 'Teen & Young Adult Mystery & Suspense': 12,
 'Teen & Young Adult Fantasy Fiction': 13,
 'Teen & Young Adult Nonfiction': 14,
 'Teen & Young Adult Fiction': 15,
 'Teen & Young Adult Social Issues': 16,
 'Teen & Young Adult Historical Fiction': 17,
 'Teen & Young Adult Romance': 18,
 'Teen & Young Adult Science Fiction': 19,
 'Step Into Reading': 20,
 'Children’s Middle Grade Books': 21,
 'Children’s Activity & Novelty Books': 22,
 'Children’s Board Books': 23,
 'Children’s Boxed Sets': 24,
 'Children’s Chapter Books': 25,
 'Childrens Media Tie-In Books': 26,
 'Children’s Picture Books': 27,
 'Fantasy': 28,
 'Gothic & Horror': 29,
 'Graphic Novels & Manga': 30,
 'Historical Fiction': 31,
 'Literary Fiction': 32,


In [108]:
label_hier = dict(hiera)
label_hier.pop('Root')

label_hier = {_label_dict[k]: set([_label_dict[v] for v in vs]) for k, vs in label_hier.items()}
label_hier

with open('./slot.pt', 'wb') as f:
    torch.save(label_hier, f)