In [1]:
import os
import argparse
from tqdm import tqdm
import typing as tp
import pandas as pd
import numpy as np

import torch
import torchaudio
from audiocraft.models import MusicGen, MultiBandDiffusion
from audiocraft.data.audio_utils import normalize_audio

from config import CFG
from utils import set_seed

objc[10323]: Class AVFFrameReceiver is implemented in both /usr/local/Cellar/ffmpeg/6.0_1/lib/libavdevice.60.1.100.dylib (0x11866a378) and /Users/macos/miniconda3/envs/zaic2023/lib/python3.10/site-packages/av/.dylibs/libavdevice.59.7.100.dylib (0x11c8c0118). One of the two will be used. Which one is undefined.
objc[10323]: Class AVFAudioReceiver is implemented in both /usr/local/Cellar/ffmpeg/6.0_1/lib/libavdevice.60.1.100.dylib (0x11866a3c8) and /Users/macos/miniconda3/envs/zaic2023/lib/python3.10/site-packages/av/.dylibs/libavdevice.59.7.100.dylib (0x11c8c0168). One of the two will be used. Which one is undefined.


In [2]:
model = MusicGen.get_pretrained(CFG.AUDIO_CRAFT_MODEL, device=CFG.DEVICE)

model.lm.load_state_dict(
    torch.load(
        "model/musicgen_4000.0.pt",
        map_location=CFG.DEVICE
    )
)



<All keys matched successfully>

In [3]:
BATCH_SIZE = 1
SAMPLE_RATE = 16000
DURATION = 10
JSON_PATH = "private/private.json"

In [4]:
df_inference = pd.read_json(JSON_PATH, orient="index").reset_index()
df_inference.columns = ["filename", "description"]

In [5]:
def generate(model, descriptions):
    attributes, prompt_tokens = model._prepare_tokens_and_attributes(descriptions, None)
    model.generation_params = {
        'max_gen_len': int(DURATION/2 * model.frame_rate),
        'use_sampling': True,
        'temp': 1.0,
        'top_k': 250,
        'top_p': 0.5,
        'cfg_coef': 3.0,
        'two_step_cfg': False,
    }
    total = []
    for _ in range(1):
        with model.autocast:
            gen_tokens = model.lm.generate(
                prompt_tokens, attributes, callback=None, **model.generation_params)
            total.append(
                gen_tokens[..., prompt_tokens.shape[-1] if prompt_tokens is not None else 0:])
            prompt_tokens = gen_tokens[..., -gen_tokens.shape[-1] // 2:]
    gen_tokens = torch.cat(total, -1)

    assert gen_tokens.dim() == 3
    with torch.no_grad():
        gen_audio = model.compression_model.decode(gen_tokens, None)

    for idx, one_wav in enumerate(gen_audio):
        assert one_wav.dtype.is_floating_point, "wav is not floating point"
        if one_wav.dim() == 1:
            one_wav = one_wav[None]
        elif one_wav.dim() > 2:
            raise ValueError("Input wav should be at most 2 dimension.")
        assert one_wav.isfinite().all()
    return one_wav

In [None]:
import time
all_predicted_time = []
output_path1 = "results/jupyter_submission1"

if not os.path.exists(output_path1):
    os.makedirs(output_path1)

for idx, batch in tqdm(df_inference.groupby(np.arange(len(df_inference)) // BATCH_SIZE)):
    print('Processing batch', batch)
    print("---------------------")
    filenames = batch["filename"].tolist()
    descriptions = batch["description"].tolist()
    t1 = time.time()
    # ***************Start model prediction******************
    one_wav = generate(model, descriptions)
    one_wav = normalize_audio(
        wav=one_wav.cpu(),
        strategy=CFG.STRATEGY,
        peak_clip_headroom_db=CFG.PEAK,
        loudness_compressor=CFG.LOUDNESS_COMPRESSOR,
        sample_rate=CFG.SAMPLE_RATE,
    )
    path_submission1 = os.path.join(output_path1, filenames[0])
    torchaudio.save(path_submission1, one_wav, CFG.SAMPLE_RATE)
    # ***************End model prediction******************
    t2 = time.time()
    predicted_time = t2 - t1
    all_predicted_time.append((filenames, predicted_time))

df = pd.DataFrame(all_predicted_time, columns=["fname", "time"])
df.to_csv("results/time_submission1.csv", index=False)

  0%|                                                                                                                                                    | 0/1000 [00:00<?, ?it/s]

Processing batch                 filename                                        description
0  1699168496.395952.mp3  The recording features a widely spread electri...
---------------------


  0%|▏                                                                                                                                        | 1/1000 [01:13<20:15:47, 73.02s/it]

Processing batch                 filename                                        description
1  1699168495.217152.mp3  The recording features a cover of a rock song ...
---------------------


  0%|▎                                                                                                                                        | 2/1000 [02:28<20:35:14, 74.26s/it]

Processing batch                  filename                                        description
2  1699168495.1176987.mp3  The recording features an arpeggiated acoustic...
---------------------


  0%|▍                                                                                                                                        | 3/1000 [03:31<19:07:52, 69.08s/it]

Processing batch                  filename                                        description
3  1699168498.4178677.mp3  The recording features a cover of a rock song ...
---------------------


  0%|▌                                                                                                                                        | 4/1000 [04:38<18:55:20, 68.39s/it]

Processing batch                  filename                                        description
4  1699168495.6089337.mp3  The recording features an arpeggiated acoustic...
---------------------


  0%|▋                                                                                                                                        | 5/1000 [05:49<19:09:44, 69.33s/it]

Processing batch                 filename                                        description
5  1699168495.505732.mp3  The recording features a cover of a rock song ...
---------------------


In [None]:
import time
all_predicted_time = []
output_path2 = "results/jupyter_submission2"

if not os.path.exists(output_path2):
    os.makedirs(output_path2)

for idx, batch in tqdm(df_inference.groupby(np.arange(len(df_inference)) // BATCH_SIZE)):
    print('Processing batch', batch)
    print("---------------------")
    filenames = batch["filename"].tolist()
    descriptions = batch["description"].tolist()
    t1 = time.time()
    # ***************Start model prediction******************
    forward = generate(model, descriptions)
    one_wav = normalize_audio(
        wav=one_wav.cpu(),
        strategy=CFG.STRATEGY,
        peak_clip_headroom_db=CFG.PEAK,
        loudness_compressor=CFG.LOUDNESS_COMPRESSOR,
        sample_rate=CFG.SAMPLE_RATE,
    )
    path_submission2 = os.path.join(output_path2, filenames[0])
    torchaudio.save(path_submission2, forward, CFG.SAMPLE_RATE)
    # ***************End model prediction******************
    t2 = time.time()
    predicted_time = t2 - t1
    all_predicted_time.append((filenames, predicted_time))

df = pd.DataFrame(all_predicted_time, columns=["fname", "time"])
df.to_csv("results/time_submission2.csv", index=False)