## Here we will use furher improvements of previously zero-shot approach.

Changes are:
- Ask model choose one or few option for classification of lyrics. Previous model generated too much FP.
- Add few shot examples of classification
- Add reasoning fiels into json

I have used [chat gpt](https://chatgpt.com/g/g-6769db0aa91c8191bf46eeac95f5e055-system-prompt-generator-for-reasoning-models) System Prompt Generator for prompt improvement. Link to dialog: https://chatgpt.com/share/6851f905-1230-8012-90c5-534a8f75a958



In [1]:
import sys
import os

sys.path.append(os.path.abspath('../'))

In [2]:
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from src.utils import logger, DatasetTypes
from src.data import get_datasets, get_dataloaders
from src.metrics import GenrePredictorInterface, evaluate_model
from src.model import get_pretrained

import json
import re
from typing import List, Tuple

import random
device

2025-06-18 03:08:29,441 - numexpr.utils - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2025-06-18 03:08:29,442 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.


'cuda'

## Get model

In [3]:
model_name = "Qwen/Qwen3-0.6B"
tokenizer, model = get_pretrained(model_name, device)
model.eval()

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe

## Get dataset with all genres and 1,294,054 examples 

In [4]:
path_to_csv = '../data/all_genres_downsampled.csv'
dataset_type = DatasetTypes.hundred
batch_size = 8

data_dict = get_datasets(path_to_csv, tokenizer, dataset_type=dataset_type)
train_dataset, val_dataset, test_dataset = data_dict['train_dataset'], data_dict['val_dataset'], data_dict['test_dataset']

idx2genre, genre2idx = data_dict['idx2genre'], data_dict['genre2idx']
genres = [key for key, _ in genre2idx.items()]

traid_loader, val_loader, test_loader = get_dataloaders(train_dataset, val_dataset, test_dataset, batch_size)

In [5]:
PROMPT = '''You are a music genre classification expert. Your task is to analyze song lyrics and decide which single music genre from the given list they most likely belong to. You must provide a short explanation for your choice in the "reasoning" field and then output the selected genre using the "predict" field.

**Available genres:**
<genres>

**Input format:**
```json
{
    "lyrics": "Text of the song lyrics"
}
```

**Output format:**
```json
{
    "reasoning": "Explain briefly why the lyrics fit the selected genre",
    "predict": "genre_name"
}
```

**Few-shot examples:**

Example 1:
```json
Input:
{
    "lyrics": "I got my hands up, they're playing my song, I know I'm gonna be OK..."
}
Output:
{
    "reasoning": "The lyrics talk about dancing, feeling good, and have a carefree theme typical of pop songs.",
    "predict": "pop"
}
```

Example 2:
```json
Input:
{
    "lyrics": "Gunshots echo through the night, I’ve seen too many die young..."
}
Output:
{
    "reasoning": "The lyrics are gritty and socially conscious, with a storytelling style that is common in hip hop.",
    "predict": "hip-hop"
}
```

Example 3:
```json
Input:
{
    "lyrics": "Rolling down that old dirt road, truck tires kickin' dust in the air..."
}
Output:
{
    "reasoning": "Mentions of trucks, dirt roads, and rural imagery are strong indicators of country music.",
    "predict": "country"
}
```

**Lyrics for classification:**
```json
{
    "lyrics": "%s"
}
```

**Your output:**
'''

PROMPT = PROMPT.replace('<genres>', str(genres))


## Main mechanic and functions

In [6]:
# Main functions
def get_input_text(lyrics, enable_thinking=False):
    instruct = PROMPT % lyrics

    messages = [
        {"role": "user", "content": instruct}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking,
        do_sample=False
    )
    
    return text

def parse_model_response(generated_ids, model_inputs_len):
    assert generated_ids.ndim == 1
    
    output_ids = generated_ids[model_inputs_len:].tolist() 
    try:
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0
        
    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    assert len(content) != 0, 'Error. Output content is empty!'
    return thinking_content, content

def parse_output_json(response: str) -> Tuple[str, str]:
    """
    Parses the model output to extract the 'reasoning' and 'predict' fields.

    Args:
        response (str): Raw text response from the model.

    Returns:
        Tuple[str, str]: A tuple containing (reasoning, predicted_genre).
                         Returns ("", "") if parsing fails.
    """
    try:
        # Match a JSON object containing both "reasoning" and "predict"
        match = re.search(r'\{[^}]*"reasoning"\s*:\s*"[^"]*",\s*"predict"\s*:\s*"[^"]*"\s*\}', response, re.DOTALL)
        if match:
            json_str = match.group(0)
            data = json.loads(json_str)
            reasoning = data.get("reasoning", "").strip()
            genre = data.get("predict", "").strip()
            return reasoning, genre
    except Exception as e:
        print(f"Parse error: {e}")
    return "", ""  # fallback if parsing fails

In [7]:
lyrics = val_dataset[0]['features']['lyrics']
genre = val_dataset[0]['features']['genre']

truncated = lyrics[:300]

input_text = get_input_text(truncated, enable_thinking=False,)
print(input_text)

model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=1337,
    do_sample=False
)

thinking_content, content = parse_model_response(generated_ids[0], len(model_inputs.input_ids[0]))

print("thinking content:", thinking_content)
print("content:", content)

reasoning, predicted_genre = parse_output_json(content)
print(f'Reasoning: {reasoning}')
print(f'Predicted genre: {predicted_genre}')
print(f'Ground truth genre: {genre}')
print(f'Genre idx: {genre2idx[predicted_genre]}')

<|im_start|>user
You are a music genre classification expert. Your task is to analyze song lyrics and decide which single music genre from the given list they most likely belong to. You must provide a short explanation for your choice in the "reasoning" field and then output the selected genre using the "predict" field.

**Available genres:**
['alt-country', 'alt-rock', 'alternative', 'ambient', 'axé', 'black-metal', 'blues', 'bossa-nova', 'chillwave', 'classic-rock', 'classical', 'cloud-rap', 'country', 'dance', 'dancehall', 'death-metal', 'deathcore', 'disco', 'doom-metal', 'dream-pop', 'drum&bass', 'dub', 'electro-pop', 'electronic', 'electronica', 'emo', 'emo-rap', 'folk', 'forró', 'funk', 'funk-carioca', 'garage-rock', 'gothic', 'grunge', 'hard-rock', 'hardcore', 'heavy-metal', 'hip-hop', 'house', 'indie', 'indie-pop', 'indie-rock', 'j-pop', 'j-rock', 'jazz', 'jovem-guarda', 'k-pop', 'math-rock', 'melodic-death-metal', 'metal', 'metalcore', 'mpb', 'new-wave', 'pagode', 'pop', 'pop



thinking content: 
content: ```json
{
    "reasoning": "The lyrics describe a reflective and introspective tone, with references to a garden, a ladder, and a future plan, which are common in soul music. The mention of 'outside your golden garden' and 'running away together' also aligns with a romantic or emotional theme typical of soul.",
    "predict": "soul"
}
```
Reasoning: The lyrics describe a reflective and introspective tone, with references to a garden, a ladder, and a future plan, which are common in soul music. The mention of 'outside your golden garden' and 'running away together' also aligns with a romantic or emotional theme typical of soul.
Predicted genre: soul
Ground truth genre: jazz,pop,rock
Genre idx: 79


In [8]:
def proccess_sample(sample: str, truncation_len: int = 300, max_new_tokens: int = 1337, enable_thinking=False) -> np.array:
    ''' Make predictions for one sample: whether it belongs to each genre. 
        Returns np array with 1 in corresponding places if lyrics belongs to genre.'''

    truncated = sample[:truncation_len]
    
    input_text = get_input_text(truncated, enable_thinking=enable_thinking)
    
    model_inputs = tokenizer([input_text], return_tensors="pt", padding=True, padding_side='left').to(model.device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens
    )
    
    preds = np.zeros(len(genres), dtype=np.int32)
    thinking_content, content = parse_model_response(generated_ids[0], len(model_inputs['input_ids'][0]))
    reasoning, genre = parse_output_json(content)
    preds[genre2idx[genre]] = 1

    return preds

def get_input_texts(batch: str, enable_thinking=False, truncation_len=300):
    input_texts = []
    
    for lyrics in batch:
        input_text = get_input_text(lyrics[:truncation_len], enable_thinking=enable_thinking)
        input_texts.append(input_text) 
    
    return input_texts

def proccess_batch(batch: List[str], truncation_len: int = 300, max_new_tokens: int = 1337, enable_thinking=False) -> np.array:
    input_texts = get_input_texts(batch, enable_thinking=enable_thinking, truncation_len=truncation_len)
    
    model_inputs_batch = tokenizer(input_texts, return_tensors="pt", padding=True, padding_side='left').to(model.device)
    
    generated_ids_batch = model.generate(
        **model_inputs_batch,
        max_new_tokens=max_new_tokens
    )
    
    list_preds = []
    for i, (generated_ids, model_inputs) in enumerate(zip(generated_ids_batch, model_inputs_batch['input_ids'])):
        preds = np.zeros(len(genres), dtype=np.int32)
        thinking_content, content = parse_model_response(generated_ids, len(model_inputs))
        reasoning, genre = parse_output_json(content)
        genre_idx = genre2idx.get(genre, random.randint(0, len(genres)))
        preds[genre_idx] = 1
        list_preds.append(preds)

    return np.stack(list_preds)


def one_hot_encoded_to_genre_list(predictions, idx2genre: dict = None):
    ''' Predictions is array on n_genres size, where 1 if lyrics belongs to that genre and 0 if not'''    
    genre_list = []
    for i, value in enumerate(predictions):
        if value == 1:
            genre_list.append(idx2genre[i])
    
    return genre_list

In [9]:
sample = val_dataset[1]['features']['lyrics']
print(sample)

preds = proccess_sample(sample)
print(preds)

genre_list = one_hot_encoded_to_genre_list(preds, idx2genre)
print(genre_list)

A relapse of my body Sends my mind into multiple seizures Psychologically a new human being One that has never been Cursed by the shamen his voodoo spell has my soul My limbs go numb I can't control my own thought Are his now his evil consuming me ever telling me begin the clit carving Slowly turning me, into a flesh eating zombie Knowing this spell can only be broken by the vaginal skins of young women I proceed to find the meat their bleeding cunts will set me free Warmth seeping from this Body Rotted After I sucked the blood from her ass I feel more alive more alive than I've ever been Even though now I'm dead within My mouth drools As I slice your perinium My body smeared With the guts I've extracted through her hole, came swollen organs cunnilingus with the mutilated My spirit returned from the dead Released by the priest but I felt more real when I was dead The curse is broken I have a dependence on vaginal skin It's become my sexual addiction I must slit, the twitching clit Rott

## Metrics evaluation

In [10]:
# Function for evaluation through dataset

class ZeroShotClassifier(GenrePredictorInterface):
    def predict(self, batch_features: dict) -> np.array:
        lyrics_list = []
        
        for features in batch_features['features']:
            lyrics = features['lyrics']
            lyrics_list.append(lyrics)

        return proccess_batch(lyrics_list)

In [11]:
zero_shot_classifier = ZeroShotClassifier()

batch = next(iter(val_loader))

result = zero_shot_classifier.predict(batch)

print(result)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0

In [12]:
metrics = evaluate_model(zero_shot_classifier, test_loader)

print("Precision:", metrics['precision'])
print("Recall:", metrics['recall'])
print("F1-score:", metrics['f1'])

Evaluating: 100%|██████████| 13/13 [01:05<00:00,  5.02s/it]

Precision: 0.014204545454545454
Recall: 0.008116883116883118
F1-score: 0.007843521421107629



