In [1]:
from transformers import AutoProcessor, MusicgenForConditionalGeneration, MusicgenProcessor, MusicgenForCausalLM
from peft.config import PeftConfig
from peft import get_peft_config, get_peft_model, PeftModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from audiocraft.utils.notebook import display_audio

In [3]:
config = PeftConfig.from_pretrained("tuned-musicgen-small-lora")

In [4]:

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")

full_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
decoder = PeftModel.from_pretrained(full_model.decoder, "tuned-musicgen-small-lora")
full_model.decoder = decoder


In [5]:
full_model.config

MusicgenConfig {
  "_name_or_path": "facebook/musicgen-small",
  "architectures": [
    "MusicgenForConditionalGeneration"
  ],
  "audio_encoder": {
    "_name_or_path": "facebook/encodec_32khz",
    "add_cross_attention": false,
    "architectures": [
      "EncodecModel"
    ],
    "audio_channels": 1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_length_s": null,
    "chunk_size_feed_forward": 0,
    "codebook_dim": 128,
    "codebook_size": 2048,
    "compress": 2,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "dilation_growth_rate": 2,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_size": 128,
    "id2label": {
      "0": "LABEL_0",
      "1": "

In [6]:
list([k for k, v in full_model.decoder.named_parameters()])

['base_model.model.model.decoder.embed_tokens.0.weight',
 'base_model.model.model.decoder.embed_tokens.1.weight',
 'base_model.model.model.decoder.embed_tokens.2.weight',
 'base_model.model.model.decoder.embed_tokens.3.weight',
 'base_model.model.model.decoder.embed_positions.weights',
 'base_model.model.model.decoder.layers.0.self_attn.k_proj.weight',
 'base_model.model.model.decoder.layers.0.self_attn.v_proj.base_layer.weight',
 'base_model.model.model.decoder.layers.0.self_attn.v_proj.lora_A.default.weight',
 'base_model.model.model.decoder.layers.0.self_attn.v_proj.lora_B.default.weight',
 'base_model.model.model.decoder.layers.0.self_attn.q_proj.base_layer.weight',
 'base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.default.weight',
 'base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_B.default.weight',
 'base_model.model.model.decoder.layers.0.self_attn.out_proj.weight',
 'base_model.model.model.decoder.layers.0.self_attn_layer_norm.weight',
 'base_model.

In [7]:
full_model = full_model.to("cuda")
full_model.eval()

MusicgenForConditionalGeneration(
  (text_encoder): T5EncoderModel(
    (shared): Embedding(32128, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=768, out_features=3072, bias=False)
                (wo): L

In [8]:
[
 'BLADE',
 'Crystals',
 'DJ',
 'DVRST',
 'Dxrk',
 'GLICHERY',
 'GUDOG',
 'Hensonn',
 'INTERWORLD',
 'KIIXSHI',
 'KORDHELL',
 'MANO',
 'Memory',
 'NEON',
 'Onimxru',
 'Phonk',
 'Reboot',
 'SCXR',
 'XDXVIL',
 'g3ox_em',
 'yatashigang']

['BLADE',
 'Crystals',
 'DJ',
 'DVRST',
 'Dxrk',
 'GLICHERY',
 'GUDOG',
 'Hensonn',
 'INTERWORLD',
 'KIIXSHI',
 'KORDHELL',
 'MANO',
 'Memory',
 'NEON',
 'Onimxru',
 'Phonk',
 'Reboot',
 'SCXR',
 'XDXVIL',
 'g3ox_em',
 'yatashigang']

In [9]:
{'100bpm',
 '102bpm',
 '106bpm',
 '120bpm',
 '126bpm',
 '128bpm',
 '131bpm',
 '133bpm',
 '138bpm',
 '96bpm',
 '98bpm',
 'BLADE',
 'Brazilian',
 'Crystals',
 'DJ',
 'DVRST',
 'Dxrk',
 'GLICHERY',
 'GUDOG',
 'Hensonn',
 'INTERWORLD',
 'KIIXSHI',
 'KORDHELL',
 'MANO',
 'Memory',
 'NEON',
 'Onimxru',
 'Phonk',
 'Reboot',
 'SCXR',
 'XDXVIL',
 'brazilian phonk',
 'fast',
 'g3ox_em',
 'medium',
 'not',
 'phonk',
 'sad',
 'slowly',
 'style',
 'synthwave',
 'tempo',
 'tempo,',
 'track',
 'with',
 'yatashigang'}

{'100bpm',
 '102bpm',
 '106bpm',
 '120bpm',
 '126bpm',
 '128bpm',
 '131bpm',
 '133bpm',
 '138bpm',
 '96bpm',
 '98bpm',
 'BLADE',
 'Brazilian',
 'Crystals',
 'DJ',
 'DVRST',
 'Dxrk',
 'GLICHERY',
 'GUDOG',
 'Hensonn',
 'INTERWORLD',
 'KIIXSHI',
 'KORDHELL',
 'MANO',
 'Memory',
 'NEON',
 'Onimxru',
 'Phonk',
 'Reboot',
 'SCXR',
 'XDXVIL',
 'brazilian phonk',
 'fast',
 'g3ox_em',
 'medium',
 'not',
 'phonk',
 'sad',
 'slowly',
 'style',
 'synthwave',
 'tempo',
 'tempo,',
 'track',
 'with',
 'yatashigang'}

In [10]:

import torch


inputs = processor(
    text=[
        # "phonk",
        # "phonk",
        # "phonk",
        # "slow tempo phonk XDXVIL 98bpm",
        "phonk track 125bpm",
        "fast cool phonk track",
        "brazilian sad phonk track",
        'modern track with bassy drums in style in medium tempo' ,
        'modern phink track with bassy drums in style in medium tempo' ,
        # "Generate a phonk track with a blend of old-school samples, gritty beats, and smooth, jazzy undertones. Aim for that nostalgic '90s vibe.",
        # "Create a phonk-inspired composition featuring vintage vinyl crackles, soulful melodies, and a laid-back groove. Think Memphis rap meets lo-fi vibes.",
        # "Generate a phonk track that captures the essence of underground hip-hop with chopped-up soul samples, deep basslines, and a distinct, gritty atmosphere.",
        # "phonk-infused piece combining eerie, haunting melodies with hard-hitting drums and elements reminiscent of classic Southern rap.",
        # "a chill, late-night phonk track with hypnotic loops, mellow chords, and a relaxed yet head-nodding rhythm. Imagine cruising through the city at midnight.",
        # "a phonk-style beat with atmospheric textures, vintage samples, and a touch of distortion for that raw, authentic feel.",
        # "a phonk-inspired composition that pays homage to the '90s era of hip-hop, blending funk-infused rhythms, soulful samples, and a touch of dark ambiance.",
        # "a nostalgic phonk track using pitched-down vocals, old-school samples, and a combination of smooth jazz elements with heavy basslines.",
        # "Generate a lo-fi phonk track with a dreamy atmosphere, warped samples, and laid-back drum patterns that transport the listener to a bygone era.",
        # "Craft a phonk-style composition that fuses elements of funk, jazz, and hip-hop, incorporating dusty samples and a distinct urban vibe.",
    ],
    padding=True,
    return_tensors="pt",
)
inputs = {k:v.to("cuda") for k, v in inputs.items()}
with torch.inference_mode():
    audio_values = full_model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1000)

torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 1, 64])
torch.Size([10, 16, 1, 64]) torch.Size([10, 16, 

In [11]:
# audio_values = processor.batch_decode(audio_values, padding_mask=inputs["padding_mask"])

In [16]:
display_audio(audio_values[2], sample_rate=32000)

In [13]:
full_model.generation_config

GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 2048,
  "decoder_start_token_id": 2048,
  "do_sample": true,
  "guidance_scale": 3.0,
  "max_length": 1500,
  "pad_token_id": 2048,
  "transformers_version": "4.33.2"
}