In [9]:
import os
import pickle
import errno

import numpy as np
import pandas as pd
import six
from sklearn import (model_selection, linear_model, multiclass,
                     preprocessing)

In [12]:
def load_labels(label_filename, vocab_size):
    """Load labels file. Supports single or multiple labels"""
    raw_labels = {}
    min_labels = np.inf
    max_labels = 0
    with open(label_filename) as f:
        for line in f.readlines():
            values = [int(x) for x in line.strip().split()]
            raw_labels[values[0]] = values[1:]
            min_labels = min(len(values) - 1, min_labels)
            max_labels = max(len(values) - 1, max_labels)
    print("Raw Labels: {}".format(len(raw_labels)))
    if min_labels < 1:
        raise RuntimeError("Expected 1 or more labels in file {}"
                           .format(label_filename))
    # Single label
    elif max_labels == 1:
        labels = np.zeros(vocab_size, dtype=np.int32)
        for (index, label) in six.iteritems(raw_labels):
            labels[index + FLAGS.force_offset] = label[0]
        return raw_labels, labels

    # Multiple labels
    else:
        print("Multi-label classification")
        unique_labels = np.unique(
            [l for labs in raw_labels.values() for l in labs])
        n_labels = len(unique_labels)
        print("Number of labels: {}".format(n_labels))

        label_encoder = preprocessing.MultiLabelBinarizer(unique_labels)
        labels = np.zeros((vocab_size, n_labels), dtype=np.int8)
        for (index, multi_label) in six.iteritems(raw_labels):
            labels[index -1] = \
                label_encoder.fit_transform([multi_label])
        return raw_labels, labels


In [13]:
label_dir = "/Users/Ganymedian/Desktop"
label_file = "sorted-labels.txt"
r_labels, all_labels = load_labels(os.path.join(label_dir, label_file), 10313)

Raw Labels: 10312
Multi-label classification
Number of labels: 39


In [14]:
len(all_labels)

10313

In [16]:
all_labels[0]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int8)