## Here we will use zero-shot learning method as baseline for genre classification


In [1]:
import sys
import os

sys.path.append(os.path.abspath('../'))

In [2]:
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from src.utils import logger, DatasetTypes
from src.data import init_data
from src.metrics import GenrePredictorInterface, evaluate_model
from src.model import get_pretrained

import gc
import json
import re
from typing import List, Tuple


device

2025-06-18 01:58:34,693 - numexpr.utils - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2025-06-18 01:58:34,693 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.


'cuda'

## Get model

In [3]:
model_name = "Qwen/Qwen3-0.6B"
tokenizer, model = get_pretrained(model_name, device)
model.eval()

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe

## Get dataset with all genres and 1,294,054 examples 

In [4]:
path_to_csv = '../data/all_genres_downsampled.csv'
train_dataset, val_dataset, test_dataset , idx2genre, genres, traid_loader, val_loader, test_loader = init_data(path_to_csv=path_to_csv, batch_size=16, tokenizer=tokenizer, dataset_type=DatasetTypes.hundred, shuffle=True)

In [5]:
PROMPT = '''You are a music genre expert. You will determine whether a song belongs to a specific genre based on its lyrics. You will be provided with a JSON input containing the lyrics and the target genre. Respond with 1 if the song likely belongs to the specified genre, and 0 if it does not.

**Input format:**
```json
{
    "lyrics": "Lyrics of the song",
    "genre": "Target genre"
}
```

**Output format:**
```json
{
    "predict": 1  // if the song belongs to the genre
    // or
    "predict": 0  // if it does not
}
```

**Lyrics with genre for classification:**
```json
{
    "lyrics": "%s",
    "genre": "%s"
}
```

**Your output**:
'''

## Main mechanic and functions

In [6]:
# Main functions
def get_input_text(lyrics, genre, enable_thinking=False):
    instruct = PROMPT % (lyrics, genre)

    messages = [
        {"role": "user", "content": instruct}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking,
        max_thinking_tokens=50,
        do_sample=False
    )
    
    return text

def parse_model_response(generated_ids, model_inputs_len):
    assert generated_ids.ndim == 1
    
    output_ids = generated_ids[model_inputs_len:].tolist() 
    try:
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0
        
    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    assert len(content) != 0, 'Error. Output content is empty!'
    return thinking_content, content

def parse_output_json(response: str) -> int:
    try:
        match = re.search(r'\{[^}]*"predict"\s*:\s*(0|1)[^}]*\}', response)
        if match:
            data = json.loads(match.group(0))
            return int(data["predict"])
    except Exception as e:
        print(f"Parse error: {e}")
    return 0  # fallback to 0 if anything goes wrong

In [7]:
lyrics = val_dataset[0]['features']['lyrics']
target_genre = val_dataset[0]['features']['genre_list'][0]

truncated = lyrics[:300]

input_text = get_input_text(truncated, target_genre, enable_thinking=True)
print(input_text)

model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=1337
)

thinking_content, content = parse_model_response(generated_ids[0], len(model_inputs.input_ids[0]))

print("thinking content:", thinking_content)
print("content:", content)

result = parse_output_json(content)
print(f"result answer: {result}")

<|im_start|>user
You are a music genre expert. You will determine whether a song belongs to a specific genre based on its lyrics. You will be provided with a JSON input containing the lyrics and the target genre. Respond with 1 if the song likely belongs to the specified genre, and 0 if it does not.

**Input format:**
```json
{
    "lyrics": "Lyrics of the song",
    "genre": "Target genre"
}
```

**Output format:**
```json
{
    "predict": 1  // if the song belongs to the genre
    // or
    "predict": 0  // if it does not
}
```

**Lyrics with genre for classification:**
```json
{
    "lyrics": "[Verse 1] Well, I'm standing here, freezing, outside your golden garden Uh got my ladder, leaned up against your wall Tonight's the night we planned to run away together Come on Dolly Mae, there's no time to stall But now you're telling me [Chorus] I think I better wait until tomorrow I think I bett",
    "genre": "jazz"
}
```

**Your output**:
<|im_end|>
<|im_start|>assistant

thinking conten

In [8]:
# Also functions for batched generation
def build_prompts_and_map(
    self, lyrics_list: List[str]
) -> Tuple[List[str], List[int]]:
    prompts, idx_map = [], []
    for idx, txt in enumerate(lyrics_list):
        trunc = txt[: self.max_lyrics_length]
        for _g in self.genres:
            prompts.append(self.prompt_template % (trunc, _g))
            idx_map.append(idx)
    return prompts, idx_map


def get_input_texts_for_each_genre(sample: str, genres: List, enable_thinking=False):
    input_texts = []
    
    for genre in genres:
        input_text = get_input_text(sample, genre, enable_thinking=enable_thinking)
        input_texts.append(input_text) 
    
    return input_texts 


def proccess_sample(sample: str, genres: List, truncation_len: int = 300, max_new_tokens: int = 1337, enable_thinking=False) -> np.array:
    ''' Make predictions for one sample: whether it belongs to each genre. 
        Returns np array with 1 in corresponding places if lyrics belongs to genre.'''

    truncated = sample[:truncation_len]
    
    input_texts = get_input_texts_for_each_genre(truncated, genres, enable_thinking=enable_thinking)
    
    model_inputs_all_genres = tokenizer(input_texts, return_tensors="pt", padding=True, padding_side='left').to(model.device)
    
    generated_ids_all_genres = model.generate(
        **model_inputs_all_genres,
        max_new_tokens=max_new_tokens
    )
    
    preds = np.zeros(len(genres), dtype=np.int32)
    for i, (generated_ids, model_inputs) in enumerate(zip(generated_ids_all_genres, model_inputs_all_genres['input_ids'])):
        thinking_content, content = parse_model_response(generated_ids, len(model_inputs))
        result = parse_output_json(content)
        preds[i] = result
        
    return preds

def one_hot_encoded_to_genre_list(predictions, idx2genre: dict = None):
    ''' Predictions is array on n_genres size, where 1 if lyrics belongs to that genre and 0 if not'''    
    genre_list = []
    for i, value in enumerate(predictions):
        if value == 1:
            genre_list.append(idx2genre[i])
    
    return genre_list

In [9]:
sample = val_dataset[1]['features']['lyrics']
print(sample)

preds = proccess_sample(sample, genres)
print(preds)

genre_list = one_hot_encoded_to_genre_list(preds, idx2genre)
print(genre_list)

A relapse of my body Sends my mind into multiple seizures Psychologically a new human being One that has never been Cursed by the shamen his voodoo spell has my soul My limbs go numb I can't control my own thought Are his now his evil consuming me ever telling me begin the clit carving Slowly turning me, into a flesh eating zombie Knowing this spell can only be broken by the vaginal skins of young women I proceed to find the meat their bleeding cunts will set me free Warmth seeping from this Body Rotted After I sucked the blood from her ass I feel more alive more alive than I've ever been Even though now I'm dead within My mouth drools As I slice your perinium My body smeared With the guts I've extracted through her hole, came swollen organs cunnilingus with the mutilated My spirit returned from the dead Released by the priest but I felt more real when I was dead The curse is broken I have a dependence on vaginal skin It's become my sexual addiction I must slit, the twitching clit Rott

## Metrics evaluation

In [10]:
# Function for evaluation through dataset

class ZeroShotClassifier(GenrePredictorInterface):
    def predict(self, batch_features: dict) -> np.array:
        list_pred = []
        
        for features in batch_features['features']:
            lyrics = features['lyrics']
            
            pred = proccess_sample(lyrics, genres)
            list_pred.append(pred)

        return np.stack(list_pred)

In [11]:
zero_shot_classifier = ZeroShotClassifier()

batch = next(iter(val_loader))

result = zero_shot_classifier.predict(batch)

print(result)

[[1 1 0 ... 1 1 0]
 [0 1 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]
 ...
 [1 1 1 ... 1 1 0]
 [0 1 1 ... 1 1 0]
 [1 1 1 ... 1 1 1]]


In [12]:
metrics = evaluate_model(zero_shot_classifier, test_loader)

print("Precision:", metrics['precision'])
print("Recall:", metrics['recall'])
print("F1-score:", metrics['f1'])

Evaluating: 100%|██████████| 7/7 [05:33<00:00, 47.64s/it]

Precision: 0.01956986669338165
Recall: 0.4130279442779443
F1-score: 0.0349194116910554



