In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
from datasets import Dataset, load_metric
from torch import logical_and, logical_or, nn
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

In [2]:
# Configurations
model_name = "distilbert-base-uncased"
test_ratio = 0.1
val_ratio = 0.1
batch_size = 8
num_genres = 10

In [3]:
# Load data and nest genres
df = pd.read_csv("data/df_fixed.csv")
df_genres = pd.read_csv("data/df_genres.csv")

# Some data cleaning on the Plot
regex_fixes = [
    ["\s*\[.*\]", " ", True],  # footnotes
    ["–|—", "-", True],  # dash
    ["\r\n", "", True],  # newlines
]

for i, row in enumerate(regex_fixes):
    df["Plot"] = df["Plot"].str.replace(row[0], row[1], regex=row[2])

In [4]:
# Merge similar genres (see https://aclanthology.org/Y18-1007.pdf)
genre_groups = pd.DataFrame(
    [
        (
            "action",
            [
                "action",
                "adventure",
                "sci-fi",
                "superhero",
                "sport",
                "spy",
                "war",
                "worldwar-i",
                "worldwar-ii",
            ],
        ),
        ("comedy", ["comedy", "rom-com", "black-comedy"]),
        ("drama", ["drama", "fantasy", "biodrama", "melodrama"]),
        ("family", ["family", "animation", "musical", "anime", "child"]),
        ("thriller", ["thriller", "mystery"]),
        ("documentary", ["documentary", "biographical", "historical"]),
    ],
    columns=["genre_group", "Genre"],
).explode("Genre")

df_genres = df_genres.merge(genre_groups, how="left", on="Genre")
df_genres.loc[df_genres.genre_group.notna(), "Genre"] = df_genres["genre_group"]

In [5]:
# Only keep the top `num_genres` genres
top_genres = (
    df_genres.query("Genre != 'unknown'")
    .groupby("Genre")
    .agg(n=("Genre", "count"))
    .reset_index()
    .sort_values("n", ascending=False)
    .head(num_genres)
    .Genre.values
)

top_genres

array(['drama', 'comedy', 'action', 'family', 'thriller', 'romance',
       'crime', 'horror', 'western', 'documentary'], dtype=object)

In [6]:
# Encode genre labels to wide arrays
df_genres = (
    df_genres.query("Genre in @top_genres")
    .assign(cnt=1)
    .pivot_table(index=["movieID"], columns="Genre", values=["cnt"])
    .fillna(0)
    # .astype(int)
    .reset_index(col_level=1)  # get movieID out
)

df_genres.columns = [x[1] for x in df_genres.columns]
df_genres = df_genres.set_index("movieID")

genre_names = df_genres.columns.tolist()
labels = df_genres.values.tolist()
df_genres = pd.DataFrame({"movieID": df_genres.index, "labels": labels})

In [7]:
df = (
    df.reset_index()
    .rename(columns={"index": "movieID"})
    .filter(["movieID", "Plot"])
    .merge(df_genres, on="movieID")
)

df.sample(10)

Unnamed: 0,movieID,Plot,labels
25286,31088,The film opens when a child named Tarun starts...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
7595,8275,Rancher Taw Jackson (John Wayne) returns to hi...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
28,48,"Only lasting 15 minutes, it is a light-hearted...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
16275,17538,Jake Cullen (Bill Kerr) is babysitting his gra...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, ..."
11099,12110,"Christina Ford (Bo Derek) a beautiful woman, i...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
13311,14398,A team of researchers funded by a New York pha...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ..."
18078,19430,"Ewan McEwan, an easy-going sheep and corn farm...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3082,3198,Ted Scott (John Payne) is a band pianist whose...,"[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
6926,7578,Two American World War II veterans Jeff (John ...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
6953,7608,"Nick Adams is a young, restless man who wants ...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."


## Multi-label classification

In [8]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# Construct dataset
dat = Dataset.from_pandas(df)

# Tokenize the Plot column
dat = dat.map(
    lambda batch: tokenizer.batch_encode_plus(
        batch["Plot"], padding="max_length", truncation=True
    ),
    batched=True,
    remove_columns=["__index_level_0__", "movieID"],
)

# Retrieve tensors of the following columns as model inputs
valid_cols = ["input_ids", "token_type_ids", "attention_mask", "labels"]
cols = [c for c in dat.column_names if c in valid_cols]
dat.set_format(type="torch", columns=cols)

# Train/validation/test split
dat = dat.train_test_split(test_size=test_ratio, seed=42)
dat_train = dat["train"].train_test_split(test_size=val_ratio, seed=42)
dat["train"] = dat_train["train"]
dat["validation"] = dat_train["test"]

  0%|          | 0/28 [00:00<?, ?ba/s]

In [9]:
# Modify last layer of model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, problem_type="multi_label_classification", num_labels=num_genres
)
model.to("cuda")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'classifier.w

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [10]:
training_args = TrainingArguments(
    output_dir="distilbert_multilabel",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=2e-5,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=1,  # make sure validation loss is logged in each epoch
)

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dat["train"],
    eval_dataset=dat["validation"],
)
trainer.train()

The following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running training *****
  Num examples = 22277
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 24
  Gradient Accumulation steps = 1
  Total optimization steps = 2787


Epoch,Training Loss,Validation Loss
1,0.2183,0.227829
2,0.1162,0.210961
3,0.235,0.211835


The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running Evaluation *****
  Num examples = 2476
  Batch size = 24
Saving model checkpoint to distilbert_multilabel/checkpoint-929
Configuration saved in distilbert_multilabel/checkpoint-929/config.json
Model weights saved in distilbert_multilabel/checkpoint-929/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running Evaluation *****
  Num examples = 2476
  Batch size = 24
Saving model checkpoint to distilbert_multilabel/checkpoint-1858
Configuration saved in distilbert_multilabel/checkpoint-1858/config.json
Model weights saved in distilbert_multilabel/checkpoint-1858/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceCla

TrainOutput(global_step=2787, training_loss=0.22178466243527037, metrics={'train_runtime': 943.6196, 'train_samples_per_second': 70.824, 'train_steps_per_second': 2.954, 'total_flos': 8854191754905600.0, 'train_loss': 0.22178466243527037, 'epoch': 3.0})

In [12]:
model.to("cuda:0")
model.eval()

dl = DataLoader(dat["test"], batch_size=8)

In [13]:
proba_labels = []
sigmoid = nn.Sigmoid()
for batch in dl:
    batch = {k: v.to("cuda:0") for k, v in batch.items()}
    logits = model(**batch).get("logits")

    # Make sure there's at least one predicted label
    # by setting the logit of the maximum to 10
    max_proba = logits.argmax(axis=1)
    for i in range(logits.shape[0]):
        logits[i, max_proba[i]] = 10.0

    y_pred = (sigmoid(logits) > 0.5).cpu().detach().numpy()
    proba_labels.append(y_pred)

proba_labels = np.vstack(proba_labels)
y_labels = dat["test"]["labels"].bool().cpu().detach().numpy()

In [14]:
from sklearn.metrics import classification_report

print(classification_report(y_labels, proba_labels, target_names=genre_names))

              precision    recall  f1-score   support

      action       0.65      0.64      0.64       514
      comedy       0.76      0.54      0.63       740
       crime       0.46      0.31      0.37       154
 documentary       0.67      0.19      0.29        74
       drama       0.61      0.70      0.65       969
      family       0.75      0.43      0.54       257
      horror       0.67      0.71      0.69       172
     romance       0.53      0.32      0.40       201
    thriller       0.50      0.30      0.37       228
     western       0.88      0.86      0.87        97

   micro avg       0.65      0.56      0.60      3406
   macro avg       0.65      0.50      0.55      3406
weighted avg       0.65      0.56      0.59      3406
 samples avg       0.66      0.60      0.61      3406



In [15]:
true_pos = np.logical_and(y_labels, proba_labels).sum(axis=1)
pred_pos = np.logical_or(y_labels, proba_labels).sum(axis=1)

hamming_score = np.nansum(true_pos / pred_pos) / y_labels.shape[0]
precision = np.nansum(true_pos / y_labels.sum(axis=1)) / y_labels.shape[0]
recall = np.nansum(true_pos / proba_labels.sum(axis=1)) / y_labels.shape[0]

print(
    f"""
    Hamming accuracy: {hamming_score}
    Precision: {precision}
    Recall: {recall}
"""
)


    Hamming accuracy: 0.5784684357203441
    Precision: 0.6006906579425664
    Recall: 0.6572155579789167



In [16]:
# Get genre names and plot text
true_labels = [
    [genre_names[x] for x in np.argwhere(arr == 1).flatten()] for arr in y_labels
]
plots = dat["test"]["Plot"]
pred_labels = [
    [genre_names[x] for x in np.argwhere(arr == 1).flatten()] for arr in proba_labels
]

for i in range(10):
    print(
        f"""
    True label: {true_labels[i]}
    Predicted label: {pred_labels[i]}
    Plot: {plots[i]}
    """
    )


    True label: ['action']
    Predicted label: ['action']
    Plot: A British naval officer volunteers for a dangerous mission to infiltrate the base of pirates who threaten shipping off Madagascar.
    

    True label: ['comedy']
    Predicted label: ['comedy']
    Plot: The film starts off with Calvin "Babyface" Simms (Marlon Wayans) who is a very short convict. He is seen getting released and planning a robbery to steal a diamond with the help of his goofball cohort Percy (Tracy Morgan). After the successful robbery, the duo are almost arrested, but not before Calvin manages to stash the diamond in a nearby woman's purse. The thieves follow the handbag's owner to her home where they discover a couple, Darryl (Shawn Wayans) and Vanessa Edwards (Kerry Washington), who are eager to have a child.Calvin and Percy hatch a plot to pass Calvin off as a baby left on the couple's doorstep. After seeing Calvin, Darryl and Vanessa, wanting a child, immediately adopt the baby as their own. Ho