### 加载模型

In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

model_name = "facebook/musicgen-medium"  # 可选：small, medium, large
# 初次使用记得去掉local_files_only=True
processor = AutoProcessor.from_pretrained(model_name, local_files_only=True)
model = MusicgenForConditionalGeneration.from_pretrained(model_name, local_files_only=True).half().to(device)
# model.half()解决精度问题报错

  from .autonotebook import tqdm as notebook_tqdm


device: cuda


### 加载数据集

In [23]:
import librosa
from datasets import Dataset
import os
import numpy as np

def process_data(batch):
    # 加载音频并标准化
    audio, sr = librosa.load(os.path.join("./data/datashare/", batch["location"]), sr=32000)
    current_length = len(audio)
    target_length = 320000 # 10s*32000/s
    if current_length >= target_length:
        padded_audio = audio[:target_length]  # 如果音频过长，截断
    else:
        # 计算需要补零的数量
        padding_length = target_length - current_length
        # 在末尾补零（静音）
        padded_audio = np.pad(audio, (0, padding_length), mode='constant').flatten()
    
    # 使用 processor 处理文本和音频
    inputs = processor(
        text=[batch["main_caption"]],
        audio=[padded_audio],
        sampling_rate=32000,
        padding="max_length",
        max_length=512,
        return_tensors="pt",
    ).to(device)
    inputs["input_ids"] = inputs["input_ids"].flatten()
    # inputs["input_values"] = inputs["input_values"].flatten()
    # print(inputs["input_ids"].shape, inputs["input_values"].shape)
    
    return inputs

In [24]:
from datasets import load_dataset
dataset = load_dataset("amaai-lab/MusicBench", split="test")

In [25]:
dataset = dataset.map(process_data, batched=False)
# dataset = dataset["train"].train_test_split(test_size=0.2)
dataset = dataset.train_test_split(0.2)

Map: 100%|██████████| 800/800 [00:09<00:00, 84.63 examples/s] 


### 设置lora参数

In [26]:
from peft import LoraConfig, get_peft_model

# 定义 LoRA 配置
lora_config = LoraConfig(
    r=8,                  # LoRA 的秩（Rank）
    lora_alpha=32,        # 缩放因子
    target_modules=["q_proj", "v_proj"],  # 目标模块（MusicGen 的注意力层）
    lora_dropout=0.05,    # Dropout 率
    bias="none",          # 不调整偏置
    task_type="CAUSAL_LM", # 因果语言模型任务
)

# 应用 LoRA
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()  # 查看可训练参数（应远小于原始模型）
lora_model.save_pretrained("./outputs/musicgen-lora/initial_lora")

trainable params: 4,718,592 || all params: 2,009,969,730 || trainable%: 0.2348


### 设置训练参数

In [27]:

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./outputs/musicgen-lora",
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=1e-4,  # LoRA 需要更高的学习率
    fp16=True,           # 混合精度训练
    logging_steps=100,
    save_steps=500,
)

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [28]:
dataset["train"][0]

{'dataset': 'MusicBench',
 'location': 'data/tmabzx6yxqs.wav',
 'main_caption': 'This kids song features groovy synth keys, chords, simple piano melody, echoing and wide pig oink sound effects, shimmering shakers, muffled claps, soft kick hits and groovy bass. It sounds fun and happy - like something that children would listen to in their favorite tv shows.',
 'alt_caption': 'This kids song features groovy synth keys, chords, simple piano melody, echoing and wide pig oink sound effects, shimmering shakers, muffled claps, soft kick hits and groovy bass. It sounds fun and happy - like something that children would listen to in their favorite tv shows.',
 'prompt_aug': '',
 'prompt_ch': 'The chord sequence is Bbmaj7, F7.',
 'prompt_bt': 'The time signature is 4/4.',
 'prompt_bpm': 'This song is played in Moderato.',
 'prompt_key': 'The key of this song is Bb major.',
 'beats': [[0.06,
   0.62,
   1.18,
   1.76,
   2.32,
   2.88,
   3.44,
   4.0,
   4.58,
   5.14,
   5.7,
   6.26,
   6.84,

### 训练

In [29]:
trainer.train()

torch.Size([4, 512])


RuntimeError: output with shape [4, 12, 512, 512] doesn't match the broadcast shape [4, 4, 12, 512, 512]