In [9]:
import numpy as np
import pandas as pd
from datasets import Dataset
from torch import logical_and, logical_or, nn
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

In [10]:
# Configurations
model_name = "distilbert-base-uncased"
test_ratio = 0.1
val_ratio = 0.1
batch_size = 8
num_genres = 10

In [11]:
# Load data and nest genres
df = pd.read_csv("data/df_fixed.csv")
df_genres = pd.read_csv("data/df_genres.csv")

# Some data cleaning on the Plot
regex_fixes = [
    ["\s*\[.*\]", " ", True],  # footnotes
    ["–|—", "-", True],  # dash
    ["\r\n", "", True],  # newlines
]

for i, row in enumerate(regex_fixes):
    df["Plot"] = df["Plot"].str.replace(row[0], row[1], regex=row[2])

In [12]:
# Merge similar genres (see https://aclanthology.org/Y18-1007.pdf)
genre_groups = pd.DataFrame(
    [
        (
            "action",
            [
                "action",
                "adventure",
                "sci-fi",
                "superhero",
                "sport",
                "spy",
                "war",
                "worldwar-i",
                "worldwar-ii",
            ],
        ),
        ("comedy", ["comedy", "rom-com", "black-comedy"]),
        ("drama", ["drama", "fantasy", "biodrama", "melodrama"]),
        ("family", ["family", "animation", "musical", "anime", "child"]),
        ("thriller", ["thriller", "mystery"]),
        ("documentary", ["documentary", "biographical", "historical"]),
    ],
    columns=["genre_group", "Genre"],
).explode("Genre")

df_genres = df_genres.merge(genre_groups, how="left", on="Genre")
df_genres.loc[df_genres.genre_group.notna(), "Genre"] = df_genres["genre_group"]

In [13]:
# Only keep the top `num_genres` genres
top_genres = (
    df_genres.query("Genre != 'unknown'")
    .groupby("Genre")
    .agg(n=("Genre", "count"))
    .reset_index()
    .sort_values("n", ascending=False)
    .head(num_genres)
    .Genre.values
)

top_genres

array(['drama', 'comedy', 'action', 'family', 'thriller', 'romance',
       'crime', 'horror', 'western', 'documentary'], dtype=object)

In [14]:
# Encode genre labels to wide arrays
df_genres = (
    df_genres.query("Genre in @top_genres")
    .assign(cnt=1)
    .pivot_table(index=["movieID"], columns="Genre", values=["cnt"])
    .fillna(0)
    # .astype(int)
    .reset_index(col_level=1)  # get movieID out
)

df_genres.columns = [x[1] for x in df_genres.columns]
df_genres = df_genres.set_index("movieID")

genre_names = df_genres.columns.tolist()
labels = df_genres.values.tolist()
df_genres = pd.DataFrame({"movieID": df_genres.index, "labels": labels})

In [15]:
df = (
    df.reset_index()
    .rename(columns={"index": "movieID"})
    .filter(["movieID", "Plot"])
    .merge(df_genres, on="movieID")
    .reset_index(drop=True)
)

df.sample(10)

Unnamed: 0,movieID,Plot,labels
6943,7598,A vagabond family composed of Pop Kwimper (Art...,"[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ..."
5304,5787,Robert Teller (Kirk Douglas) visits a seaport ...,"[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
4249,4541,Connie Dickason is the strong-willed daughter ...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
316,360,Young Jim Hawkins is caught up with the pirate...,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
12337,13409,"The Miami Sharks, a once-great American footba...","[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
22407,25223,Sanam Teri Kasam is the story of Sunil (Kamal ...,"[0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, ..."
10805,11812,Marie (Anne Parillaud) is a very appealing mod...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ..."
6137,6712,Gangsters Nat Burdell (Kenne Duncan) and Brad ...,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
10014,10982,"Emily Crane (Kelly McGillis), a picture editor...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ..."
26111,32373,Krish (Ram) is a happy-go-lucky guy who lives ...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ..."


In [8]:
df.shape

(27504, 3)

## Multi-label classification

In [17]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# Construct dataset
dat = Dataset.from_pandas(df)

# Tokenize the Plot column
dat = dat.map(
    lambda batch: tokenizer.batch_encode_plus(
        batch["Plot"], padding="max_length", truncation=True
    ),
    batched=True,
    remove_columns=["movieID"],
)

# Retrieve tensors of the following columns as model inputs
valid_cols = ["input_ids", "token_type_ids", "attention_mask", "labels"]
cols = [c for c in dat.column_names if c in valid_cols]
dat.set_format(type="torch", columns=cols)

# Train/validation/test split
dat = dat.train_test_split(test_size=test_ratio, seed=42)
dat_train = dat["train"].train_test_split(test_size=val_ratio, seed=42)
dat["train"] = dat_train["train"]
dat["validation"] = dat_train["test"]

  0%|          | 0/28 [00:00<?, ?ba/s]

In [20]:
len(dat["train"])

22277

In [21]:
len(dat["validation"])

2476

In [22]:
len(dat["test"])

2751

In [23]:
# Modify last layer of model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, problem_type="multi_label_classification", num_labels=num_genres
)
model.to("cuda")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classi

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [24]:
training_args = TrainingArguments(
    output_dir="distilbert_multilabel",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=2 * batch_size,
    learning_rate=1e-5,
    num_train_epochs=5,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=1,  # make sure validation loss is logged in each epoch
    seed=42,
)

In [25]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dat["train"],
    eval_dataset=dat["validation"],
)
trainer.train()

The following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running training *****
  Num examples = 22277
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 24
  Gradient Accumulation steps = 1
  Total optimization steps = 4645


Epoch,Training Loss,Validation Loss
1,0.2153,0.24179
2,0.1559,0.218139
3,0.2789,0.211901
4,0.2181,0.21191
5,0.0933,0.21323


The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running Evaluation *****
  Num examples = 2476
  Batch size = 48
Saving model checkpoint to distilbert_multilabel/checkpoint-929
Configuration saved in distilbert_multilabel/checkpoint-929/config.json
Model weights saved in distilbert_multilabel/checkpoint-929/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Plot.
***** Running Evaluation *****
  Num examples = 2476
  Batch size = 48
Saving model checkpoint to distilbert_multilabel/checkpoint-1858
Configuration saved in distilbert_multilabel/checkpoint-1858/config.json
Model weights saved in distilbert_multilabel/checkpoint-1858/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceCla

TrainOutput(global_step=4645, training_loss=0.22298088123935672, metrics={'train_runtime': 1535.6079, 'train_samples_per_second': 72.535, 'train_steps_per_second': 3.025, 'total_flos': 1.4756986258176e+16, 'train_loss': 0.22298088123935672, 'epoch': 5.0})

In [26]:
model.to("cuda:0")
model.eval()

dl = DataLoader(dat["test"], batch_size=8)

In [27]:
proba_labels = []
sigmoid = nn.Sigmoid()
for batch in dl:
    batch = {k: v.to("cuda:0") for k, v in batch.items()}
    logits = model(**batch).get("logits")

    # Make sure there's at least one predicted label
    # by setting the logit of the maximum to 10
    max_proba = logits.argmax(axis=1)
    for i in range(logits.shape[0]):
        logits[i, max_proba[i]] = 10.0

    y_pred = (sigmoid(logits) > 0.5).cpu().detach().numpy()
    proba_labels.append(y_pred)

proba_labels = np.vstack(proba_labels)
y_labels = dat["test"]["labels"].bool().cpu().detach().numpy()

In [28]:
from sklearn.metrics import classification_report

print(classification_report(y_labels, proba_labels, target_names=genre_names))

              precision    recall  f1-score   support

      action       0.67      0.61      0.64       514
      comedy       0.77      0.53      0.63       740
       crime       0.48      0.28      0.35       154
 documentary       0.73      0.11      0.19        74
       drama       0.60      0.70      0.65       969
      family       0.77      0.39      0.52       257
      horror       0.65      0.73      0.68       172
     romance       0.51      0.36      0.42       201
    thriller       0.51      0.30      0.38       228
     western       0.85      0.85      0.85        97

   micro avg       0.65      0.56      0.60      3406
   macro avg       0.65      0.49      0.53      3406
weighted avg       0.65      0.56      0.59      3406
 samples avg       0.65      0.59      0.60      3406



In [29]:
true_pos = np.logical_and(y_labels, proba_labels).sum(axis=1)
pred_pos = np.logical_or(y_labels, proba_labels).sum(axis=1)

hamming_score = np.nansum(true_pos / pred_pos) / y_labels.shape[0]
precision = np.nansum(true_pos / y_labels.sum(axis=1)) / y_labels.shape[0]
recall = np.nansum(true_pos / proba_labels.sum(axis=1)) / y_labels.shape[0]

print(
    f"""
    Hamming accuracy: {hamming_score}
    Precision: {precision}
    Recall: {recall}
"""
)


    Hamming accuracy: 0.573397552405186
    Precision: 0.5925118138858597
    Recall: 0.6506724827335514



In [30]:
# Get genre names and plot text
true_labels = [
    [genre_names[x] for x in np.argwhere(arr == 1).flatten()] for arr in y_labels
]
plots = dat["test"]["Plot"]
pred_labels = [
    [genre_names[x] for x in np.argwhere(arr == 1).flatten()] for arr in proba_labels
]

for i in range(20):
    print(
        f"""
    True label: {true_labels[i]}
    Predicted label: {pred_labels[i]}
    Plot: {plots[i]}
    """
    )


    True label: ['action']
    Predicted label: ['action']
    Plot: A British naval officer volunteers for a dangerous mission to infiltrate the base of pirates who threaten shipping off Madagascar.
    

    True label: ['comedy']
    Predicted label: ['comedy']
    Plot: The film starts off with Calvin "Babyface" Simms (Marlon Wayans) who is a very short convict. He is seen getting released and planning a robbery to steal a diamond with the help of his goofball cohort Percy (Tracy Morgan). After the successful robbery, the duo are almost arrested, but not before Calvin manages to stash the diamond in a nearby woman's purse. The thieves follow the handbag's owner to her home where they discover a couple, Darryl (Shawn Wayans) and Vanessa Edwards (Kerry Washington), who are eager to have a child.Calvin and Percy hatch a plot to pass Calvin off as a baby left on the couple's doorstep. After seeing Calvin, Darryl and Vanessa, wanting a child, immediately adopt the baby as their own. Ho