In [3]:
import pickle

# load dataframe and training/test datasets from pickle files

with open('../data/movies_df.pkl', 'rb') as f:
    movies_df = pickle.load(f)

with open('../data/train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)
    X_train, y_train, ids_train = train_data

with open('../data/test_data.pkl', 'rb') as f:
    test_data = pickle.load(f)
    X_test, y_test, ids_test = test_data

with open('../data/mlb.pkl', 'rb') as f:
    mlb = pickle.load(f)

In [4]:
#pip3 install -U "transformers>=4.45" "datasets>=3.0" "accelerate>=1.0" "torch" "evaluate"
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict

train_df = pd.DataFrame({
    'text': list(X_train),
    'labels': [list(map(float, row)) for row in y_train] 
})
test_df = pd.DataFrame({
    'text': list(X_test),
    'labels': [list(map(float, row)) for row in y_test]
})


ds = DatasetDict({
    'train': Dataset.from_pandas(train_df, preserve_index=False),
    'test': Dataset.from_pandas(test_df, preserve_index=False),
})

num_labels = y_train.shape[1]
num_labels

  from .autonotebook import tqdm as notebook_tqdm


23

In [5]:
from transformers import AutoTokenizer

model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_fn(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=256)

tokenized = ds.map(tokenize_fn, batched=True, remove_columns=['text'])
tokenized.set_format('torch')

Map: 100%|██████████| 8000/8000 [00:00<00:00, 14054.00 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 14321.53 examples/s]


In [6]:
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = sigmoid(logits)
    preds = (probs >= 0.5).astype(int)
    return {
        "micro_f1": f1_score(labels, preds, average='micro', zero_division=0),
        "macro_f1": f1_score(labels, preds, average='macro', zero_division=0),
        "micro_precision": precision_score(labels, preds, average='micro', zero_division=0),
        "micro_recall": recall_score(labels, preds, average='micro', zero_division=0),
    }

args = TrainingArguments(
    output_dir="../models/distilbert_multilabel",
    do_eval=True,               
    save_steps=500,    
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    num_train_epochs=50,
    weight_decay=0.01,
    logging_steps=50,
    fp16=False  # set True if GPU supports it
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


In [None]:
import os
train_result = trainer.train()
eval_metrics = trainer.evaluate()
print(eval_metrics)

os.makedirs("../models", exist_ok=True)
with open("../models/metrics_distilbert.txt", "w") as f:
    for k, v in sorted(eval_metrics.items()):
        f.write(f"{k}: {v}\n")

trainer.save_model("../models/distilbert_multilabel")

Imports and training (TF-IDF + OneVsRest LogisticRegression)

In [8]:
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

tfidf_clf = Pipeline([
    ('tfidf', TfidfVectorizer(max_features=50000, ngram_range=(1, 2), lowercase=True, stop_words='english')),
    ('ovr', OneVsRestClassifier(LogisticRegression(max_iter=200, C=2.0, solver='liblinear')))
])

tfidf_clf.fit(X_train, y_train)

0,1,2
,steps,"[('tfidf', ...), ('ovr', ...)]"
,transform_input,
,memory,
,verbose,False

0,1,2
,input,'content'
,encoding,'utf-8'
,decode_error,'strict'
,strip_accents,
,lowercase,True
,preprocessor,
,tokenizer,
,analyzer,'word'
,stop_words,'english'
,token_pattern,'(?u)\\b\\w\\w+\\b'

0,1,2
,estimator,LogisticRegre...r='liblinear')
,n_jobs,
,verbose,0

0,1,2
,penalty,'l2'
,dual,False
,tol,0.0001
,C,2.0
,fit_intercept,True
,intercept_scaling,1
,class_weight,
,random_state,
,solver,'liblinear'
,max_iter,200


Evaluation

In [9]:
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score

y_pred = tfidf_clf.predict(X_test)

micro_f1 = f1_score(y_test, y_pred, average='micro', zero_division=0)
macro_f1 = f1_score(y_test, y_pred, average='macro', zero_division=0)
micro_precision = precision_score(y_test, y_pred, average='micro', zero_division=0)
micro_recall = recall_score(y_test, y_pred, average='micro', zero_division=0)
report = classification_report(y_test, y_pred, target_names=list(mlb.classes_), zero_division=0)

print(f"micro_f1: {micro_f1:.4f}")
print(f"macro_f1: {macro_f1:.4f}")
print(f"micro_precision: {micro_precision:.4f}")
print(f"micro_recall: {micro_recall:.4f}")
print()
print(report)

micro_f1: 0.4772
macro_f1: 0.1794
micro_precision: 0.7261
micro_recall: 0.3554

              precision    recall  f1-score   support

      Action       0.76      0.44      0.56       543
       Adult       0.00      0.00      0.00         0
   Adventure       0.84      0.27      0.41       360
   Animation       0.00      0.00      0.00       110
   Biography       0.95      0.13      0.22       151
      Comedy       0.70      0.47      0.56       722
       Crime       0.79      0.38      0.51       411
 Documentary       0.00      0.00      0.00        36
       Drama       0.72      0.76      0.74      1142
      Family       0.00      0.00      0.00       107
     Fantasy       1.00      0.01      0.01       146
     History       0.00      0.00      0.00        74
      Horror       0.82      0.19      0.31       262
       Music       1.00      0.08      0.14        64
     Musical       0.00      0.00      0.00        20
     Mystery       0.50      0.07      0.13       220
 

Persist model and metrics

In [10]:
import os
import pickle

os.makedirs('../models', exist_ok=True)

with open('../models/baseline_tfidf_logreg.pkl', 'wb') as f:
    pickle.dump({'pipeline': tfidf_clf, 'mlb': mlb}, f)

with open('../models/metrics_tfidf.txt', 'w') as f:
    f.write("\n".join([
        f"micro_f1: {micro_f1}",
        f"macro_f1: {macro_f1}",
        f"micro_precision: {micro_precision}",
        f"micro_recall: {micro_recall}",
        "",
        report
    ]))

Embeddings baseline (SBERT + OneVsRest)

In [12]:
import numpy as np

try:
    from sentence_transformers import SentenceTransformer
except Exception as e:
    raise RuntimeError("Please install sentence-transformers: pip install sentence-transformers")

sbert_model_name = 'sentence-transformers/all-MiniLM-L6-v2'
sbert = SentenceTransformer(sbert_model_name)

def encode_sbert(texts, batch_size=256):
    embs = []
    for i in range(0, len(texts), batch_size):
        batch = list(texts[i:i+batch_size])
        emb = sbert.encode(batch, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
        embs.append(emb)
    return np.vstack(embs)

X_train_emb = encode_sbert(X_train)
X_test_emb = encode_sbert(X_test)

sbert_clf = OneVsRestClassifier(LogisticRegression(max_iter=300, C=2.0, solver='liblinear'))
sbert_clf.fit(X_train_emb, y_train)

0,1,2
,estimator,LogisticRegre...r='liblinear')
,n_jobs,
,verbose,0

0,1,2
,penalty,'l2'
,dual,False
,tol,0.0001
,C,2.0
,fit_intercept,True
,intercept_scaling,1
,class_weight,
,random_state,
,solver,'liblinear'
,max_iter,300


In [13]:
y_pred_sbert = sbert_clf.predict(X_test_emb)

micro_f1_sbert = f1_score(y_test, y_pred_sbert, average='micro', zero_division=0)
macro_f1_sbert = f1_score(y_test, y_pred_sbert, average='macro', zero_division=0)
micro_precision_sbert = precision_score(y_test, y_pred_sbert, average='micro', zero_division=0)
micro_recall_sbert = recall_score(y_test, y_pred_sbert, average='micro', zero_division=0)
report_sbert = classification_report(y_test, y_pred_sbert, target_names=list(mlb.classes_), zero_division=0)

print(f"SBERT micro_f1: {micro_f1_sbert:.4f}")
print(f"SBERT macro_f1: {macro_f1_sbert:.4f}")
print(f"SBERT micro_precision: {micro_precision_sbert:.4f}")
print(f"SBERT micro_recall: {micro_recall_sbert:.4f}")
print()
print(report_sbert)

with open('../models/baseline_sbert_logreg.pkl', 'wb') as f:
    pickle.dump({'clf': sbert_clf, 'mlb': mlb, 'sbert_model_name': sbert_model_name}, f)

with open('../models/metrics_sbert.txt', 'w') as f:
    f.write("\n".join([
        f"micro_f1: {micro_f1_sbert}",
        f"macro_f1: {macro_f1_sbert}",
        f"micro_precision: {micro_precision_sbert}",
        f"micro_recall: {micro_recall_sbert}",
        "",
        report_sbert
    ]))

SBERT micro_f1: 0.5718
SBERT macro_f1: 0.3468
SBERT micro_precision: 0.7063
SBERT micro_recall: 0.4803

              precision    recall  f1-score   support

      Action       0.75      0.59      0.66       543
       Adult       0.00      0.00      0.00         0
   Adventure       0.74      0.50      0.60       360
   Animation       0.53      0.09      0.16       110
   Biography       0.75      0.35      0.48       151
      Comedy       0.70      0.57      0.63       722
       Crime       0.68      0.50      0.58       411
 Documentary       1.00      0.19      0.33        36
       Drama       0.73      0.78      0.76      1142
      Family       0.64      0.08      0.15       107
     Fantasy       0.52      0.08      0.14       146
     History       0.59      0.22      0.32        74
      Horror       0.70      0.43      0.53       262
       Music       0.90      0.42      0.57        64
     Musical       0.00      0.00      0.00        20
     Mystery       0.55      0.