### 加载模型

In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

model_name = "facebook/musicgen-medium"  # 可选：small, medium, large
# 初次使用记得去掉local_files_only=True
processor = AutoProcessor.from_pretrained(model_name, local_files_only=True)
model = MusicgenForConditionalGeneration.from_pretrained(model_name, local_files_only=True).half().to(device)
# model.half()解决精度问题报错

  from .autonotebook import tqdm as notebook_tqdm


device: cuda


### 加载数据集

In [25]:
import librosa
from datasets import Dataset
import os

def process_data(batch):
    # 加载音频并标准化
    audios = []
    for path in batch["location"]:
        audio, sr = librosa.load(os.path.join("./data/datashare/", path), sr=32000)
        audios.append(audio)
    
    # 使用 processor 处理文本和音频
    inputs = processor(
        text=batch["main_caption"],
        audio=audios,
        sampling_rate=32000,
        padding=True,
        return_tensors="pt",
    ).to(device)
    
    return inputs

In [26]:
from datasets import load_dataset
dataset = load_dataset("amaai-lab/MusicBench")

In [27]:
dataset = dataset.map(process_data, batched=True, batch_size=2)
# dataset = dataset["train"].train_test_split(test_size=0.2)
dataset

Map:  17%|█▋        | 8942/52768 [02:32<10:48, 67.61 examples/s]

### 设置lora参数

In [None]:
from peft import LoraConfig, get_peft_model

# 定义 LoRA 配置
lora_config = LoraConfig(
    r=8,                  # LoRA 的秩（Rank）
    lora_alpha=32,        # 缩放因子
    target_modules=["q_proj", "v_proj"],  # 目标模块（MusicGen 的注意力层）
    lora_dropout=0.05,    # Dropout 率
    bias="none",          # 不调整偏置
    task_type="CAUSAL_LM", # 因果语言模型任务
)

# 应用 LoRA
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()  # 查看可训练参数（应远小于原始模型）
lora_model.save_pretrained("./outputs/musicgen-lora/initial_lora")

### 设置训练参数

In [None]:

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./outputs/musicgen-lora",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    learning_rate=1e-4,  # LoRA 需要更高的学习率
    fp16=True,           # 混合精度训练
    logging_steps=100,
    save_steps=500,
)

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

### 训练

In [None]:
trainer.train()