# Here we will use ptune learning method as baseline for genre classification


In [1]:
import sys
import os

sys.path.append(os.path.abspath('../'))

In [2]:
import numpy as np

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils import logger, DatasetTypes
from src.data import init_data
from src.ptune import prepare_ptune, train, MultiLabelClassifier
from src.model import get_pretrained
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score
import json
import re
# supported files in spython
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [3]:
model_name = "Qwen/Qwen3-0.6B"
tokenizer, base_model = get_pretrained(model_name, device)

## Get dataset with all genres and 1,294,054 examples

In [4]:
path_to_csv = '../data/top_genres.csv'
train_dataset, val_dataset, test_dataset , idx2genre, genres, train_loader, val_loader, test_loader = init_data(path_to_csv=path_to_csv, batch_size=16, tokenizer=tokenizer)

In [5]:
peft_model = prepare_ptune(model=base_model, model_name=model_name, genres=genres, device=device)
hidden_size = base_model.config.hidden_size
num_labels = len(genres)
model = MultiLabelClassifier(peft_model, hidden_size, num_labels).to("cuda")

trainable params: 30,720 || all params: 596,080,640 || trainable%: 0.0052


In [6]:
num_epochs = 100
warmup_steps = 500
learning_rate = 5e-5

In [7]:
model = train(model=model,
      idx2genre=idx2genre,
      learning_rate=learning_rate,
      num_epochs=num_epochs,
      tokenizer=tokenizer,
      train_loader=train_loader,
      val_loader=val_loader,
      device=device)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch 1/100 - Train loss: 1.0756
Epoch 1/100 - Val macro F1: 0.1123, Label acc: 0.5350
Epoch 2/100 - Train loss: 1.0617
Epoch 2/100 - Val macro F1: 0.1206, Label acc: 0.5400
Epoch 3/100 - Train loss: 1.0317
Epoch 3/100 - Val macro F1: 0.1216, Label acc: 0.5525
Epoch 4/100 - Train loss: 0.9797
Epoch 4/100 - Val macro F1: 0.1155, Label acc: 0.5733
Epoch 5/100 - Train loss: 0.8948
Epoch 5/100 - Val macro F1: 0.1402, Label acc: 0.6100
Epoch 6/100 - Train loss: 0.8016
Epoch 6/100 - Val macro F1: 0.1409, Label acc: 0.6433
Epoch 7/100 - Train loss: 0.7209
Epoch 7/100 - Val macro F1: 0.1331, Label acc: 0.6575
Epoch 8/100 - Train loss: 0.6438
Epoch 8/100 - Val macro F1: 0.1353, Label acc: 0.6892
Epoch 9/100 - Train loss: 0.5605
Epoch 9/100 - Val macro F1: 0.1322, Label acc: 0.7283
Epoch 10/100 - Train loss: 0.4886
Epoch 10/100 - Val macro F1: 0.1313, Label acc: 0.7783
Epoch 11/100 - Train loss: 0.4416
Epoch 11/100 - Val macro F1: 0.1192, Label acc: 0.8033
Epoch 12/100 - Train loss: 0.4069
Epoch

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 36/100 - Val macro F1: 0.1486, Label acc: 0.8567
Epoch 37/100 - Train loss: 0.2692
Epoch 37/100 - Val macro F1: 0.1561, Label acc: 0.8508
Epoch 38/100 - Train loss: 0.2709


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 38/100 - Val macro F1: 0.1506, Label acc: 0.8558
Epoch 39/100 - Train loss: 0.2650


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 39/100 - Val macro F1: 0.1428, Label acc: 0.8517
Epoch 40/100 - Train loss: 0.2641
Epoch 40/100 - Val macro F1: 0.1538, Label acc: 0.8525
Epoch 41/100 - Train loss: 0.2646


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 41/100 - Val macro F1: 0.1403, Label acc: 0.8550
Epoch 42/100 - Train loss: 0.2564
Epoch 42/100 - Val macro F1: 0.1559, Label acc: 0.8625
Epoch 43/100 - Train loss: 0.2528
Epoch 43/100 - Val macro F1: 0.1381, Label acc: 0.8500
Epoch 44/100 - Train loss: 0.2493
Epoch 44/100 - Val macro F1: 0.1406, Label acc: 0.8525
Epoch 45/100 - Train loss: 0.2488


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 45/100 - Val macro F1: 0.1435, Label acc: 0.8550
Epoch 46/100 - Train loss: 0.2436


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 46/100 - Val macro F1: 0.1559, Label acc: 0.8575
Epoch 47/100 - Train loss: 0.2414


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 47/100 - Val macro F1: 0.1520, Label acc: 0.8558
Epoch 48/100 - Train loss: 0.2397


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 48/100 - Val macro F1: 0.1574, Label acc: 0.8567
Epoch 49/100 - Train loss: 0.2386


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 49/100 - Val macro F1: 0.1402, Label acc: 0.8508
Epoch 50/100 - Train loss: 0.2357


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 50/100 - Val macro F1: 0.1361, Label acc: 0.8508
Epoch 51/100 - Train loss: 0.2357


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 51/100 - Val macro F1: 0.1518, Label acc: 0.8600
Epoch 52/100 - Train loss: 0.2322


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 52/100 - Val macro F1: 0.1447, Label acc: 0.8550
Epoch 53/100 - Train loss: 0.2303
Epoch 53/100 - Val macro F1: 0.1486, Label acc: 0.8550
Epoch 54/100 - Train loss: 0.2292


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 54/100 - Val macro F1: 0.1493, Label acc: 0.8600
Epoch 55/100 - Train loss: 0.2281


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 55/100 - Val macro F1: 0.1512, Label acc: 0.8567
Epoch 56/100 - Train loss: 0.2248


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 56/100 - Val macro F1: 0.1489, Label acc: 0.8550
Epoch 57/100 - Train loss: 0.2260


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 57/100 - Val macro F1: 0.1493, Label acc: 0.8550
Epoch 58/100 - Train loss: 0.2226


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 58/100 - Val macro F1: 0.1551, Label acc: 0.8583
Epoch 59/100 - Train loss: 0.2218


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 59/100 - Val macro F1: 0.1523, Label acc: 0.8592
Epoch 60/100 - Train loss: 0.2206


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 60/100 - Val macro F1: 0.1554, Label acc: 0.8558
Epoch 61/100 - Train loss: 0.2193


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 61/100 - Val macro F1: 0.1549, Label acc: 0.8583
Epoch 62/100 - Train loss: 0.2186


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 62/100 - Val macro F1: 0.1535, Label acc: 0.8592
Epoch 63/100 - Train loss: 0.2167


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 63/100 - Val macro F1: 0.1481, Label acc: 0.8550
Epoch 64/100 - Train loss: 0.2150


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 64/100 - Val macro F1: 0.1497, Label acc: 0.8633
Epoch 65/100 - Train loss: 0.2141


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 65/100 - Val macro F1: 0.1574, Label acc: 0.8600
Epoch 66/100 - Train loss: 0.2135


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 66/100 - Val macro F1: 0.1437, Label acc: 0.8567
Epoch 67/100 - Train loss: 0.2125


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 67/100 - Val macro F1: 0.1490, Label acc: 0.8600
Epoch 68/100 - Train loss: 0.2120


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 68/100 - Val macro F1: 0.1493, Label acc: 0.8575
Epoch 69/100 - Train loss: 0.2124
Epoch 69/100 - Val macro F1: 0.1604, Label acc: 0.8583
Epoch 70/100 - Train loss: 0.2105


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 70/100 - Val macro F1: 0.1537, Label acc: 0.8592
Epoch 71/100 - Train loss: 0.2098


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 71/100 - Val macro F1: 0.1533, Label acc: 0.8600
Epoch 72/100 - Train loss: 0.2081


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 72/100 - Val macro F1: 0.1469, Label acc: 0.8550
Epoch 73/100 - Train loss: 0.2095


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 73/100 - Val macro F1: 0.1568, Label acc: 0.8617
Epoch 74/100 - Train loss: 0.2090


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 74/100 - Val macro F1: 0.1491, Label acc: 0.8575
Epoch 75/100 - Train loss: 0.2073
Epoch 75/100 - Val macro F1: 0.1433, Label acc: 0.8533
Epoch 76/100 - Train loss: 0.2064


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 76/100 - Val macro F1: 0.1537, Label acc: 0.8600
Epoch 77/100 - Train loss: 0.2064


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 77/100 - Val macro F1: 0.1568, Label acc: 0.8592
Epoch 78/100 - Train loss: 0.2049
Epoch 78/100 - Val macro F1: 0.1472, Label acc: 0.8575
Epoch 79/100 - Train loss: 0.2048


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 79/100 - Val macro F1: 0.1586, Label acc: 0.8567
Epoch 80/100 - Train loss: 0.2048


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 80/100 - Val macro F1: 0.1548, Label acc: 0.8625
Epoch 81/100 - Train loss: 0.2048


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 81/100 - Val macro F1: 0.1521, Label acc: 0.8567
Epoch 82/100 - Train loss: 0.2035
Epoch 82/100 - Val macro F1: 0.1483, Label acc: 0.8575
Epoch 83/100 - Train loss: 0.2038


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 83/100 - Val macro F1: 0.1550, Label acc: 0.8642
Epoch 84/100 - Train loss: 0.2030


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 84/100 - Val macro F1: 0.1583, Label acc: 0.8617
Epoch 85/100 - Train loss: 0.2025


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 85/100 - Val macro F1: 0.1514, Label acc: 0.8550
Epoch 86/100 - Train loss: 0.2026


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 86/100 - Val macro F1: 0.1628, Label acc: 0.8575
Epoch 87/100 - Train loss: 0.2025


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 87/100 - Val macro F1: 0.1695, Label acc: 0.8650
Epoch 88/100 - Train loss: 0.2022


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 88/100 - Val macro F1: 0.1562, Label acc: 0.8600
Epoch 89/100 - Train loss: 0.2021


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 89/100 - Val macro F1: 0.1572, Label acc: 0.8633
Epoch 90/100 - Train loss: 0.2017


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 90/100 - Val macro F1: 0.1583, Label acc: 0.8600
Epoch 91/100 - Train loss: 0.2022


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 91/100 - Val macro F1: 0.1555, Label acc: 0.8575
Epoch 92/100 - Train loss: 0.2018


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 92/100 - Val macro F1: 0.1494, Label acc: 0.8575
Epoch 93/100 - Train loss: 0.2015


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 93/100 - Val macro F1: 0.1494, Label acc: 0.8592
Epoch 94/100 - Train loss: 0.2012


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 94/100 - Val macro F1: 0.1661, Label acc: 0.8625
Epoch 95/100 - Train loss: 0.2011


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 95/100 - Val macro F1: 0.1664, Label acc: 0.8608
Epoch 96/100 - Train loss: 0.2014


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 96/100 - Val macro F1: 0.1668, Label acc: 0.8650
Epoch 97/100 - Train loss: 0.2010


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 97/100 - Val macro F1: 0.1421, Label acc: 0.8533
Epoch 98/100 - Train loss: 0.2011


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 98/100 - Val macro F1: 0.1544, Label acc: 0.8608
Epoch 99/100 - Train loss: 0.2009


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 99/100 - Val macro F1: 0.1477, Label acc: 0.8575
Epoch 100/100 - Train loss: 0.2015
Epoch 100/100 - Val macro F1: 0.1553, Label acc: 0.8617


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [10]:
import numpy as np
from sklearn.metrics import f1_score

model.eval()

# Choose threshold range
thresholds = np.arange(0.1, 0.91, 0.05)

# To store best results
best_threshold = 0.0
best_f1 = 0.0
threshold_f1_scores = []

all_val_preds = []
all_val_labels = []

# Collect raw logits for threshold testing
with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits = model(input_ids, attention_mask)
        probs = torch.sigmoid(logits)

        all_val_preds.append(probs.cpu())     # save sigmoid probs
        all_val_labels.append(labels.cpu())

# Concatenate all batches
all_val_probs = torch.cat(all_val_preds, dim=0).to(torch.float32).numpy()
all_val_labels = torch.cat(all_val_labels, dim=0).to(torch.float32).numpy()


# Search for best threshold
for threshold in thresholds:
    preds = (all_val_probs > threshold).astype(int)
    f1 = f1_score(all_val_labels, preds, average='macro', zero_division=0)
    threshold_f1_scores.append((threshold, f1))

    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

# Report
print(f"\nBest threshold: {best_threshold:.2f} with macro F1-score: {best_f1:.4f}")
print("Threshold sweep results:")
for t, f1 in threshold_f1_scores:
    print(f"Threshold = {t:.2f} --> Macro F1 = {f1:.4f}")



Best threshold: 0.15 with macro F1-score: 0.1911
Threshold sweep results:
Threshold = 0.10 --> Macro F1 = 0.1853
Threshold = 0.15 --> Macro F1 = 0.1911
Threshold = 0.20 --> Macro F1 = 0.1909
Threshold = 0.25 --> Macro F1 = 0.1836
Threshold = 0.30 --> Macro F1 = 0.1807
Threshold = 0.35 --> Macro F1 = 0.1689
Threshold = 0.40 --> Macro F1 = 0.1654
Threshold = 0.45 --> Macro F1 = 0.1590
Threshold = 0.50 --> Macro F1 = 0.1553
Threshold = 0.55 --> Macro F1 = 0.1432
Threshold = 0.60 --> Macro F1 = 0.1273
Threshold = 0.65 --> Macro F1 = 0.1265
Threshold = 0.70 --> Macro F1 = 0.1169
Threshold = 0.75 --> Macro F1 = 0.1179
Threshold = 0.80 --> Macro F1 = 0.0966
Threshold = 0.85 --> Macro F1 = 0.0858
Threshold = 0.90 --> Macro F1 = 0.0529
