# Here we will use LoRA learning method as baseline for genre classification

In [1]:
import sys
import os

sys.path.append(os.path.abspath('../'))

In [2]:
import numpy as np

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils import logger, DatasetTypes
from src.data import init_data
from src.ptune import prepare_ptune, train, MultiLabelClassifier
from src.lora import prepare_lora, validate_model
from src.model import get_pretrained
from sklearn.metrics import classification_report, f1_score, precision_score, recall_score
import json
import re
# supported files in spython
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [3]:
model_name = "Qwen/Qwen3-0.6B"
tokenizer, base_model = get_pretrained(model_name, device)

In [4]:
BATCH_SIZE = 16
PATH_TO_CSV = '../data/all_genres_downsampled.csv'
OUTPUT_DIR = "../experiments/qwen_lora_genre"
LR = 2e-4
EPOCHS = 10
LOG_STEPS = 50
SAVE_STEPS = 200
GRDIENT_ACCUM_STEPS = 4


#LoRA params
RANG = 8
LORA_ALPHA = 16
DROPOUT = 0.05

In [5]:
path_to_csv = '../data/top_genres.csv'
train_dataset, val_dataset, test_dataset , idx2genre, genres, train_loader, val_loader, test_loader = init_data(path_to_csv=path_to_csv, batch_size=16, tokenizer=tokenizer)

In [6]:
peft_model = prepare_lora(base_model, r=RANG, alpha=LORA_ALPHA, dropout=DROPOUT)
peft_model.to(device)

trainable params: 1,146,880 || all params: 597,196,800 || trainable%: 0.1920


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 1024)
        (layers): ModuleList(
          (0-27): 28 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1024, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): Linear(in_fea

In [7]:
num_epochs = 100
warmup_steps = 500
learning_rate = 1e-4
treshold = 0.3

In [8]:
hidden_size = base_model.config.hidden_size
num_labels = len(genres)
model = MultiLabelClassifier(peft_model, hidden_size, num_labels).to(device)


In [9]:
for name, param in model.named_parameters():
    if "lora" not in name:
        param.requires_grad = False
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'qwen_lora.pt'))

In [10]:
model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'qwen_lora.pt')))
model.eval()

MultiLabelClassifier(
  (peft): PeftModelForCausalLM(
    (base_model): LoraModel(
      (model): Qwen3ForCausalLM(
        (model): Qwen3Model(
          (embed_tokens): Embedding(151936, 1024)
          (layers): ModuleList(
            (0-27): 28 x Qwen3DecoderLayer(
              (self_attn): Qwen3Attention(
                (q_proj): lora.Linear(
                  (base_layer): Linear(in_features=1024, out_features=2048, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1024, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_features=2048, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                  (lora_magnitude_v

In [11]:
model.eval()
threshold = 0.3

print("\nTesting on test set (multi‑label):")
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        true_multilabels = batch['labels']  # [B, num_labels]
        
        # forward через ваш MultiLabelClassifier
        logits = model(input_ids, attention_mask)       # [B, num_labels]
        probs  = torch.sigmoid(logits)                  # [B, num_labels]
        preds  = (probs > threshold).long()             # [B, num_labels]
        
        for i in range(preds.size(0)):
            # собираем списки жанров
            pred_genres = [genres[j] for j, p in enumerate(preds[i]) if p == 1]
            true_genres = [genres[j] for j, t in enumerate(true_multilabels[i]) if t == 1]
            
            pred_str = ", ".join(pred_genres) if pred_genres else "None"
            true_str = ", ".join(true_genres) if true_genres else "None"
            print(f"Predicted genres: {pred_str:<30} | Actual genres: {true_str}")




Testing on test set (multi‑label):
Predicted genres: country, hip-hop, metal, other, pop, rock | Actual genres: r&b
Predicted genres: electronic, hip-hop, metal, other, pop, rock | Actual genres: country
Predicted genres: country, electronic, hip-hop, other, pop, r&b, rock | Actual genres: hip-hop
Predicted genres: country, electronic, hip-hop, metal, other, pop, rock | Actual genres: rock
Predicted genres: electronic, hip-hop, other, pop, rock | Actual genres: pop
Predicted genres: country, electronic, hip-hop, metal, other, pop, rock | Actual genres: pop
Predicted genres: electronic, hip-hop, metal, other, pop, r&b, rock | Actual genres: hip-hop
Predicted genres: electronic, hip-hop, metal, other, pop, rock | Actual genres: hip-hop
Predicted genres: country, electronic, hip-hop, metal, other, pop, r&b, rock | Actual genres: electronic, pop
Predicted genres: electronic, hip-hop, metal, other, pop, rock | Actual genres: pop
Predicted genres: electronic, hip-hop, metal, other, pop, r&b

In [12]:
metrics = validate_model(model, val_loader, device)


Validation Metrics → F1: 0.3285 | Precision: 0.2021 | Recall: 0.8766
