### 加载模型

In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

model_name = "facebook/musicgen-medium"  # 可选：small, medium, large
# 初次使用记得去掉local_files_only=True
processor = AutoProcessor.from_pretrained(model_name, local_files_only=True)
model = MusicgenForConditionalGeneration.from_pretrained(model_name, local_files_only=True).half().to(device)
# model.half()解决精度问题报错

  from .autonotebook import tqdm as notebook_tqdm


device: cuda


### 加载数据集

In [2]:
import librosa
from datasets import Dataset
import os

def process_data(batch):
    # 加载音频并标准化
    audio, sr = librosa.load(os.path.join("./data/datashare/", batch["location"]), sr=32000)
    
    # 使用 processor 处理文本和音频
    inputs = processor(
        text=[batch["main_caption"]],
        audio=[audio],
        sampling_rate=32000,
        padding=True,
        return_tensors="pt",
    ).to(device)
    
    return inputs

from datasets import load_dataset
dataset = load_dataset("amaai-lab/MusicBench")


Using the latest cached version of the dataset since amaai-lab/MusicBench couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at C:\Users\XHZGenius\.cache\huggingface\datasets\amaai-lab___music_bench\default\0.0.0\b141e962aacc19ffd51c15732738040377989203 (last modified on Thu May 22 18:25:15 2025).


In [3]:
dataset = dataset.map(process_data, batched=False)
# dataset = dataset["train"].train_test_split(test_size=0.2)
dataset

Map: 100%|██████████| 52768/52768 [15:43<00:00, 55.94 examples/s]  
Map: 100%|██████████| 800/800 [00:13<00:00, 60.94 examples/s] 


DatasetDict({
    train: Dataset({
        features: ['dataset', 'location', 'main_caption', 'alt_caption', 'prompt_aug', 'prompt_ch', 'prompt_bt', 'prompt_bpm', 'prompt_key', 'beats', 'bpm', 'chords', 'chords_time', 'key', 'keyprob', 'is_audioset_eval_mcaps', 'input_ids', 'attention_mask', 'input_values', 'padding_mask'],
        num_rows: 52768
    })
    test: Dataset({
        features: ['dataset', 'location', 'main_caption', 'alt_caption', 'prompt_aug', 'prompt_ch', 'prompt_bt', 'prompt_bpm', 'prompt_key', 'beats', 'bpm', 'chords', 'chords_time', 'key', 'keyprob', 'is_audioset_eval_mcaps', 'input_ids', 'attention_mask', 'input_values', 'padding_mask'],
        num_rows: 800
    })
})

### 设置训练参数

In [4]:
from peft import LoraConfig, get_peft_model

# 定义 LoRA 配置
lora_config = LoraConfig(
    r=8,                  # LoRA 的秩（Rank）
    lora_alpha=32,        # 缩放因子
    target_modules=["q_proj", "v_proj"],  # 目标模块（MusicGen 的注意力层）
    lora_dropout=0.05,    # Dropout 率
    bias="none",          # 不调整偏置
    task_type="CAUSAL_LM", # 因果语言模型任务
)

# 应用 LoRA
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()  # 查看可训练参数（应远小于原始模型）
lora_model.save_pretrained("./outputs/musicgen-lora/initial_lora")

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./outputs/musicgen-lora",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    learning_rate=1e-4,  # LoRA 需要更高的学习率
    fp16=True,           # 混合精度训练
    logging_steps=100,
    save_steps=500,
)

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

trainable params: 4,718,592 || all params: 2,009,969,730 || trainable%: 0.2348


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [14]:
dataset["train"][0]

{'dataset': 'MusicBench',
 'location': 'data_aug2/-0SdAVK79lg_1.wav',
 'main_caption': 'This mellow instrumental track showcases a dominant electric guitar that opens with a descending riff, followed by arpeggiated chords, hammer-ons, and a slide. The percussion section keeps it simple with rim shots and a common time count, while the bass adds a single note on the first beat of every bar. Minimalist piano chords round out the song while leaving space for the guitar to shine. There are no vocals, making it perfect for a coffee shop or some chill background music. The key is in E major, with a chord progression that centers around that key and a straightforward 4/4 time signature.',
 'alt_caption': 'This song features an electric guitar as the main instrument. The guitar plays a descending run in the beginning then plays an arpeggiated chord followed by a double stop hammer on to a higher note and a descending slide followed by a descending chord run. The percussion plays a simple beat 

### 训练

In [6]:
trainer.train()

ValueError: expected sequence of length 88 at dim 2 (got 153)