# Here we will use ptune learning method as baseline for genre classification


In [1]:
import sys
import os

sys.path.append(os.path.abspath('../'))

In [2]:
import numpy as np

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils import logger, DatasetTypes
from src.data import init_data
from src.ptune import prepare_ptune, train
from src.metrics import GenrePredictorInterface, evaluate_model
from src.model import get_pretrained
import json
import re
# supported files in spython
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [3]:
model_name = "Qwen/Qwen3-0.6B"
tokenizer, model = get_pretrained(model_name, device)

## Get dataset with all genres and 1,294,054 examples

In [4]:
path_to_csv = '../data/all_genres_downsampled.csv'
train_dataset, val_dataset, test_dataset , idx2genre, genres, train_loader, val_loader, test_loader = init_data(path_to_csv=path_to_csv, batch_size=16, tokenizer=tokenizer)

In [5]:
model = prepare_ptune(model=model, model_name=model_name, genres=genres, device=device)

trainable params: 30,720 || all params: 596,080,640 || trainable%: 0.0052


In [6]:
num_epochs = 100
warmup_steps = 500
learning_rate = 5e-5

In [7]:
model = train(model=model,
      idx2genre=idx2genre,
      learning_rate=learning_rate,
      num_epochs=num_epochs,
      tokenizer=tokenizer,
      train_loader=train_loader,
      val_loader=val_loader,
      device=device)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch 1/100 - Train loss: 14.2247
Epoch 1/100 - Val accuracy: 0.0000
Epoch 2/100 - Train loss: 14.1602
Epoch 2/100 - Val accuracy: 0.0000
Epoch 3/100 - Train loss: 14.0889
Epoch 3/100 - Val accuracy: 0.0000
Epoch 4/100 - Train loss: 13.9824
Epoch 4/100 - Val accuracy: 0.0000
Epoch 5/100 - Train loss: 13.8319
Epoch 5/100 - Val accuracy: 0.0000
Epoch 6/100 - Train loss: 13.7105
Epoch 6/100 - Val accuracy: 0.0000
Epoch 7/100 - Train loss: 13.5596
Epoch 7/100 - Val accuracy: 0.0000
Epoch 8/100 - Train loss: 13.4254
Epoch 8/100 - Val accuracy: 0.0000
Epoch 9/100 - Train loss: 13.1516
Epoch 9/100 - Val accuracy: 0.0000
Epoch 10/100 - Train loss: 12.8953
Epoch 10/100 - Val accuracy: 0.0000
Epoch 11/100 - Train loss: 12.5061
Epoch 11/100 - Val accuracy: 0.0000
Epoch 12/100 - Train loss: 11.8187
Epoch 12/100 - Val accuracy: 0.0000
Epoch 13/100 - Train loss: 10.0447
Epoch 13/100 - Val accuracy: 0.1800
Epoch 14/100 - Train loss: 6.0792
Epoch 14/100 - Val accuracy: 0.5400
Epoch 15/100 - Train loss

In [8]:
label_token_ids = {g: tokenizer.encode(' ' + g, add_special_tokens=False)[0] for g in idx2genre.values()}
id2label_token = {v: k for k, v in label_token_ids.items()}

model.eval()
print("\nTesting on test set:")
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        logits = outputs.logits[:, -1, :]
        preds = logits.argmax(dim=-1)

        for i in range(len(labels)):
            pred_token = preds[i].item()
            label_idx = labels[i].item() if labels[i].ndim == 0 else labels[i][0].item()
            true_label = idx2genre[label_idx]
            pred_label = id2label_token.get(pred_token, 'UNKNOWN')
            print(f"Predicted genre is: {pred_label:<15} | Actual genre is: {true_label}")


Testing on test set:
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Actual genre is: alt-country
Predicted genre is: alt-rock        | Ac