## Introduction

In this example, we will build a multi-label text classifier to predict the subject areas of arXiv papers from their abstract bodies. This type of classifier can be useful for conference submission portals like [OpenReview](https://openreview.net/). Given a paper abstract, the portal could provide suggestions on which areas the underlying paper would best belong to.

The dataset was collected using the [`arXiv` Python library](https://github.com/lukasschwab/arxiv.py) that provides a wrapper around the [original arXiv API](http://arxiv.org/help/api/index). To know more, please refer to [this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb). 

## Imports

In [1]:
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from ast import literal_eval
import pandas as pd

## Read data and perform basic EDA

In this section, we first load the dataset into a `pandas` dataframe and then perform some basic exploratory data analysis (EDA).

In [2]:
arxiv_data = pd.read_csv(
    "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
)
arxiv_data.head()

Unnamed: 0,titles,summaries,terms
0,Survey on Semantic Stereo Matching / Semantic ...,Stereo matching is one of the widely used tech...,"['cs.CV', 'cs.LG']"
1,FUTURE-AI: Guiding Principles and Consensus Re...,The recent advancements in artificial intellig...,"['cs.CV', 'cs.AI', 'cs.LG']"
2,Enforcing Mutual Consistency of Hard Regions f...,"In this paper, we proposed a novel mutual cons...","['cs.CV', 'cs.AI']"
3,Parameter Decoupling Strategy for Semi-supervi...,Consistency training has proven to be an advan...,['cs.CV']
4,Background-Foreground Segmentation for Interio...,"To ensure safety in automated driving, the cor...","['cs.CV', 'cs.LG']"


Our text features are present in the `summaries` column and their corresponding labels are in `terms`. As we can notice there are multiple categories associated with a particular entry. 

In [3]:
print(f"There are {len(arxiv_data)} rows in the dataset.")

There are 51774 rows in the dataset.


Real-world data is noisy. One of the most commonly observed such noise is data duplication. Here we notice that our initial dataset has got about 13k duplicate entries. 

In [4]:
total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
print(f"There are {total_duplicate_titles} duplicate titles.")

There are 12802 duplicate titles.


Before proceeding further we first drop these entries. 

In [5]:
arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")

There are 38972 rows in the deduplicated dataset.


In [6]:
# There are some terms with occurrence as low as 1.
print(sum(arxiv_data["terms"].value_counts() == 1))

2321


In [7]:
# How many unique terms?
print(arxiv_data["terms"].nunique())

3157


As observed above, out of 3157 unique combinations of `terms`, 2321 entries have the lowest occurrence. To prepare our train, validation, and test sets with [stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop these terms. 

In [8]:
# Filtering the rare terms.
arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
arxiv_data_filtered.shape

(36651, 3)

## Convert the string labels to list of strings

The initial labels are represented as raw strings. Here we make them `List[str]` for a more compact representation. 

In [9]:
arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
    lambda x: literal_eval(x)
)
arxiv_data_filtered["terms"].values[:5]

array([list(['cs.CV', 'cs.LG']), list(['cs.CV', 'cs.AI', 'cs.LG']),
       list(['cs.CV', 'cs.AI']), list(['cs.CV']),
       list(['cs.CV', 'cs.LG'])], dtype=object)

## Stratified splits because of class imbalance

The dataset has a [class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset). So, to have a fair evaluation result, we need to ensure the datasets are sampled with stratification. To know more about different strategies to deal with the class imbalance problem, you can follow [this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data). 

In [10]:
test_split = 0.1

# Initial train and test split.
train_df, test_df = train_test_split(
    arxiv_data_filtered,
    test_size=test_split,
    stratify=arxiv_data_filtered["terms"].values,
)

# Splitting the test set further into validation
# and new test sets.
val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)

print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in validation set: {len(val_df)}")
print(f"Number of rows in test set: {len(test_df)}")

Number of rows in training set: 32985
Number of rows in validation set: 1833
Number of rows in test set: 1833


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  errors=errors,


## Multi-label binarization

Now we preprocess our labels using [`MultiLabelBinarizer`](http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html). 

In [11]:
mlb = MultiLabelBinarizer()
mlb.fit_transform(train_df["terms"])
mlb.classes_

array(['14J60 (Primary) 14F05, 14J26 (Secondary)', '62H30', '62H35',
       '62H99', '65D19', '68', '68Q32', '68T01', '68T05', '68T07',
       '68T10', '68T30', '68T45', '68T99', '68Txx', '68U01', '68U10',
       'E.5; E.4; E.2; H.1.1; F.1.1; F.1.3', 'F.2.2; I.2.7', 'G.3',
       'H.3.1; H.3.3; I.2.6; I.2.7', 'H.3.1; I.2.6; I.2.7', 'I.2',
       'I.2.0; I.2.6', 'I.2.1', 'I.2.10', 'I.2.10; I.2.6',
       'I.2.10; I.4.8', 'I.2.10; I.4.8; I.5.4', 'I.2.10; I.4; I.5',
       'I.2.10; I.5.1; I.4.8', 'I.2.1; J.3', 'I.2.6', 'I.2.6, I.5.4',
       'I.2.6; I.2.10', 'I.2.6; I.2.7', 'I.2.6; I.2.7; H.3.1; H.3.3',
       'I.2.6; I.2.8', 'I.2.6; I.2.9', 'I.2.6; I.5.1', 'I.2.6; I.5.4',
       'I.2.7', 'I.2.8', 'I.2; I.2.6; I.2.7', 'I.2; I.4; I.5', 'I.2; I.5',
       'I.2; J.2', 'I.4', 'I.4.0', 'I.4.3', 'I.4.4', 'I.4.5', 'I.4.6',
       'I.4.6; I.4.8', 'I.4.8', 'I.4.9', 'I.4.9; I.5.4', 'I.4; I.5',
       'I.5.4', 'K.3.2', 'astro-ph.IM', 'cond-mat.dis-nn',
       'cond-mat.mtrl-sci', 'cond-mat.soft', 'c

`MultiLabelBinarizer`separates out the individual unique classes available from the label pool and then uses this information to represent a given label set with 0's and 1's. Below is an example. 

In [12]:
sample_label = train_df["terms"].iloc[0]
print(f"Original label: {sample_label}")

label_binarized = mlb.transform([sample_label])
print(f"Label-binarized representation: {label_binarized}")

Original label: ['cs.CV']
Label-binarized representation: [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0]]


## Data preprocessing and `tf.data.Dataset` objects

We first get percentile estimates of the sequence lengths. The purpose will be clear in a moment.

In [13]:
train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()

count    32985.000000
mean       156.551342
std         41.456720
min          5.000000
25%        128.000000
50%        155.000000
75%        183.000000
max        462.000000
Name: summaries, dtype: float64

Notice that 50% of the abstracts have a length of 154 (you may get a different number based on the split). So, any number near that is a good enough approximate for the maximum sequence length. 

Now, we write utilities to prepare our datasets that would go straight to the text classifier model. 

In [14]:
max_seqlen = 150
batch_size = 128
padding_token = "<pad>"


def unify_text_length(text, label):
    # Split the given abstract and calculate its length.
    word_splits = tf.strings.split(text, sep=" ")
    sequence_length = tf.shape(word_splits)[0]
    
    # Calculate the padding amount.
    padding_amount = max_seqlen - sequence_length
    
    # Check if we need to pad or truncate.
    if padding_amount > 0:
        unified_text = tf.pad([text], [[0, padding_amount]], constant_values="<pad>")
        unified_text = tf.strings.reduce_join(unified_text, separator="")
    else:
        unified_text = tf.strings.reduce_join(word_splits[:max_seqlen], separator=" ")
    
    # The expansion is needed for subsequent vectorization.
    return tf.expand_dims(unified_text, -1), label


def make_dataset(dataframe, train=True):
    label_binarized = mlb.transform(dataframe["terms"].values)
    dataset = tf.data.Dataset.from_tensor_slices(
        (dataframe["summaries"].values, label_binarized)
    )
    if train:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(unify_text_length).cache()
    return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

Now we can prepare the `tf.data.Dataset` objects. 

In [15]:
train_dataset = make_dataset(train_df)
validation_dataset = make_dataset(val_df, False)
test_dataset = make_dataset(test_df, False)

## Dataset preview

In [16]:
text_batch, label_batch = next(iter(train_dataset))

for i, text in enumerate(text_batch[:5]):
    label = label_batch[i].numpy()[None, ...]
    print(f"Abstract: {text[0]}")
    print(f"Label(s): {mlb.inverse_transform(label)[0]}")
    print(" ")

Abstract: b'Supply chain network data is a valuable asset for businesses wishing to\nunderstand their ethical profile, security of supply, and efficiency.\nPossession of a dataset alone however is not a sufficient enabler of actionable\ndecisions due to incomplete information. In this paper, we present a graph\nrepresentation learning approach to uncover hidden dependency links that focal\ncompanies may not be aware of. To the best of our knowledge, our work is the\nfirst to represent a supply chain as a heterogeneous knowledge graph with\nlearnable embeddings. We demonstrate that our representation facilitates\nstate-of-the-art performance on link prediction of a global automotive supply\nchain network using a relational graph convolutional network. It is anticipated\nthat our method will be directly applicable to businesses wishing to sever\nlinks with nefarious entities and mitigate risk of supply failure. More\nabstractly, it is anticipated that our method will be useful to inform\

## Vocabulary size for vectorization

Before we feed the data to our model we need to represent them as numbers. For that purpose, we will use the [`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization). It can operate as a part of your main model so that the model is excluded from the core preprocessing logic. This greatly reduces the chances of training and serving skew. 

We first calculate the number of unique words present in the abstracts.

In [17]:
train_df["total_words"] = train_df["summaries"].str.split().str.len()
vocabulary_size = train_df["total_words"].max()
print(f"Vocabulary size: {vocabulary_size}")

Vocabulary size: 498


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


Now we can create our text classifier model with the `TextVectorization` layer present inside it. 

## Create model with `TextVectorization`

A batch of raw text will first go through the `TextVectorization` layer and it will generate their integer representations. These will then be passed to the shallow model responsible for text classification. 

In [18]:
text_vectorizer = layers.TextVectorization(
    max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
)

# `TextVectorization` needs to be adapted as per the vocabulary from our
# training set.
with tf.device("/CPU:0"):
    text_vectorizer.adapt(train_dataset.map(lambda text, label: text))


def make_model():
    shallow_mlp_model = keras.Sequential(
        [
            keras.Input(shape=(), dtype=tf.string),
            text_vectorizer,
            layers.Dense(512, activation="relu"),
            layers.Dense(256, activation="relu"),
            layers.Dense(len(mlb.classes_), activation="sigmoid"),
        ]
    )
    return shallow_mlp_model

Without the CPU placement, we run into: 

```
(1) Invalid argument: During Variant Host->Device Copy: non-DMA-copy attempted of tensor type: string
```

In [19]:
shallow_mlp_model = make_model()
shallow_mlp_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
text_vectorization (TextVect (None, 498)               1         
_________________________________________________________________
dense (Dense)                (None, 512)               255488    
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328    
_________________________________________________________________
dense_2 (Dense)              (None, 152)               39064     
Total params: 425,881
Trainable params: 425,880
Non-trainable params: 1
_________________________________________________________________


## Train the model

We will train our model using the binary cross-entropy loss. This is because the labels are not disjoint. For a given abstract, we may have multiple categories. So, we will divide the prediction task into a series of multiple binary classification problems. This is also why we kept the activation function of the classification layer in our model to sigmoid. Researchers have used other combinations of loss function and activation function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), Mahajan et al. used the softmax activation function and cross-entropy loss to train their models. 

In [20]:
epochs = 20

shallow_mlp_model.compile(
    loss="binary_crossentropy", optimizer="adam", metrics=["categorical_accuracy"]
)

shallow_mlp_model.fit(train_dataset, validation_data=validation_dataset, epochs=epochs)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7fc0002d3810>

## Evaluate the model

In [21]:
_, categorical_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(categorical_acc * 100, 2)}%.")

Categorical accuracy on the test set: 69.67%.
