In [1]:
import os
import json
import math
import sys
import argparse
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import wandb

from accelerate import Accelerator
from transformers import get_scheduler

from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
from config import Config
from audiomodel import AudioProcessing
from audiodataset_seperation import SeperationDataset, TestDataset

def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def wandb_init(cfg):
    wandb.init(
            # set the wandb project where this run will be logged
            project=cfg.wandb_project_name,
            
            # track hyperparameters and run metadata
            config={
            "learning_rate": cfg.learning_rate,
            "epochs": cfg.num_train_epochs,
            "batch_size": cfg.batch_size,
            }
    )
    
def save_checkpoint(cfg, model, result, best_loss, epoch=0):
    save_checkpoint = False
    with open("{}/summary.jsonl".format(cfg.output_dir), "a") as f:
        f.write(json.dumps(result) + "\n\n")
        
    if result["valid_loss"] < best_loss:
      best_loss = result["valid_loss"]
      save_checkpoint = True
      
    # 모델 상태 저장
    if save_checkpoint and cfg.checkpointing_steps == "best":
        torch.save(model.state_dict(), os.path.join(cfg.output_dir, f"best.pth"))

    #torch.save(model.state_dict(), os.path.join(cfg.output_dir, "last.pth"))
    torch.save(model.state_dict(), os.path.join(cfg.output_dir, f"{epoch}.pth"))

    return best_loss

    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118)
    Python  3.10.13 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [2]:
def build_model(cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""     
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', device=cfg.device)
        return compression_model, lm

def process_audio_tokenizer(wav, compression_model):
    """
    Get wav audio and return audio tokens
    """
    with torch.no_grad():
        audio_tokens, scale = compression_model.encode(wav)
    return audio_tokens

def post_process_audio_tokenizer(audio_tokens, audio_lengths=None, compression_model=None, lm=None, cfg=None):
    """
    For Masking
    """
    padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
    audio_tokens = audio_tokens.clone()
    padding_mask = padding_mask.clone()
    token_sample_rate = compression_model.frame_rate
    # B : batch size
    # K : codebook num
    # T_s : duration * 50(임의 지정, encodec's frame rate)
    B, K, T_s = audio_tokens.shape
    
    for i in range(B):
        valid_tokens = math.floor(audio_lengths[i] / cfg.sample_rate * token_sample_rate)
        audio_tokens[i, :, valid_tokens:] = lm.special_token_id # 2048이다.
        padding_mask[i, :, valid_tokens:] = 0

    return audio_tokens, padding_mask

In [3]:
base_path = "./csv_files/"
train_data_path = f"{base_path}/train_dataset_epidemic_sub.csv"
eval_data_path = f"{base_path}/eval_dataset_epidemic_sub.csv"

cfg = Config()

cfg.update(train_data_path=train_data_path, eval_data_path=eval_data_path, batch_size=1)

accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
cfg.update(device=accelerator.device)
make_dir(cfg.output_dir)
make_dir(cfg.generated_dir)
# if accelerator.is_main_process: 
#     wandb_init(cfg)

with accelerator.main_process_first():  
    compression_model, lm = build_model(cfg)
    audio_dataset = SeperationDataset(cfg, train=True)
    eval_dataset = SeperationDataset(cfg, train=False)
compression_model.eval()

model = AudioProcessing(cfg, lm)
test_dataset = TestDataset(cfg)

audio_dataloader = DataLoader(audio_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=8)
eval_dataloader = DataLoader(eval_dataset, batch_size=cfg.eval_batch_size, shuffle=False, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=1)

# gradients for lm
optimizer_parameters = [param for param in model.lm.parameters() if param.requires_grad]

optimizer = torch.optim.AdamW(
    optimizer_parameters, 
    lr=cfg.learning_rate,
    betas=(cfg.adam_beta1, cfg.adam_beta2),
    weight_decay=cfg.adam_weight_decay,
    eps=cfg.adam_epsilon,
)

num_update_steps_per_epoch = math.ceil(len(audio_dataloader) / cfg.gradient_accumulation_steps)
if cfg.max_train_steps is None:
  cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
      name=cfg.lr_scheduler_type,
      optimizer=optimizer,
      num_warmup_steps=cfg.num_warmup_steps * cfg.gradient_accumulation_steps,
      num_training_steps=cfg.max_train_steps * cfg.gradient_accumulation_steps,
  )

with accelerator.main_process_first():
    if cfg.resume_from_checkpoint is not None:
        accelerator.print(f"Resumed from local checkpoint: {cfg.resume_from_checkpoint}")
        model.load_state_dict(torch.load(cfg.resume_from_checkpoint, map_location=accelerator.device))
        #accelerator.load_state(cfg.resume_from_checkpoint)


audio_dataloader, eval_dataloader, model, compression_model, optimizer, lr_scheduler = accelerator.prepare(
    audio_dataloader, eval_dataloader, model, compression_model, optimizer, lr_scheduler
)

starting_epoch, completed_steps, best_loss, save_epoch = 0, 0, np.inf, 0
progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)

  0%|          | 0/41210000 [00:00<?, ?it/s]

In [4]:
from audiotools import AudioSignal
wav_path = "0rMqlwSi9u.wav"

wav = AudioSignal(wav_path)
sl = wav.signal_length
print(wav.duration)

wav.to_mono()
wav.resample(16000)
wav.truncate_samples(48000)

# torch.tensor(wav.numpy())

unwrapped_vae = accelerator.unwrap_model(compression_model)

audio_tokens = process_audio_tokenizer(torch.tensor(wav.numpy()).to(accelerator.device), unwrapped_vae)

print(audio_tokens.shape)

audio_tokens2, padding_mask = post_process_audio_tokenizer(audio_tokens, [sl], unwrapped_vae, lm, cfg)
print(audio_tokens2.shape)
print(padding_mask.shape)

2.815986394557823
torch.Size([1, 4, 141])
torch.Size([1, 4, 141])
torch.Size([1, 4, 141])


In [6]:
synthesized_audio, prompt, ground_truth, length = next(iter(audio_dataloader))

print(synthesized_audio.shape)
print(prompt)
print(ground_truth.shape)
print(length)

synthesized_audio_tokens = process_audio_tokenizer(synthesized_audio, unwrapped_vae)
synthesized_audio_tokens, synthesized_padding_mask = post_process_audio_tokenizer(synthesized_audio_tokens, length, unwrapped_vae, lm, cfg)

print(synthesized_audio_tokens.shape)
print(synthesized_padding_mask.shape)

# for batch_idx, (synthesized_wav, prompts, ground_truth, lengths) in enumerate(audio_dataloader):
#     print(batch_idx)
#     if batch_idx == 4:
#         break


attributes = [
    ConditioningAttributes(text={'description': str(description)})
    for description in prompt
]
loss, model_output = model(synthesized_audio_tokens, synthesized_padding_mask, attributes)

torch.Size([1, 1, 48000])
("Remove 'The clear and crisp sound of a door being closed, captured in mono with no background noise. Perfect for use in video games, movies, and tutorials.'",)
torch.Size([1, 1, 48000])
tensor([74793], device='cuda:0')
torch.Size([1, 4, 150])
torch.Size([1, 4, 150])
logits :  tensor([[[[-1.9453, -8.3282, -3.4557,  ..., -7.7003, -2.2204, -2.5620],
          [-2.1539, -6.8257, -4.4074,  ..., -6.9874, -2.8900, -3.6284],
          [-1.9611, -5.7711, -3.5935,  ..., -4.3697, -1.2969, -3.5035],
          ...,
          [-0.6441, -3.9716, -1.6830,  ..., -2.5556, -1.4842, -1.7736],
          [-0.6564, -3.9135, -1.6645,  ..., -2.4994, -1.4727, -1.7453],
          [-0.6395, -3.8795, -1.6337,  ..., -2.4380, -1.4316, -1.7322]],

         [[-3.9193, -5.1663, -6.0390,  ...,  0.8620, -0.2039, -8.5069],
          [-4.0252, -2.4035, -5.9598,  ..., -2.2706, -0.9123, -5.7665],
          [-1.5002, -2.4417, -3.9796,  ..., -0.8868, -1.6147, -5.3618],
          ...,
          [-2.4

In [5]:
for epoch in range(starting_epoch, 3):
    accelerator.print(f"-------------------EPOCH{epoch}-------------------------" )
    total_loss, total_val_loss = 0, 0
    model.train()
    
    for batch_idx, (synthesized_wav, prompts, ground_truth, lengths) in enumerate(audio_dataloader):
        # Consider batch
        with accelerator.accumulate(model):
            with torch.no_grad():
                unwrapped_vae = accelerator.unwrap_model(compression_model)
                synthesized_audio_tokens = process_audio_tokenizer(synthesized_wav, unwrapped_vae)
                synthesized_audio_tokens, synthesized_padding_mask = post_process_audio_tokenizer(synthesized_audio_tokens, lengths, unwrapped_vae, lm, cfg)
                
                audio_tokens = process_audio_tokenizer(wav, unwrapped_vae)
                audio_tokens, padding_mask = post_process_audio_tokenizer(audio_tokens, lengths, unwrapped_vae, lm, cfg)
                
                attributes = [
                    ConditioningAttributes(text={'description': str(description)})
                    for description in prompts
                ]

            mask_token_seperation = torch.tensor([[[lm.special_token_id],[lm.special_token_id],[lm.special_token_id],[lm.special_token_id]]]).to(accelerator.device)
            mask_padding_token_seperation = torch.ones(mask_token_seperation.shape).to(accelerator.device)
            
            made_up = torch.concatenate((synthesized_audio_tokens, mask_token_seperation, audio_tokens), dim=2)
            made_padding = torch.concatenate((synthesized_padding_mask, mask_padding_token_seperation, padding_mask), dim=2)
            print(made_up.shape, made_padding.shape, attributes)
            
            loss = model(audio_tokens, padding_mask, attributes)
            ppl =  torch.exp(loss)
            total_loss += loss.detach().float()
            accelerator.backward(loss)     
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            if accelerator.sync_gradients:
                progress_bar.update(1)
                completed_steps += 1

-------------------EPOCH0-------------------------


In [68]:

# import pandas as pd

# dff = pd.read_csv(eval_data_path)
# import random

# randoms = []
# for i in range(len(dff)):
#     random_integer = random.randint(0, len(dff))
#     while random_integer == i:
#         random_integer = random.randint(0, len(dff))
#     randoms.append(random_integer)

# dff['synthesized_index'] = randoms

# dff[:10]

# dff.to_csv(eval_data_path, index=False)  # 'your_file_modified.csv' will be the new file with the additional column