In [1]:
import os
import json
import math
import sys
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from accelerate import Accelerator

from config import Config
from audiomodel_seperation import AudioProcessing
from audiodataset_seperation import SeperationDataset

def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def build_model(cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""     
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', device=cfg.device)
        return compression_model, lm

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.3.0.dev20240125+cu118)
    Python  3.10.13 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
cfg = Config()

model_path = "./sep_models/lr_mod_model_epochs_6_.pth"

base_path = "./csv_files/"
train_data_path = f"{base_path}/train_dataset_epidemic_sub.csv"
eval_data_path = f"{base_path}/eval_dataset_epidemic_sub.csv"

cfg.update(train_data_path=train_data_path, eval_data_path=eval_data_path, batch_size=4)


save_path = "./test"

make_dir(save_path)

cfg.update(**{"save_path": save_path})


eval_dataset = SeperationDataset(cfg, train=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=8)

compression_model, lm = build_model(cfg)
model = AudioProcessing(cfg, lm)

model.load_state_dict(torch.load(model_path))

torch.cuda.empty_cache()



In [4]:
from IPython.display import Audio

for i in range(10):
    s, p, g, l = eval_dataset[i]
    
    print(p)

display(Audio(data=s[0].numpy(), rate=cfg.sample_rate))
display(Audio(data=g[0].numpy(), rate=cfg.sample_rate))
display(Audio(data=w2[0].numpy(), rate=cfg.sample_rate))

cfg.update(duration=3, batch_size=1)

Remove The sound of 'The audio consists of a metallic object being dropped onto a hard surface. This creates a metallic clanking sound. The audio is in mono. There is no background noise. The audio is clear and crisp. The audio can be used as a sound effect in a movie or a video game. The audio can also'
Remove The sound of 'The audio is of a wooden door being opened and closed with a loud and clear sound of metal handle.'
Remove The sound of 'A short burst of electronic distortion.'
Remove The sound of 'A hard object is being scraped against a hard surface.'
Remove The sound of 'A whoosh or swoosh sound effect.'
Remove The sound of 'A bell is ringing. The ringing of the bell is echoing. The ringing of the bell is loud. The ringing of the bell is echoing. The ringing of the bell is loud. The ringing of the bell is echoing. The ringing of the bell is loud. The ringing of the bell'
Remove The sound of 'A hard object is dropped onto a wooden surface. the object hitting the surface is loud

In [7]:
model.eval()
compression_model.eval()
audio_num = 0
for test_step, (synthesized_wav, prompts, ground_truth, lengths) in enumerate(eval_dataloader):
    if audio_num>=5:
        break
    with torch.no_grad():
        # synthesized_wav와 prompts가 사용되어야 한다.
        prompt = prompts[0]
        print(prompt)
        
        audio_tokens, _ = compression_model.encode(synthesized_wav.to("cuda"))
        print("audio_tokens : ", audio_tokens.shape)
        
        gen_tokens, gen_audio = model.inference(audio_tokens, prompts, compression_model)
        # audio_filename = f"{prompt}_{audio_num}.wav"
        # model.save_audio(gen_audio, audio_filename, cfg)
        
        from IPython.display import Audio
        # 합쳐진거, GT, 생성된 소리 다 비교
        display(Audio(data=synthesized_wav[0].numpy(), rate=cfg.sample_rate))
        display(Audio(data=ground_truth[0].numpy(), rate=cfg.sample_rate))
        display(Audio(data=gen_audio[0].detach().cpu().numpy(), rate=cfg.sample_rate))

        audio_num += 1

        del gen_tokens
        del gen_audio
        del synthesized_wav
        del audio_tokens
        
        torch.cuda.empty_cache()

Remove The sound of 'The audio consists of a metallic object being dropped onto a hard surface. This creates a metallic clanking sound. The audio is in mono. There is no background noise. The audio is clear and crisp. The audio can be used as a sound effect in a movie or a video game. The audio can also'
audio_tokens :  torch.Size([1, 4, 150])
initial_tokens :  torch.Size([1, 4, 150])


Remove The sound of 'The audio is of a wooden door being opened and closed with a loud and clear sound of metal handle.'
audio_tokens :  torch.Size([1, 4, 150])
initial_tokens :  torch.Size([1, 4, 150])


Remove The sound of 'A short burst of electronic distortion.'
audio_tokens :  torch.Size([1, 4, 150])
initial_tokens :  torch.Size([1, 4, 150])


Remove The sound of 'A hard object is being scraped against a hard surface.'
audio_tokens :  torch.Size([1, 4, 150])
initial_tokens :  torch.Size([1, 4, 150])


Remove The sound of 'A whoosh or swoosh sound effect.'
audio_tokens :  torch.Size([1, 4, 150])
initial_tokens :  torch.Size([1, 4, 150])


In [25]:
del audio_tokens
del gen_tokens
del gen_audio

torch.cuda.empty_cache()

In [8]:
lm.condition_provider.conditioners.description.parameters()

num_params = sum(p.numel() for p in lm.condition_provider.conditioners.description.parameters() if p.requires_grad)
num_params

1574400

In [9]:
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer

t5_large = T5EncoderModel.from_pretrained('t5-large') #
t5_3b = T5EncoderModel.from_pretrained('t5-3b')       # 11.4G 4배
t5_11b = T5EncoderModel.from_pretrained('t5-11b')     # 45.2G 4배


config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/11.4G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/45.2G [00:00<?, ?B/s]

In [11]:
print(sum(p.numel() for p in t5_large.parameters() if p.requires_grad))
print(sum(p.numel() for p in t5_3b.parameters() if p.requires_grad))
print(sum(p.numel() for p in t5_11b.parameters() if p.requires_grad))

334939648
1240909824
4864791552
