# Fine-tuning Mistral-7b-Instruct to Respond to classify musical genres from playlist description.

GitHub: https://github.com/thefffilo/PlaylistCreator

### Imports

In [None]:
!pip install auto-gptq
!pip install optimum
!pip install bitsandbytes

In [None]:
# resolving "No inf checks were recorded for this optimizer." issue
!pip uninstall torch -y
!pip install torch==2.1

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
import transformers

### Load model

In [None]:
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto", # automatically figures out how to best use CPU + GPU for loading model
                                             trust_remote_code=False, # prevents running custom model files on your machine
                                             revision="main") # which version of model to use in repo

#Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

### Prepare Model for Training

In [None]:
model.train() # model in training mode (dropout modules are activated)

# enable gradient check pointing
model.gradient_checkpointing_enable()

# enable quantized training
model = prepare_model_for_kbit_training(model)

In [None]:
# LoRA config
config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# LoRA trainable version of model
model = get_peft_model(model, config)

# trainable parameter count
model.print_trainable_parameters()

trainable params: 2,097,152 || all params: 264,507,392 || trainable%: 0.7928519441906561


### Preparing Training Dataset

In [None]:
# load dataset
import transformers
from datasets import load_dataset
data = load_dataset("fffilo/genre-classifier-2")

In [None]:
# create tokenize function
def tokenize_function(examples):
    # extract text
    text = examples["example"]

    #tokenize and truncate text
    tokenizer.truncation_side = "left"
    tokenized_inputs = tokenizer(
        text,
        return_tensors="np",
        truncation=True,
        max_length=512
    )

    return tokenized_inputs

# tokenize training and validation datasets
tokenized_data = data.map(tokenize_function, batched=True)

In [None]:
# setting pad token
tokenizer.pad_token = tokenizer.eos_token
# data collator
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)


### Fine-tuning Model

In [None]:
# hyperparameters
lr = 2e-4
batch_size = 4
num_epochs = 10

# define training arguments
training_args = transformers.TrainingArguments(
    output_dir= "mistral-genre-classificator",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_accumulation_steps=4,
    warmup_steps=2,
    fp16=True,
    optim="paged_adamw_8bit",

)

In [None]:
# configure trainer
trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    args=training_args,
    data_collator=data_collator
)


# train model
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

# renable warnings
model.config.use_cache = True

### Push model to hub

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
model.push_to_hub('fffilo/mistral-genre-classificator')
trainer.push_to_hub('fffilo/mistral-genre-classificator')

In [None]:
tokenizer.push_to_hub('fffilo/mistral-genre-classificator') #???????

### Load Fine-tuned Model

In [None]:
# load model from hub
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "fffilo/mistral-genre-classificator"
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

config = PeftConfig.from_pretrained("fffilo/mistral-genre-classificator")
model = PeftModel.from_pretrained(model, "fffilo/mistral-genre-classificator")

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)



tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

### Use Fine-tuned Model

In [None]:
prompt = f"""[INST] genres: [acoustic,afrobeat,alt-rock,alternative,ambient,anime,black-metal,bluegrass,blues,bossanova,brazil,breakbeat,british,cantopop,chicago-house,children,chill,classical,club,comedy,country,dance,dancehall,death-metal,deep-house,detroit-techno,disco,disney,drum-and-bass,dub,dubstep,edm,electro,electronic,emo,folk,forro,french,funk,garage,german,gospel,goth,grindcore,groove,grunge,guitar,happy,hard-rock,hardcore,hardstyle,heavy-metal,hip-hop,holidays,honky-tonk,house,idm,indian,indie,indie-pop,industrial,j-dance,j-idol,j-pop,j-rock,jazz,k-pop,kids,latin,latino,malay,mandopop,metal,metal-misc,metalcore,minimal-techno,movies,mpb,new-age,new-release,opera,pagode,party,piano,pop,pop-film,post-dubstep,power-pop,progressive-house,psych-rock,punk,punk-rock,r-n-b,rainy-day,reggae,reggaeton,road-trip,rock,rock-n-roll,rockabilly,romance,sad,salsa,samba,sertanejo,show-tunes,singer-songwriter,ska,sleep,songwriter,soul,soundtracks,spanish,study,summer,swedish,synth-pop,tango,techno,trance,trip-hop,work-out,world-music]
As an assistant, identify the three most suitable genres from the list above for the following description: 'Some energetic music to help me workout this evening' [/INST]"""
print(prompt)

[INST] genres: [acoustic,afrobeat,alt-rock,alternative,ambient,anime,black-metal,bluegrass,blues,bossanova,brazil,breakbeat,british,cantopop,chicago-house,children,chill,classical,club,comedy,country,dance,dancehall,death-metal,deep-house,detroit-techno,disco,disney,drum-and-bass,dub,dubstep,edm,electro,electronic,emo,folk,forro,french,funk,garage,german,gospel,goth,grindcore,groove,grunge,guitar,happy,hard-rock,hardcore,hardstyle,heavy-metal,hip-hop,holidays,honky-tonk,house,idm,indian,indie,indie-pop,industrial,j-dance,j-idol,j-pop,j-rock,jazz,k-pop,kids,latin,latino,malay,mandopop,metal,metal-misc,metalcore,minimal-techno,movies,mpb,new-age,new-release,opera,pagode,party,piano,pop,pop-film,post-dubstep,power-pop,progressive-house,psych-rock,punk,punk-rock,r-n-b,rainy-day,reggae,reggaeton,road-trip,rock,rock-n-roll,rockabilly,romance,sad,salsa,samba,sertanejo,show-tunes,singer-songwriter,ska,sleep,songwriter,soul,soundtracks,spanish,study,summer,swedish,synth-pop,tango,techno,trance,

In [None]:
model.eval()

inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=25)

print(tokenizer.batch_decode(outputs)[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s> [INST] genres: [acoustic,afrobeat,alt-rock,alternative,ambient,anime,black-metal,bluegrass,blues,bossanova,brazil,breakbeat,british,cantopop,chicago-house,children,chill,classical,club,comedy,country,dance,dancehall,death-metal,deep-house,detroit-techno,disco,disney,drum-and-bass,dub,dubstep,edm,electro,electronic,emo,folk,forro,french,funk,garage,german,gospel,goth,grindcore,groove,grunge,guitar,happy,hard-rock,hardcore,hardstyle,heavy-metal,hip-hop,holidays,honky-tonk,house,idm,indian,indie,indie-pop,industrial,j-dance,j-idol,j-pop,j-rock,jazz,k-pop,kids,latin,latino,malay,mandopop,metal,metal-misc,metalcore,minimal-techno,movies,mpb,new-age,new-release,opera,pagode,party,piano,pop,pop-film,post-dubstep,power-pop,progressive-house,psych-rock,punk,punk-rock,r-n-b,rainy-day,reggae,reggaeton,road-trip,rock,rock-n-roll,rockabilly,romance,sad,salsa,samba,sertanejo,show-tunes,singer-songwriter,ska,sleep,songwriter,soul,soundtracks,spanish,study,summer,swedish,synth-pop,tango,techno,tra

In [None]:
prompt = f""" [INST] genres: [acoustic,afrobeat,alt-rock,alternative,ambient,anime,black-metal,blues,bossanova,brazil,breakbeat,british,children,chill,classical,club,comedy,country,dance,dancehall,death-metal,deep-house,disco,disney,drum-and-bass,dub,dubstep,edm,electro,electronic,emo,folk,french,funk,garage,german,gospel,goth,groove,grunge,guitar,happy,hard-rock,hardcore,hardstyle,heavy-metal,hip-hop,holidays,house,idm,indian,indie,indie-pop,industrial,j-pop,jazz,k-pop,kids,latin,latino,metal,metal-misc,metalcore,minimal-techno,movies,new-age,opera,party,piano,pop,pop-film,post-dubstep,power-pop,progressive-house,psych-rock,punk,punk-rock,r-n-b,rainy-day,reggae,reggaeton,road-trip,rock,rock-n-roll,rockabilly,romance,sad,salsa,samba,show-tunes,singer-songwriter,sleep,songwriter,soul,soundtracks,spanish,study,summer,swedish,synth-pop,tango,techno,trance,work-out,world-music]
As an assistant, identify the three most suitable genres from the list above for the following description: 'Nostalgic 80s hits that make you wish you were a teenager back then.' [/INST]
"""

model.eval()
inputs = tokenizer(prompt, return_tensors="pt")

outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=25)
print(tokenizer.batch_decode(outputs)[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s>  [INST] genres: [acoustic,afrobeat,alt-rock,alternative,ambient,anime,black-metal,blues,bossanova,brazil,breakbeat,british,children,chill,classical,club,comedy,country,dance,dancehall,death-metal,deep-house,disco,disney,drum-and-bass,dub,dubstep,edm,electro,electronic,emo,folk,french,funk,garage,german,gospel,goth,groove,grunge,guitar,happy,hard-rock,hardcore,hardstyle,heavy-metal,hip-hop,holidays,house,idm,indian,indie,indie-pop,industrial,j-pop,jazz,k-pop,kids,latin,latino,metal,metal-misc,metalcore,minimal-techno,movies,new-age,opera,party,piano,pop,pop-film,post-dubstep,power-pop,progressive-house,psych-rock,punk,punk-rock,r-n-b,rainy-day,reggae,reggaeton,road-trip,rock,rock-n-roll,rockabilly,romance,sad,salsa,samba,show-tunes,singer-songwriter,sleep,songwriter,soul,soundtracks,spanish,study,summer,swedish,synth-pop,tango,techno,trance,work-out,world-music] 
As an assistant, identify the three most suitable genres from the list above for the following description: 'Nostalgic 80

In [None]:
#Codice per estrarre solo i primi 3 generi utili (nell'array 'genres' non ci sono tutti i generi ma solo alcuni)
part1, part2 = tokenizer.batch_decode(outputs)[0].split('[/INST]')
print(part2)

import re

def estrai_prime_tre_parole(genres, stringa):
    # Dividi la stringa in "parole" basate su spazi o altri simboli, usando regex per trovare tutte le parole
    parole = re.findall(r'\b\w+\b', stringa)
    parole_trovate = []

    for parola in parole:
        # Pulisce ogni parola dai simboli, mantenendo solo caratteri alfabetici
        parola_pulita = re.sub(r'[^\w]', '', parola)
        # Controlla se una qualsiasi parola "pulita" è presente in genres
        for genere in genres:
            if re.search(r'\b' + re.escape(parola_pulita) + r'\b', genere, re.IGNORECASE):
                # Aggiungi alla lista delle parole trovate se corrisponde e non è già presente
                if genere not in parole_trovate:
                    parole_trovate.append(genere)
                break  # Interrompe il ciclo interno una volta trovata la corrispondenza
        if len(parole_trovate) == 3:
            break  # Interrompe il ciclo esterno una volta trovate 3 parole

    return parole_trovate

# Esempio di utilizzo
genres = ["acoustic", "afrobeat", "alt-rock", "salsa", "ambient", "anime", "black-metal", "bluegrass", "blues", "bossanova", "brazil", "breakbeat", "british", "cantopop", "salas", "children", "chill", "classical", "club", "country", "dance", "dancehall", "death-metal", "deep-house", "disco", "disney", "drum-and-bass", "dub", "dubstep", "edm", "electro", "electronic", "emo", "folk", "forro", "french", "funk", "garage", "german", "gospel", "goth", "grindcore", "groove", "grunge", "guitar", "happy", "hard-rock", "hardcore", "hardstyle", "heavy-metal", "hip-hop", "holidays", "honky-tonk", "house", "idm", "indian", "indie", "indie-pop", "industrial", "j-dance", "j-idol", "j-pop", "j-rock", "jazz", "k-pop", "kids", "latin", "latino"]
# genres: [acoustic,afrobeat,alt-rock,alternative,ambient,anime,black-metal,bluegrass,blues,bossanova,brazil,breakbeat,british,cantopop,chicago-house,children,chill,classical,club,comedy,country,dance,dancehall,death-metal,deep-house,detroit-techno,disco,disney,drum-and-bass,dub,dubstep,edm,electro,electronic,emo,folk,forro,french,funk,garage,german,gospel,goth,grindcore,groove,grunge,guitar,happy,hard-rock,hardcore,hardstyle,heavy-metal,hip-hop,holidays,honky-tonk,house,idm,indian,indie,indie-pop,industrial,j-dance,j-idol,j-pop,j-rock,jazz,k-pop,kids,latin,latino,malay,mandopop,metal,metal-misc,metalcore,minimal-techno,movies,mpb,new-age,new-release,opera,pagode,party,piano,pop,pop-film,post-dubstep,power-pop,progressive-house,psych-rock,punk,punk-rock,r-n-b,rainy-day,reggae,reggaeton,road-trip,rock,rock-n-roll,rockabilly,romance,sad,salsa,samba,sertanejo,show-tunes,singer-songwriter,ska,sleep,songwriter,soul,soundtracks,spanish,study,summer,swedish,synth-pop,tango,techno,trance,trip-hop,work-out,world-music]
parole_trovate = estrai_prime_tre_parole(genres, part2)
print(parole_trovate)



salsa, latin, latino 

[acoustic, afrobeat, alternative, ambient, anime, black-metal, bluegrass, blues, bossanova, brasil, breakbeat, british, cantopop, cha-cha-cha, children, choir, classical, club, comedy, country, deep-house, detroit-techno, disco, disco-funk, disco-soul, drum-and-bass, dub, dubstep, edm, electronic, emo, folk, forro, french, funk, garage, german, gospel, goth, grindcore, groove, grunge, guitar, hard-rock, hardcore, hardstyle, heavy-metal, metal, metal-misc, metalcore, minimal-techno, movies, new-age, new-release, opera, pagode, reggae, reggaeton, salsa, samba, singer-songwriter, ska, sleep, summer, swedish, synth-pop, tango, techno, trance, trip-hop, work-out, world-music] 

Genres: salsa, latin, latino 

[The genres salsa, lat
['salsa', 'latin', 'latino']
