In [1]:
import os
import json
import math
import sys
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from config import Config
from audiomodel import AudioProcessing
from audiodataset import AudioDataset, TestDataset

def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def build_model(cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', device=cfg.device)
        return compression_model, lm

    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118)
    Python  3.10.13 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [2]:
prompts = [
    "toilet water, music",
    "horn sound, crash sound",
    "horn sound followed by crashing cars",
    "man gregorian chanting, inside knight's helm",
    "large metal spikeball, hitting ground",
    "heavy footsteps wrapped in chains",
    "robotic interal mechanism moving",
    "frog followed by woosh",
    "rain, frog",
    "fire, wood building falling",
    "gun sound and then child crying",
    "big, church bell",
    "morning alarm",
    "machine exploding, parts falling",
    "missile firing and exploding",
    "whoosh, ice",
    "dragon wings flapping",
    "biting an apple",
    "walking on the shallow water",
    "engine starting up",
    "violin, concert hall",
    "rolling dice",
    "wine glass falling",
    "printer printing paper",
    "laser gun",
    "multimedia, notification",
    "truck accelerating",
    "frying chicken in a oil",
    "running, basketball court",
    "lightning hits tree",
    "rifle reloading",
    "fart sound machine gun",
    "crickets chirping",
    "chainsaw cutting tree",
    "girl, whispering"
]

In [None]:
from audiotools import AudioSignal
torch.cuda.empty_cache()

cfg = Config()
cfg.update(**{"prompts": [p for p in prompts for _ in range(3)]})

accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
save_path = "./test_base22"
make_dir(save_path)
cfg.update(**{"save_path": save_path})

compression_model, lm = build_model(cfg)
model = AudioProcessing(cfg, lm)

test_dataset = TestDataset(cfg)
test_dataloader = DataLoader(test_dataset, batch_size=1)

model, compression_model = accelerator.prepare(model, compression_model)
model_path = os.path.join("./compare/base_19.pth")
model.load_state_dict(torch.load(model_path))

model.eval()
compression_model.eval()
if accelerator.is_main_process:
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_vae = accelerator.unwrap_model(compression_model)
    audio_num = 1
    for test_step, batch in tqdm(enumerate(test_dataloader)):
        gen_tokens, gen_audio = unwrapped_model.inference(batch, unwrapped_vae)
        prompt = batch[0]
        print(prompt)
        audio_filename = f"{prompt}_{audio_num}.wav"
        unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
        # from IPython.display import Audio
        # display(Audio(data=gen_audio[0].detach().cpu().numpy(), rate=cfg.sample_rate))
        # AudioSignal(gen_audio[0].detach().cpu().numpy(), sample_rate=cfg.sample_rate).widget()
        audio_num += 1
        if audio_num > 5:
            audio_num = 1

1it [00:05,  5.46s/it]

toilet water, music


2it [00:10,  5.15s/it]

toilet water, music


3it [00:15,  5.14s/it]

toilet water, music


4it [00:20,  5.04s/it]

horn sound, crash sound


5it [00:25,  4.99s/it]

horn sound, crash sound


6it [00:30,  4.93s/it]

horn sound, crash sound


7it [00:34,  4.89s/it]

horn sound followed by crashing cars


8it [00:39,  4.89s/it]

horn sound followed by crashing cars


9it [00:44,  4.87s/it]

horn sound followed by crashing cars


10it [00:49,  4.87s/it]

man gregorian chanting, inside knight's helm


11it [00:54,  4.86s/it]

man gregorian chanting, inside knight's helm


12it [00:59,  4.85s/it]

man gregorian chanting, inside knight's helm


13it [01:03,  4.84s/it]

large metal spikeball, hitting ground


14it [01:08,  4.88s/it]

large metal spikeball, hitting ground


15it [01:14,  5.02s/it]

large metal spikeball, hitting ground


16it [01:19,  4.96s/it]

heavy footsteps wrapped in chains


17it [01:24,  5.02s/it]

heavy footsteps wrapped in chains


18it [01:29,  5.05s/it]

heavy footsteps wrapped in chains


19it [01:34,  5.06s/it]

robotic interal mechanism moving


20it [01:39,  4.99s/it]

robotic interal mechanism moving


21it [01:44,  4.95s/it]

robotic interal mechanism moving


In [3]:
from audiotools import AudioSignal
# del model
# del compression_model
torch.cuda.empty_cache()

cfg = Config()
cfg.update(**{"prompts": [p for p in prompts for _ in range(3)]})

accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
save_path = "./test_total_444"
make_dir(save_path)
cfg.update(**{"save_path": save_path})

compression_model, lm = build_model(cfg)
model = AudioProcessing(cfg, lm)

test_dataset = TestDataset(cfg)
test_dataloader = DataLoader(test_dataset, batch_size=1)

model, compression_model = accelerator.prepare(model, compression_model)
model_path = os.path.join("./output_dir_total22/44.pth")
model.load_state_dict(torch.load(model_path))

model.eval()
compression_model.eval()
if accelerator.is_main_process:
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_vae = accelerator.unwrap_model(compression_model)
    audio_num = 1
    for test_step, batch in tqdm(enumerate(test_dataloader)):
        gen_tokens, gen_audio = unwrapped_model.inference(batch, unwrapped_vae)
        prompt = batch[0]
        print(prompt)
        audio_filename = f"{prompt}_{audio_num}.wav"
        unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
        # from IPython.display import Audio
        # display(Audio(data=gen_audio[0].detach().cpu().numpy(), rate=cfg.sample_rate))
        # AudioSignal(gen_audio[0].detach().cpu().numpy(), sample_rate=cfg.sample_rate).widget()
        audio_num += 1
        if audio_num > 5:
            audio_num = 1

1it [00:06,  6.13s/it]

toilet water, music


2it [00:12,  6.09s/it]

toilet water, music


3it [00:17,  5.71s/it]

toilet water, music


4it [00:22,  5.41s/it]

horn sound, crash sound


5it [00:27,  5.24s/it]

horn sound, crash sound


6it [00:32,  5.14s/it]

horn sound, crash sound


7it [00:37,  5.09s/it]

horn sound followed by crashing cars


8it [00:42,  5.08s/it]

horn sound followed by crashing cars


9it [00:47,  5.05s/it]

horn sound followed by crashing cars


10it [00:52,  5.01s/it]

man gregorian chanting, inside knight's helm


11it [00:57,  5.00s/it]

man gregorian chanting, inside knight's helm


12it [01:02,  5.04s/it]

man gregorian chanting, inside knight's helm


13it [01:07,  5.09s/it]

large metal spikeball, hitting ground


14it [01:12,  5.13s/it]

large metal spikeball, hitting ground


15it [01:18,  5.16s/it]

large metal spikeball, hitting ground


16it [01:23,  5.17s/it]

heavy footsteps wrapped in chains


17it [01:28,  5.18s/it]

heavy footsteps wrapped in chains


18it [01:33,  5.18s/it]

heavy footsteps wrapped in chains


19it [01:38,  5.20s/it]

robotic interal mechanism moving


20it [01:43,  5.16s/it]

robotic interal mechanism moving


21it [01:48,  5.08s/it]

robotic interal mechanism moving


22it [01:53,  5.04s/it]

frog followed by woosh


23it [01:58,  5.01s/it]

frog followed by woosh


23it [02:03,  5.35s/it]


KeyboardInterrupt: 

In [4]:
from audiotools import AudioSignal
torch.cuda.empty_cache()

cfg = Config()
cfg.update(**{"prompts": [p for p in prompts for _ in range(3)]})

accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
save_path = "./test_total_noise"
make_dir(save_path)
cfg.update(**{"save_path": save_path})

compression_model, lm = build_model(cfg)
model = AudioProcessing(cfg, lm)

test_dataset = TestDataset(cfg)
test_dataloader = DataLoader(test_dataset, batch_size=1)

model, compression_model = accelerator.prepare(model, compression_model)
model_path = os.path.join("./output_dir_total22/19.pth")
model.load_state_dict(torch.load(model_path))

model.eval()
compression_model.eval()
if accelerator.is_main_process:
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_vae = accelerator.unwrap_model(compression_model)
    audio_num = 1
    for test_step, batch in tqdm(enumerate(test_dataloader)):
        gen_tokens, gen_audio = unwrapped_model.inference(batch, unwrapped_vae)
        prompt = batch[0]
        print(prompt)
        audio_filename = f"{prompt}_{audio_num}.wav"
        unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
        # from IPython.display import Audio
        # display(Audio(data=gen_audio[0].detach().cpu().numpy(), rate=cfg.sample_rate))
        # AudioSignal(gen_audio[0].detach().cpu().numpy(), sample_rate=cfg.sample_rate).widget()
        audio_num += 1
        if audio_num > 5:
            audio_num = 1

1it [00:05,  5.27s/it]

frog followed by woosh


2it [00:10,  5.07s/it]

frog followed by woosh


3it [00:15,  5.00s/it]

frog followed by woosh


4it [00:20,  4.96s/it]

rain, frog


5it [00:24,  4.94s/it]

rain, frog


6it [00:29,  4.93s/it]

rain, frog


7it [00:34,  4.93s/it]

fire, wood building falling


8it [00:39,  4.95s/it]

fire, wood building falling


9it [00:44,  4.94s/it]

fire, wood building falling


10it [00:49,  4.93s/it]

gun sound and then child crying


11it [00:54,  4.93s/it]

gun sound and then child crying


12it [00:59,  4.93s/it]

gun sound and then child crying


13it [01:04,  4.97s/it]

big, church bell


14it [01:09,  4.99s/it]

big, church bell


15it [01:15,  5.25s/it]

big, church bell


16it [01:21,  5.40s/it]

morning alarm


17it [01:26,  5.50s/it]

morning alarm


18it [01:32,  5.61s/it]

morning alarm


19it [01:38,  5.69s/it]

machine exploding, parts falling


20it [01:44,  5.73s/it]

machine exploding, parts falling


21it [01:50,  5.75s/it]

machine exploding, parts falling


22it [01:55,  5.75s/it]

missile firing and exploding


23it [02:01,  5.81s/it]

missile firing and exploding


24it [02:07,  5.79s/it]

missile firing and exploding


25it [02:12,  5.60s/it]

whoosh, ice


26it [02:17,  5.38s/it]

whoosh, ice


27it [02:22,  5.24s/it]

whoosh, ice


28it [02:27,  5.16s/it]

dragon wings flapping


29it [02:32,  5.08s/it]

dragon wings flapping


30it [02:37,  5.03s/it]

dragon wings flapping


31it [02:42,  5.00s/it]

biting an apple


32it [02:48,  5.45s/it]

biting an apple


33it [02:54,  5.59s/it]

biting an apple


34it [03:00,  5.70s/it]

walking on the shallow water


35it [03:06,  5.75s/it]

walking on the shallow water


36it [03:12,  5.81s/it]

walking on the shallow water


37it [03:17,  5.55s/it]

engine starting up


38it [03:22,  5.36s/it]

engine starting up


39it [03:27,  5.24s/it]

engine starting up


40it [03:32,  5.16s/it]

violin, concert hall


41it [03:37,  5.08s/it]

violin, concert hall


42it [03:42,  5.03s/it]

violin, concert hall


43it [03:47,  5.01s/it]

rolling dice


44it [03:51,  4.97s/it]

rolling dice


45it [03:56,  4.94s/it]

rolling dice


46it [04:01,  4.92s/it]

wine glass falling


47it [04:06,  4.91s/it]

wine glass falling


48it [04:11,  4.94s/it]

wine glass falling


49it [04:16,  4.93s/it]

printer printing paper


50it [04:21,  4.91s/it]

printer printing paper


51it [04:26,  4.90s/it]

printer printing paper


52it [04:31,  4.90s/it]

laser gun


53it [04:36,  4.91s/it]

laser gun


54it [04:40,  4.90s/it]

laser gun


55it [04:45,  4.92s/it]

multimedia, notification


56it [04:50,  4.91s/it]

multimedia, notification


57it [04:55,  4.91s/it]

multimedia, notification


58it [05:00,  4.92s/it]

truck accelerating


59it [05:05,  4.90s/it]

truck accelerating


60it [05:10,  4.89s/it]

truck accelerating


61it [05:15,  4.89s/it]

frying chicken in a oil


62it [05:20,  4.88s/it]

frying chicken in a oil


63it [05:24,  4.88s/it]

frying chicken in a oil


64it [05:29,  4.87s/it]

running, basketball court


65it [05:34,  4.88s/it]

running, basketball court


66it [05:39,  4.88s/it]

running, basketball court


67it [05:44,  4.91s/it]

lightning hits tree


68it [05:49,  4.94s/it]

lightning hits tree


69it [05:54,  5.04s/it]

lightning hits tree


70it [06:00,  5.09s/it]

rifle reloading


71it [06:05,  5.05s/it]

rifle reloading


72it [06:09,  5.01s/it]

rifle reloading


73it [06:14,  5.00s/it]

fart sound machine gun


74it [06:19,  4.97s/it]

fart sound machine gun


75it [06:24,  4.97s/it]

fart sound machine gun


76it [06:29,  4.95s/it]

crickets chirping


77it [06:34,  4.95s/it]

crickets chirping


78it [06:39,  4.93s/it]

crickets chirping


79it [06:44,  4.92s/it]

chainsaw cutting tree


80it [06:49,  4.92s/it]

chainsaw cutting tree


81it [06:54,  4.91s/it]

chainsaw cutting tree


82it [06:59,  4.92s/it]

girl, whispering


83it [07:04,  4.99s/it]

girl, whispering


84it [07:09,  5.11s/it]

girl, whispering



