In [None]:
import os
import random

import numpy as np
import pandas as pd
import torch
import transformers
import wandb
from datasets import Dataset, load_dataset
from sklearn import metrics
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.preprocessing import MultiLabelBinarizer
from torch import cuda
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer,
    BertConfig,
    BertForSequenceClassification,
    BertModel,
    BertTokenizer,
    DataCollatorWithPadding,
)

np.random.seed(42)
torch.manual_seed(42)
random.seed(42)


device = "cuda" if cuda.is_available() else "cpu"

In [None]:
df = pd.read_csv("data/train.csv")
df

Unnamed: 0,Фильм,Описание,Сюжет,Жанры
0,Дивергент (2014),"Действие фильма «Дивергент» происходит в мире,...","Действие фильма «Дивергент» происходит в мире,...","фантастика, детектив, боевик, мелодрама"
1,Кунг-фу Панда 4 (2024),Однажды ночью на вершине горы возле каменоломн...,Однажды ночью на вершине горы возле каменоломн...,"мультфильм, фэнтези, боевик, комедия, приключения"
2,2046 (2004),Чоу возвращается в Гонконг после нескольких ле...,Чоу возвращается в Гонконг после нескольких ле...,"фантастика, драма, мелодрама"
3,Полицейский из Беверли-Хиллз: Аксель Фоули (2024),Аксель Фоули вернулся в Беверли-Хиллз после то...,Аксель Фоули вернулся в Беверли-Хиллз после то...,"боевик, комедия, криминал, детектив"
4,"Знакомьтесь, Джо Блэк (1998)","История об Ангеле Смерти, который решает взять...","История об Ангеле Смерти, который решает взять...","мелодрама, фэнтези, драма"
...,...,...,...,...
566,Апокалипсис (2006),В 1517 году на полуострове Юкатан племя Лапы Я...,1517 год. Полуостров Юкатан. Группа охотников ...,"боевик, триллер, драма, приключения"
567,Лёд 3 (2024),"Надя, ставшая фигуристкой, стремится выиграть ...",Фильм начинается с истории взросления дочери А...,"мюзикл, мелодрама"
568,Дастур (2023),"Новоиспеченная невеста, которую выдали замуж п...",,"ужасы, фантастика"
569,Не говори никому (2024),Пара вместе с дочерью получают приглашение от ...,Пара вместе с дочерью получают приглашение от ...,"триллер, драма"


In [None]:
df = df.rename(
    {"Фильм": "movie", "Сюжет": "plot", "Жанры": "genres", "Описание": "description"},
    axis=1,
)
df["plot"] = df["plot"].fillna(df["description"])
df.isnull().any()

movie          False
description    False
plot           False
genres         False
dtype: bool

In [None]:
df["genres"] = df["genres"].apply(lambda x: x.split(", "))
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df["genres"])
df["target"] = y.tolist()
df = df.drop("genres", axis=1)
df = df.drop("movie", axis=1)

In [None]:
df["description"].apply(len).max()

np.int64(1719)

In [None]:
df["plot"].apply(len).max()

np.int64(15174)

In [None]:
max_length = 2048
num_labels = len(mlb.classes_)
# num_labels = 6
label2id = dict(zip(range(num_labels), mlb.classes_))
id2label = dict(zip(mlb.classes_, range(num_labels)))
model = "cointegrated/rubert-tiny2"
problem_type = "multi_label_classification"
batch_size = 16
pin_memory = False
drop_last = False
num_workers = 4
shuffle = True
val_size = int(len(df) * 0.2)
num_epochs = 30


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model)

model = BertForSequenceClassification.from_pretrained(
    model,
    num_labels=num_labels,
    problem_type=problem_type,
    label2id=label2id,
    id2label=id2label,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
dataset = Dataset.from_pandas(df)
dataset = dataset.map(
    lambda x: tokenizer(
        x["description"], x["plot"], truncation=True, max_length=max_length
    ),
    batched=True,
)
dataset = dataset.map(
    lambda x: {"label": [float(y) for y in x["target"]]},
    batched=False,
    remove_columns=["description", "plot", "target"],
)

Map:   0%|          | 0/571 [00:00<?, ? examples/s]

Map:   0%|          | 0/571 [00:00<?, ? examples/s]

In [None]:
data_collator = DataCollatorWithPadding(tokenizer)


train_val = dataset.train_test_split(test_size=val_size)  # 10% of full dataset

train_dataloader = DataLoader(
    train_val["train"],
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers,
    collate_fn=data_collator,
    pin_memory=pin_memory,
    drop_last=drop_last,
)

val_dataloader = DataLoader(
    train_val["test"],
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=data_collator,
    pin_memory=pin_memory,
    drop_last=drop_last,
)

In [None]:
THRESHOLD = 0.1


def get_report_multilabel(y_true, y_pred, target_names, output_dict):
    return classification_report(
        y_true,
        y_pred >= THRESHOLD,
        target_names=target_names,
        output_dict=output_dict,
        zero_division=0,
    )

In [None]:
def predict(model, dataloader):
    with torch.inference_mode():
        y_true = []
        y_pred = []
        val_loss = 0

        for batch in tqdm(dataloader):
            batch = batch.to(model.device)
            output = model(**batch)
            loss = output.loss

            val_loss += loss.item() * batch["input_ids"].size(0)
            y_true.append(batch.labels.cpu())
            y_pred.append(output.logits.cpu())

        val_loss = val_loss / len(dataloader.dataset)

    return (
        torch.cat(y_true).numpy(),
        torch.sigmoid(torch.cat(y_pred)).numpy(),
        val_loss,
    )


def train_epoch(model, train_dataloader, optimizer):
    y_true = []
    y_pred = []
    train_loss = 0

    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        batch = batch.to(model.device)
        output = model(**batch)
        loss = output.loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * batch["input_ids"].size(0)
        y_true.append(batch.labels.detach().cpu())
        y_pred.append(output.logits.detach().cpu())

    train_loss = train_loss / len(train_dataloader.dataset)
    return (
        torch.cat(y_true).numpy(),
        torch.sigmoid(torch.cat(y_pred)).numpy(),
        train_loss,
    )


def eval(model, val_dataloader, labels):
    model.eval()
    val_y_true, val_y_pred, val_loss = predict(model, val_dataloader)
    report_dict = get_report_multilabel(val_y_true, val_y_pred, labels, True)
    df = pd.DataFrame(report_dict)
    df = df.round(2)
    return df

In [None]:
def train(model, train_dataloader, optimizer, epochs, val_dataloader, labels):
    tq = tqdm(range(epochs))

    for epoch in tq:

        model.train()
        train_y_true, train_y_pred, train_loss = train_epoch(
            model, train_dataloader, optimizer
        )

        model.eval()
        val_y_true, val_y_pred, val_loss = predict(model, val_dataloader)

        tq.set_description(f"train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}")

    df = eval(model, val_dataloader, labels)
    print(df)

    return val_y_true, val_y_pred

In [None]:
model.cuda()

optimizer = Adam(model.parameters(), lr=0.00001)

val_y_true, val_y_pred = train(
    model=model,
    train_dataloader=train_dataloader,
    optimizer=optimizer,
    epochs=num_epochs,
    val_dataloader=train_dataloader,
    labels=label2id.values(),
)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

  0%|          | 0/29 [00:00<?, ?it/s]

                0     1       2     3      4     5     6     7    8     9  \
precision    0.87   0.0    0.69   0.0    0.0   0.0   0.0   0.0  0.0   0.0   
recall       0.37   0.0    0.85   0.0    0.0   0.0   0.0   0.0  0.0   0.0   
f1-score     0.52   0.0    0.76   0.0    0.0   0.0   0.0   0.0  0.0   0.0   
support    146.00  64.0  238.00  52.0  143.0  95.0  69.0  37.0  6.0  94.0   

               10    11      12    13  micro avg  macro avg  weighted avg  \
precision    1.00   0.0    1.00   0.0       0.73       0.25          0.40   
recall       0.01   0.0    0.05   0.0       0.21       0.09          0.21   
f1-score     0.02   0.0    0.10   0.0       0.32       0.10          0.21   
support    119.00  45.0  100.00  68.0    1276.00    1276.00       1276.00   

           samples avg  
precision         0.56  
recall            0.23  
f1-score          0.31  
support        1276.00  


In [None]:
df = eval(model, val_dataloader, label2id.values())

  0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
df

Unnamed: 0,боевик,детектив,драма,история,комедия,криминал,мелодрама,мультфильм,мюзикл,приключения,триллер,ужасы,фантастика,фэнтези,micro avg,macro avg,weighted avg,samples avg
precision,0.33,0.11,0.4,0.1,0.37,0.18,0.13,0.11,0.0,0.27,0.28,0.14,0.21,0.2,0.22,0.2,0.26,0.22
recall,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.77,0.0,1.0,1.0,0.87,1.0,1.0,0.96,0.9,0.96,0.97
f1-score,0.5,0.19,0.57,0.18,0.54,0.31,0.23,0.19,0.0,0.43,0.44,0.25,0.35,0.34,0.36,0.32,0.4,0.35
support,38.0,12.0,46.0,11.0,42.0,21.0,15.0,13.0,7.0,31.0,32.0,15.0,24.0,23.0,330.0,330.0,330.0,330.0


In [None]:
val_y_true

array([[1., 0., 1., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 1.],
       ...,
       [1., 0., 0., ..., 0., 0., 1.],
       [0., 0., 1., ..., 0., 0., 0.],
       [1., 0., 1., ..., 0., 1., 0.]], dtype=float32)

In [None]:
val_y_pred

array([[0.4493606 , 0.20008188, 0.46238035, ..., 0.15467286, 0.29485258,
        0.17997764],
       [0.5004476 , 0.22297907, 0.41778725, ..., 0.19219719, 0.3300212 ,
        0.19626734],
       [0.47234684, 0.2094067 , 0.42952403, ..., 0.1607923 , 0.34536922,
        0.27086797],
       ...,
       [0.52628845, 0.2193778 , 0.39998758, ..., 0.17677812, 0.38198584,
        0.2555664 ],
       [0.24974592, 0.1404527 , 0.5637468 , ..., 0.10028478, 0.16473237,
        0.11947478],
       [0.4581577 , 0.19072028, 0.4222496 , ..., 0.1653984 , 0.30586356,
        0.18352522]], dtype=float32)