## Here we will use zero-shot learning method as baseline for genre classification


In [None]:
import sys
import os

sys.path.append(os.path.abspath('../'))

In [34]:
import numpy as np
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoModelForCausalLM, AutoTokenizer

from src.utils import logger, DatasetTypes
from src.data import get_datasets, get_dataloaders, one_hot_encoded_to_genre_list
from src.metrics import GenrePredictorInterface, evaluate_model

import json
import re

## Get model

In [8]:
# simpliest model for demonstration scenario
model_name = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map=device
)

tokenizer_config.json:   0%|          | 0.00/9.73k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

## Get dataset with all genres and 1,294,054 examples 

In [None]:
path_to_csv = '../data/all_genres_downsampled.csv'
data_dict = get_datasets(path_to_csv, tokenizer, dataset_type=DatasetTypes.small)

train_dataset, val_dataset, test_dataset = data_dict['train_dataset'], data_dict['val_dataset'], data_dict['test_dataset']
idx2genre, genre2idx = data_dict['idx2genre'], data_dict['genre2idx']
genres = [key for key, _ in genre2idx.items()]

batch_size = 16
traid_loader, val_loader, test_loader = get_dataloaders(train_dataset, val_dataset, test_dataset, batch_size)

## Sobstvenno, model

In [22]:
prompt_v1 = '''You are a music genre expert. You will determine whether a song belongs to a specific genre based on its lyrics. You will be provided with a JSON input containing the lyrics and the target genre. Respond with 1 if the song likely belongs to the specified genre, and 0 if it does not.

**Input format:**
```json
{
    "lyrics": "Lyrics of the song",
    "genre": "Target genre"
}
```

**Output format:**
```json
{
    "predict": 1  // if the song belongs to the genre
    // or
    "predict": 0  // if it does not
}
```

**Lyrics with genre for classification:**
```json
{
    "lyrics": "%s",
    "genre": "%s"
}
```

**Your output**:
'''

def parse_model_response(response: str) -> int:
    try:
        # Попробуем извлечь JSON через регулярку (на случай мусора вокруг)
        match = re.search(r'\{[^}]*"predict"\s*:\s*(0|1)[^}]*\}', response)
        if match:
            data = json.loads(match.group(0))
            return int(data['predict'])
    except Exception as e:
        print(f"Parsing error: {e}")

    raise ValueError("Could not parse prediction from model response.")


class ZeroShotClassifier(GenrePredictorInterface):
    def __init__(self, model, tokenizer, genres, prompt_template, device="cuda", max_lyrics_length=300):
        self.model = model
        self.tokenizer = tokenizer
        self.genres = genres  # список всех возможных жанров
        self.device = device
        self.max_lyrics_length = max_lyrics_length
        self.prompt_template = prompt_template
        
    def _make_prompts(self, lyrics: str) -> list[str]:
        truncated = lyrics[:self.max_lyrics_length].replace('\n', ' ').replace('"', "'")
        prompts = [self.prompt_template % (truncated, genre) for genre in self.genres]
        return prompts

    def _parse_response(self, response: str) -> int:
        try:
            match = re.search(r'\{[^}]*"predict"\s*:\s*(0|1)[^}]*\}', response)
            if match:
                data = json.loads(match.group(0))
                return int(data["predict"])
        except Exception as e:
            print(f"Parse error: {e}")
        return 0  # fallback to 0 if anything goes wrong
    

def make_prompts(lyrics: str, genres) -> list[str]:
    truncated = lyrics[:300].replace('\n', ' ').replace('"', "'")
    prompts = [prompt_v1 % (truncated, genre) for genre in genres]
    return prompts

def parse_response(response: str) -> int:
    try:
        match = re.search(r'\{[^}]*"predict"\s*:\s*(0|1)[^}]*\}', response)
        if match:
            data = json.loads(match.group(0))
            return int(data["predict"])
    except Exception as e:
        print(f"Parse error: {e}")
    return 0  # fallback to 0 if anything goes wrong

Main mechanic

In [None]:
# suppose we have lyrics and target genre
lyrics = val_dataset[0]['features']['lyrics']
target_genre = val_dataset[0]['features']['genre_list'][0]
print(lyrics[:100])
print(target_genre)

# let's try ask model if that song is belongs to target genre
truncated = lyrics[:300]
instruct = prompt_v1 % (truncated, target_genre)

print(instruct)

# prepare the model input
messages = [
    {"role": "user", "content": instruct}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False,
    do_sample=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

try:
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0
    
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
print("thinking content:", thinking_content)
print("content:", content)

[Verse 1] Well, I'm standing here, freezing, outside your golden garden Uh got my ladder, leaned up 
jazz
You are a music genre expert. You will determine whether a song belongs to a specific genre based on its lyrics. You will be provided with a JSON input containing the lyrics and the target genre. Respond with 1 if the song likely belongs to the specified genre, and 0 if it does not.

**Input format:**
```json
{
    "lyrics": "Lyrics of the song",
    "genre": "Target genre"
}
```

**Output format:**
```json
{
    "predict": 1  // if the song belongs to the genre
    // or
    "predict": 0  // if it does not
}
```

**Lyrics with genre for classification:**
```json
{
    "lyrics": "[Verse 1] Well, I'm standing here, freezing, outside your golden garden Uh got my ladder, leaned up against your wall Tonight's the night we planned to run away together Come on Dolly Mae, there's no time to stall But now you're telling me [Chorus] I think I better wait until tomorrow I think I bett",
    

In [None]:
class ZeroShotClassifierV1(ZeroShotClassifier):
    def __init__(self, model, tokenizer, genres, prompt_template, device="cuda", max_lyrics_length=300, batch_size=2):
        """
        batch_size — размер мини-батча для генерации (не batch['features']).
        """
        super().__init__(model, tokenizer, genres, prompt_template, device, max_lyrics_length)
        self.batch_size = batch_size

    def predict(self, batch: dict, enable_thinking=False, debug: bool = False) -> np.ndarray:
        lyrics_list = [row['lyrics'] for row in batch['features']]
        all_prompts = []
        index_map = []

        for i, lyrics in enumerate(lyrics_list):
            truncated = lyrics[:self.max_lyrics_length]
            for genre in self.genres:
                prompt = self.prompt_template % (truncated, genre)
                all_prompts.append(prompt)
                index_map.append(i)

        if debug:
            logger.info(f"Total prompts: {len(all_prompts)}")
            logger.info(f"Example prompt:\n{all_prompts[0]}")

        # Шаблоны превращаются в текст через chat_template
        instruct_texts = [
            self.tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=enable_thinking,
                do_sample=False
            )
            for prompt in all_prompts
        ]

        # Подаём списками по batch_size
        generated_texts = []
        full_generated = []
        self.model.eval()

        for i in range(0, len(instruct_texts), self.batch_size):
            batch_texts = instruct_texts[i:i + self.batch_size]
            model_inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **model_inputs,
                    max_new_tokens=1024,
                    pad_token_id=self.tokenizer.eos_token_id
                )

            for j in range(len(batch_texts)):
                output_ids = outputs[j][len(model_inputs["input_ids"][j]):].tolist()

                try:
                    split_idx = len(output_ids) - output_ids[::-1].index(151668)
                except ValueError:
                    split_idx = 0

                thinking_content = tokenizer.decode(output_ids[:split_idx], skip_special_tokens=True).strip("\n")
                main_output = self.tokenizer.decode(output_ids[split_idx:], skip_special_tokens=True).strip()
                
                full_generated.append(thinking_content + main_output)
                generated_texts.append(main_output)

        if debug:
            logger.info("Sample model outputs:\n" + "\n---\n".join(generated_texts[:3]))

        # Собираем финальные предсказания
        batch_size = len(lyrics_list)
        num_genres = len(self.genres)
        predictions = np.zeros((batch_size, num_genres), dtype=np.int32)

        for i, raw_output in enumerate(generated_texts):
            sample_idx = index_map[i]
            genre_idx = i % num_genres
            try:
                predictions[sample_idx, genre_idx] = self._parse_response(raw_output)
            except Exception as e:
                logger.warning(f"Failed to parse output: {raw_output}, error: {e}")

        if debug:
            for i, pred in enumerate(predictions):
                predicted_genres = [g for g, flag in zip(self.genres, pred) if flag]
                logger.info(f"Sample {i} predicted genres: {predicted_genres}")

        return predictions, full_generated, instruct_texts

In [None]:
def test_classifier(enable_thinking):
    classifier_v1 = ZeroShotClassifierV1(model, tokenizer, genres, prompt_v1, device=device, max_lyrics_length=250, batch_size=128)

    batch = next(iter(val_loader))
    batch['features'] = batch['features'][:1]
    ground_truth = batch['labels'][:1]

    predictions, generated_texts, instruct_texts = classifier_v1.predict(batch, enable_thinking=enable_thinking)
    print('Ground trith labels:', ground_truth)
    print('Predicted labels:', predictions)

    print("\nLet's take a look at specific instruct:")
    print(instruct_texts[10])
    print("\nAnd here is the answer:")
    print(generated_texts[10])
    print('\nActual genre was:', one_hot_encoded_to_genre_list(ground_truth[0], idx2genre))
    print('\nPredicted genre is:', one_hot_encoded_to_genre_list(predictions[0], idx2genre))
    
print('Test classifier with turned on thinking mode')
test_classifier(True)

print('\n\n\n\nTest classifier with turned off thinking mode')
test_classifier(False)

Test classifier with turned on thinking mode
