In [1]:
import os
import json
import math
import sys
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm

import soundfile as sf
import librosa

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import wandb

from accelerate import Accelerator
from transformers import get_scheduler

from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from beats.BEATs import BEATsConfig, BEATs

from config import Config
from captioning_config import CaptionConfig
from audiomodel import AudioProcessing
from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes

def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def wandb_init(cfg):
    wandb.init(
            # set the wandb project where this run will be logged
            project=cfg.wandb_project_name,
            
            # track hyperparameters and run metadata
            config={
            "learning_rate": cfg.learning_rate,
            "epochs": cfg.num_train_epochs,
            "batch_size": cfg.batch_size,
            }
    )
    
def save_checkpoint(cfg, model, result, best_loss, epoch=0):
    save_checkpoint = False
    with open("{}/summary.jsonl".format(cfg.output_dir), "a") as f:
        f.write(json.dumps(result) + "\n\n")
        
    if result["train_loss"] < best_loss:
      best_loss = result["train_loss"]
      save_checkpoint = True
      
    # 모델 상태 저장
    if save_checkpoint and cfg.checkpointing_steps == "best":
        torch.save(model.state_dict(), os.path.join(cfg.output_dir, "best.pth"))

    torch.save(model.state_dict(), os.path.join(cfg.output_dir, "last.pth"))
    torch.save(model.state_dict(), os.path.join(cfg.output_dir, f"epoch_{epoch}.pth"))

    return best_loss



    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118)
    Python  3.10.13 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [2]:
# beats
def load_beats(beats_ckpt, device):
    beats_checkpoint = torch.load(beats_ckpt, map_location='cpu')
    beats_cfg = BEATsConfig(beats_checkpoint['cfg'])
    beats = BEATs(beats_cfg)
    beats.load_state_dict(beats_checkpoint['model'])
    for name, param in beats.named_parameters():
        param.requires_grad = False
    return beats

In [17]:
class CaptionModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.model = GPT2LMHeadModel.from_pretrained(cfg.lm_model_name)
        self.ln_audio = nn.LayerNorm(cfg.audio_embedding_size)
        self.lm_proj_audio = nn.Linear(cfg.audio_embedding_size, cfg.text_embedding_size)

    def forward(self, audio_embeds, input_ids, labels, tokenizer, device):
        # 오디오 임베딩 레이어 정규화 및 크기 조정
        audio_embeds = self.ln_audio(audio_embeds)
        audio_embeds = self.lm_proj_audio(audio_embeds)

        # 텍스트 임베딩
        embed_tokens = self.model.transformer.wte
        inputs_embeds = embed_tokens(input_ids)

        # BOS 토큰 임베딩 생성 및 반복
        bsz = input_ids.size(0)
        bos_embeds = embed_tokens(torch.ones([1], dtype=torch.long, device=device) * tokenizer.bos_token_id)
        bos_embeds = bos_embeds.repeat(bsz, 1, 1)

        # 오디오, BOS 및 텍스트 임베딩 결합
        inputs_embeds = torch.cat([bos_embeds, audio_embeds, inputs_embeds], dim=1)

        # 모델 실행 및 손실 계산
        output = self.model(inputs_embeds=inputs_embeds, labels=labels)
        loss = output.loss

        return loss

    def generate(self, audio_embeds, input_ids, tokenizer, max_length=50, num_return_sequences=1):

        # 텍스트 생성
        audio_embeds = self.ln_audio(audio_embeds)
        audio_embeds = self.lm_proj_audio(audio_embeds)

        # 텍스트 임베딩
        embed_tokens = self.model.transformer.wte
        inputs_embeds = embed_tokens(input_ids)

        # BOS 토큰 임베딩 생성 및 반복
        bsz = input_ids.size(0)
        bos_embeds = embed_tokens(torch.ones([1], dtype=torch.long, device=inputs_embeds.device) * tokenizer.bos_token_id)
        bos_embeds = bos_embeds.repeat(bsz, 1, 1)

        # 오디오, BOS 및 텍스트 임베딩 결합
        inputs_embeds = torch.cat([bos_embeds, audio_embeds, inputs_embeds], dim=1)
        atts = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(inputs_embeds.device)
        
        output = self.model.generate(inputs_embeds=inputs_embeds, attention_mask=atts, max_length=50, num_return_sequences=1,pad_token_id=tokenizer.eos_token_id)
        
        # 생성된 텍스트 디코딩
        generated_texts = []
        for i in range(len(output)):
            generated_text = tokenizer.decode(output[i], skip_special_tokens=True)
            generated_texts.append(generated_text)
        
        return generated_texts, output

In [18]:
def process_audio_embedding(wav, beats, audio_token_length, device):
    # 오디오 패딩 마스크 생성
    audio_padding_mask = torch.zeros(wav.shape, device=device).bool()

    # 오디오 특징 추출
    audio_embeds, _ = beats.extract_features(wav, padding_mask=audio_padding_mask, feature_only=True)

    # 현재 길이 확인
    current_length = audio_embeds.size(1)

    if current_length > audio_token_length:
        # 오디오 임베딩 자르기
        audio_embeds = audio_embeds.narrow(1, 0, audio_token_length)
    elif current_length < audio_token_length:
        # 필요한 패딩 길이 계산 및 적용
        padding_length = audio_token_length - current_length
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, padding_length))

    return audio_embeds

In [19]:
class CaptionDataset(Dataset):
    def __init__(self, cfg, tokenizer: GPT2Tokenizer, train=True):
        if train:
            self.data_path = cfg.train_data_path
        else:
            self.data_path = cfg.eval_data_path
        self.dataframe = pd.read_csv(self.data_path)
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = tokenizer.eos_token
        self.audio_token_length = cfg.audio_token_length
        self.text_max_length = cfg.text_max_length
        self.max_length = self.audio_token_length + self.text_max_length + 1   # 전체 길이 설정
        self.sample_rate = cfg.sample_rate
        self.duration = cfg.duration
        

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # DataFrame에서 데이터 로드
        data = self.dataframe.iloc[idx]
        wav_path = data['audio_path']
        
        caption = "<TEXT>"

        # 오디오 파일 정보 읽기
        info = sf.info(wav_path)
        lengths = info.duration * info.samplerate
        
        # 오디오 파일이 3초 이상인 경우
        if info.duration > 3:
            # 첫 3초만 읽기
            wav, sr = sf.read(wav_path, frames=info.samplerate * 3)
        else:
            # 전체 파일 읽기
            wav, sr = sf.read(wav_path)
            
        if len(wav.shape) == 2:
            wav = wav[:, 0]

        # 샘플링 레이트 조정
        if sr != self.sample_rate:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=self.sample_rate, res_type="fft")

        #from IPython.display import Audio
        #display(Audio(wav, rate=16000))
        
        # 오디오 길이 조정 (3초로)
        target_length = self.duration * self.sample_rate  # 3초에 해당하는 샘플 수
        if len(wav) > target_length:
            wav = wav[:target_length]  # 3초를 초과하는 경우 자르기
        elif len(wav) < target_length:
            padding = target_length - len(wav)  # 필요한 패딩 계산
            wav = np.pad(wav, (0, padding), 'constant')  # 패딩 적용

        # 토큰화
        batch_encoding = self.tokenizer(caption, return_tensors='pt')
        input_ids = batch_encoding['input_ids'].squeeze(0)

        return wav, input_ids, lengths

In [20]:
def build_model(cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""     
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', device=cfg.device)
        return compression_model, lm

In [21]:
def process_audio_tokenizer(wav, compression_model):
        with torch.no_grad():
            audio_tokens, scale = compression_model.encode(wav)
        return audio_tokens

def post_process_audio_tokenizer(audio_tokens, audio_lengths=None, compression_model=None, lm=None, cfg=None):
    padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
    audio_tokens = audio_tokens.clone()
    padding_mask = padding_mask.clone()
    token_sample_rate = compression_model.frame_rate
    B, K, T_s = audio_tokens.shape
    
    for i in range(B):
        valid_tokens = math.floor(audio_lengths[i] / cfg.sample_rate * token_sample_rate)
        audio_tokens[i, :, valid_tokens:] = lm.special_token_id
        padding_mask[i, :, valid_tokens:] = 0

    return audio_tokens, padding_mask

In [22]:
class TestDataset(Dataset):
    def __init__(self, cfg):

        if cfg.prompts is None:
            test_df = pd.read_csv(cfg.test_data_path)
            self.prompts = [test_df.iloc[0]['caption'], test_df.iloc[1]['caption'], test_df.iloc[2]['caption'], test_df.iloc[3]['caption'], test_df.iloc[4]['caption'], test_df.iloc[5]['caption'], test_df.iloc[6]['caption'], test_df.iloc[7]['caption'] ]
        else:
            self.prompts = cfg.prompts

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):

        return self.prompts[idx]

In [23]:
def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [24]:
wav, input_ids, lengths = next(iter(audio_dataloader))
audio_embeds = process_audio_embedding(wav.to(device), beats_model, caption_cfg.audio_token_length, device)

descriptions, output = caption_model.generate(audio_embeds.to(device), input_ids.to(device), tokenizer)
for d in descriptions:
    print(d)

### run train

In [25]:
def main():
    cfg = Config()
    caption_cfg = CaptionConfig()

    accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
    device = accelerator.device
    cfg.update(device=accelerator.device)
    make_dir(cfg.output_dir)
    make_dir(cfg.generated_dir)
    #if accelerator.is_main_process: 
    #    wandb_init(cfg)

    with accelerator.main_process_first():  
        compression_model, lm = build_model(cfg)
        model = AudioProcessing(cfg, lm)
    
        beats_model = load_beats(caption_cfg.beats_ckpt, device).to(device)
        tokenizer = GPT2Tokenizer.from_pretrained(caption_cfg.lm_model_name)
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
        caption_model = CaptionModel(caption_cfg).to(device)
        caption_model.load_state_dict(torch.load("caption_weight/6.pth"))

        audio_dataset = CaptionDataset(caption_cfg, tokenizer, train=True) 
        eval_dataset = CaptionDataset(caption_cfg, tokenizer, train=False)
    test_dataset = TestDataset(caption_cfg)
    
    audio_dataloader = DataLoader(audio_dataset, batch_size=caption_cfg.batch_size, shuffle=True, num_workers=8)
    eval_dataloader = DataLoader(eval_dataset, batch_size=caption_cfg.eval_batch_size, shuffle=False, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=1)
    
    optimizer_parameters = [param for param in model.lm.parameters() if param.requires_grad]

    optimizer = torch.optim.AdamW(
        optimizer_parameters, lr=cfg.learning_rate,
        betas=(cfg.adam_beta1, cfg.adam_beta2),
        weight_decay=cfg.adam_weight_decay,
        eps=cfg.adam_epsilon,
    )

    num_update_steps_per_epoch = math.ceil(len(audio_dataloader) / cfg.gradient_accumulation_steps)
    if cfg.max_train_steps is None:
      cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
    
    lr_scheduler = get_scheduler(
          name=cfg.lr_scheduler_type,
          optimizer=optimizer,
          num_warmup_steps=cfg.num_warmup_steps * cfg.gradient_accumulation_steps,
          num_training_steps=cfg.max_train_steps * cfg.gradient_accumulation_steps,
      )


    audio_dataloader, eval_dataloader, model, compression_model, caption_model, beats_model, optimizer, lr_scheduler = accelerator.prepare(
    audio_dataloader, eval_dataloader, model, compression_model, caption_model, beats_model, optimizer, lr_scheduler
)
    compression_model.eval()
    beats_model.eval()
    caption_model.eval()

    starting_epoch, completed_steps, best_loss = 0, 0, np.inf
    progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
    
    for epoch in range(starting_epoch, cfg.num_train_epochs):
        print(f"-------------------EPOCH{epoch}-------------------------" )
        total_loss, total_val_loss = 0, 0
        model.train()
        for batch_idx, (wav, input_ids, lengths) in enumerate(audio_dataloader):
             with accelerator.accumulate(model):
                with torch.no_grad():
                    unwrapped_vae = accelerator.unwrap_model(compression_model)
                    audio_tokens = process_audio_tokenizer(wav.unsqueeze(1).to(torch.float32), unwrapped_vae)
                    audio_tokens, padding_mask = post_process_audio_tokenizer(audio_tokens, lengths, unwrapped_vae, lm, cfg) 
                
                with torch.no_grad():
                    unwrapped_beats = accelerator.unwrap_model(beats_model)
                    unwrapped_gpt = accelerator.unwrap_model(caption_model)
                    audio_embeds = process_audio_embedding(wav.to(device), unwrapped_beats, caption_cfg.audio_token_length, device)
    
                    descriptions, output = unwrapped_gpt.generate(audio_embeds.to(device), input_ids.to(device), tokenizer)
                    #for d in descriptions:
                    #    print(d)
                    attributes = [
                        ConditioningAttributes(text={'description': str(description)})
                        for description in descriptions]
              
                loss = model(audio_tokens, padding_mask, attributes)
                #print(loss)
                ppl =  torch.exp(loss)
                total_loss += loss.detach().float()
                accelerator.backward(loss)     
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    completed_steps += 1
        model.eval()
        for batch_idx, (wav, input_ids, lengths) in enumerate(eval_dataloader):
            with accelerator.accumulate(model):
                with torch.no_grad():
                  
                    unwrapped_vae = accelerator.unwrap_model(compression_model)
                    audio_tokens = process_audio_tokenizer(wav.unsqueeze(1).to(torch.float32), unwrapped_vae)
                    audio_tokens, padding_mask = post_process_audio_tokenizer(audio_tokens, lengths, unwrapped_vae, lm, cfg) 
                
                    unwrapped_beats = accelerator.unwrap_model(beats_model)
                    unwrapped_gpt = accelerator.unwrap_model(caption_model)
                    audio_embeds = process_audio_embedding(wav.to(device), unwrapped_beats, caption_cfg.audio_token_length, device)
    
                    descriptions, output = unwrapped_gpt.generate(audio_embeds.to(device), input_ids.to(device), tokenizer)
                    #for d in descriptions:
                    #    print(d)
                    attributes = [
                        ConditioningAttributes(text={'description': str(description)})
                        for description in descriptions]
                    loss = model(audio_tokens, padding_mask, attributes)
                    total_val_loss += loss 
                    
        if accelerator.is_main_process:         
            result = {}
            result["epoch"] = save_epoch + 1,
            result["step"] = completed_steps
            result["train_loss"] = round(total_loss.item()/cfg.save_steps, 4)
            result["valid_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
            
            wandb.log(result)
            result_string = "Epoch: {}, Loss Train: {}, Valid: {}\n".format(save_epoch + 1, result["train_loss"], result["valid_loss"])    
            accelerator.print(result_string) 
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_vae = accelerator.unwrap_model(compression_model)
            best_loss = save_checkpoint(cfg, unwrapped_model, result, best_loss, save_epoch)
            for test_step, batch in enumerate(test_dataloader):
                gen_token, gen_audio = unwrapped_model.inference(batch, unwrapped_vae)
                audio_filename = f"epoch_{save_epoch}_{test_step}.wav"
                unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
            save_epoch += 1 

    

In [None]:
from accelerate import notebook_launcher
args = ()
notebook_launcher(main, args, num_processes=1)

Launching training on one GPU.



  0%|          | 0/14893000 [00:00<?, ?it/s][A

-------------------EPOCH0-------------------------



  0%|          | 1/14893000 [00:07<30143:36:38,  7.29s/it][A
  0%|          | 2/14893000 [00:11<21975:37:12,  5.31s/it][A
  0%|          | 3/14893000 [00:15<19416:09:57,  4.69s/it][A
  0%|          | 4/14893000 [00:19<18229:25:53,  4.41s/it][A
  0%|          | 5/14893000 [00:22<17288:14:16,  4.18s/it][A
  0%|          | 6/14893000 [00:26<16950:58:43,  4.10s/it][A
  0%|          | 7/14893000 [00:30<16643:03:51,  4.02s/it][A
  0%|          | 8/14893000 [00:34<16558:05:44,  4.00s/it][A
  0%|          | 9/14893000 [00:38<16448:22:03,  3.98s/it][A
  0%|          | 10/14893000 [00:42<16383:33:38,  3.96s/it][A
  0%|          | 11/14893000 [00:46<16338:35:22,  3.95s/it][A
  0%|          | 12/14893000 [00:50<16323:30:47,  3.95s/it][A
  0%|          | 13/14893000 [00:54<16315:28:48,  3.94s/it][A
  0%|          | 14/14893000 [00:58<16269:39:58,  3.93s/it][A
  0%|          | 15/14893000 [01:02<16331:42:04,  3.95s/it][A
  0%|          | 16/14893000 [01:06<16288:19:17,  3.94s/it][A


### captioning unit test

In [None]:
def main():

    cfg = Config()
    
    accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
    make_dir(cfg.output_dir)
    make_dir(cfg.generated_dir)
    wandb_init(cfg)
    
    #compression_model, lm = build_model(cfg)
    model = AudioProcessing(cfg)
    
    audio_dataset = AudioDataset(cfg, train=True) 
    eval_dataset = AudioDataset(cfg, train=False)
    test_dataset = TestDataset(cfg)

    audio_dataloader = DataLoader(audio_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=12)
    eval_dataloader = DataLoader(eval_dataset, batch_size=cfg.eval_batch_size, shuffle=False, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=1)

    optimizer_parameters = [param for param in model.lm.parameters() if param.requires_grad]
    
    optimizer = torch.optim.AdamW(
        optimizer_parameters, lr=cfg.learning_rate,
        betas=(cfg.adam_beta1, cfg.adam_beta2),
        weight_decay=cfg.adam_weight_decay,
        eps=cfg.adam_epsilon,
    )
    
    
    num_update_steps_per_epoch = math.ceil(len(audio_dataloader) / cfg.gradient_accumulation_steps)
    if cfg.max_train_steps is None:
      cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
    
    lr_scheduler = get_scheduler(
          name=cfg.lr_scheduler_type,
          optimizer=optimizer,
          num_warmup_steps=cfg.num_warmup_steps * cfg.gradient_accumulation_steps,
          num_training_steps=cfg.max_train_steps * cfg.gradient_accumulation_steps,
      )


    audio_dataloader, eval_dataloader, model, optimizer, lr_scheduler = accelerator.prepare(
        audio_dataloader, eval_dataloader, model, optimizer, lr_scheduler
    )

    starting_epoch, completed_steps, best_loss = 0, 0, np.inf
    progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
    
    for epoch in range(starting_epoch, cfg.num_train_epochs):
        print(f"-------------------EPOCH{epoch}-------------------------" )
        total_loss, total_val_loss = 0, 0
        model.train()
        for batch_idx, (wav, descriptions, lengths) in enumerate(audio_dataloader):
          with accelerator.accumulate(model):
              loss = model(wav, descriptions, lengths)
              ppl =  torch.exp(loss)
              total_loss += loss.detach().float()
              accelerator.backward(loss)     
              optimizer.step()
              lr_scheduler.step()
              optimizer.zero_grad()
              
          if accelerator.sync_gradients:
              progress_bar.update(1)
              completed_steps += 1
            
        model.eval()
        for batch_idx, (wav, descriptions, lengths) in enumerate(eval_dataloader):
              loss = model(wav, descriptions, lengths)
              total_val_loss += loss  
    
        if accelerator.is_main_process:         
            result = {}
            result["epoch"] = epoch + 1,
            result["step"] = completed_steps
            result["train_loss"] = round(total_loss.item()/len(audio_dataloader), 4)
            result["valid_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
            
            wandb.log(result)
            result_string = "Epoch: {}, Loss Train: {}, Valid: {}\n".format(epoch + 1, result["train_loss"], result["valid_loss"])    
            accelerator.print(result_string) 
            best_loss = save_checkpoint(cfg, model, result, best_loss, epoch)
            
            unwrapped_model = accelerator.unwrap_model(model)
            for test_step, batch in enumerate(test_dataloader):
                gen_audio = unwrapped_model.inference(batch)
                audio_filename = f"epoch_{epoch}_{test_step}.wav"
                unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
             
    #wandb.finish()

In [6]:
#main()
from accelerate import notebook_launcher
notebook_launcher(main, num_processes=1)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Launching training on one GPU.


VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁█
train_loss,█▁
valid_loss,▁█

0,1
step,92.0
train_loss,2.151
valid_loss,2.4137


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112272594538, max=1.0)))


  0%|          | 0/46000 [00:00<?, ?it/s][A

-------------------EPOCH0-------------------------



  0%|          | 1/46000 [00:03<40:52:21,  3.20s/it][A
  0%|          | 2/46000 [00:03<20:06:14,  1.57s/it][A
  0%|          | 3/46000 [00:04<13:44:09,  1.08s/it][A
  0%|          | 4/46000 [00:04<10:43:04,  1.19it/s][A
  0%|          | 5/46000 [00:05<9:00:29,  1.42it/s] [A
  0%|          | 6/46000 [00:05<8:01:04,  1.59it/s][A
  0%|          | 7/46000 [00:06<7:22:41,  1.73it/s][A
  0%|          | 8/46000 [00:06<6:57:42,  1.84it/s][A
  0%|          | 9/46000 [00:06<6:41:39,  1.91it/s][A
  0%|          | 10/46000 [00:07<6:31:49,  1.96it/s][A
  0%|          | 11/46000 [00:07<6:23:56,  2.00it/s][A
  0%|          | 12/46000 [00:08<6:19:03,  2.02it/s][A
  0%|          | 13/46000 [00:08<6:16:05,  2.04it/s][A
  0%|          | 14/46000 [00:09<6:13:20,  2.05it/s][A
  0%|          | 15/46000 [00:09<6:12:47,  2.06it/s][A
  0%|          | 16/46000 [00:10<6:09:37,  2.07it/s][A
  0%|          | 17/46000 [00:10<6:10:40,  2.07it/s][A
  0%|          | 18/46000 [00:11<6:09:34,  2.07it/s

Epoch: 1, Loss Train: 2.457, Valid: 2.4067

-------------------EPOCH1-------------------------



  0%|          | 47/46000 [01:00<140:45:57, 11.03s/it][A
  0%|          | 48/46000 [01:00<100:20:00,  7.86s/it][A
  0%|          | 49/46000 [01:01<72:03:39,  5.65s/it] [A
  0%|          | 50/46000 [01:01<52:17:27,  4.10s/it][A
  0%|          | 51/46000 [01:02<38:27:35,  3.01s/it][A
  0%|          | 52/46000 [01:02<28:46:16,  2.25s/it][A
  0%|          | 53/46000 [01:03<21:59:09,  1.72s/it][A
  0%|          | 54/46000 [01:03<17:13:31,  1.35s/it][A
  0%|          | 55/46000 [01:04<13:53:30,  1.09s/it][A
  0%|          | 56/46000 [01:04<11:34:12,  1.10it/s][A
  0%|          | 57/46000 [01:05<9:56:37,  1.28it/s] [A
  0%|          | 58/46000 [01:05<8:49:51,  1.45it/s][A
  0%|          | 59/46000 [01:06<8:01:54,  1.59it/s][A
  0%|          | 60/46000 [01:06<7:27:37,  1.71it/s][A
  0%|          | 61/46000 [01:07<7:03:48,  1.81it/s][A
  0%|          | 62/46000 [01:07<6:47:33,  1.88it/s][A
  0%|          | 63/46000 [01:08<6:35:16,  1.94it/s][A
  0%|          | 64/46000 [01:08<

Epoch: 2, Loss Train: 2.151, Valid: 2.4137

-------------------EPOCH2-------------------------



  0%|          | 93/46000 [01:57<139:14:47, 10.92s/it][A
  0%|          | 94/46000 [01:57<99:18:36,  7.79s/it] [A
  0%|          | 95/46000 [01:58<71:21:33,  5.60s/it][A
  0%|          | 96/46000 [01:58<51:45:49,  4.06s/it][A
  0%|          | 97/46000 [01:59<38:04:44,  2.99s/it][A

KeyboardInterrupt: 

### Other

In [2]:
from pathlib import Path
import time
import tqdm
import json
import typing as tp
import pandas as pd
import glob2
import math
import omegaconf
import torch
from torch.nn import functional as F
import torch.nn as nn
import sys
import torchaudio
from torch.utils.data import Dataset, DataLoader

class Config:
    def __init__(self):
        self.sample_rate = 16000
        self.is_training = True
        self.duration = 1
        self.total_updates = 10000
        self.eval_steps = 4
        self.device = 'cuda'  # 'cuda' 또는 'cpu'
        self.batch_size = 24
        self.eval_batch_size = 4
        self.train_data_path = "/workspace/train_dataset.csv"
        self.eval_data_path = "/workspace/eval_dataset.csv"
        self.output_dir = "./output_dir"
        self.checkpointing_steps = "best"
        self.save_every = 10
        self.with_tracking = False
        self.text_encoder_name = None  # 나중에 설정
        self.snr_gamma = 5.0
        self.freeze_text_encoder = True
        self.uncondition = False
        self.learning_rate = 3e-5
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.adam_weight_decay = 1e-2
        self.adam_epsilon = 1e-08
        self.gradient_accumulation_steps = 1
        self.num_train_epochs = 1000
        self.num_warmup_steps = 0
        self.max_train_steps = None
        self.lr_scheduler_type = "linear"
        self.resume_from_checkpoint = None #"/workspace/output_dir_batch48/last/" #None
        self.wandb_project_name = "audiogen-finetune-init-test1"
        self.wandb_id = None #"earnest-pond-52"
        self.resume_epoch = 0 #127
        self.dtype = "float32"
        
        self.update_audiocraft_config()
        

    def update(self, **kwargs):
        for key, value in kwargs.items():
            # 기존 속성에 값 할당하거나 새 속성 생성
            if not hasattr(self, key):
                print(key)
            setattr(self, key, value)
            
    def update_audiocraft_config(self):
        self.solver = None
        self.fsdp = None
        self.profiler = None
        self.deadlock = None
        self.dataset = None
        self.checkpoint = None
        self.generate = None
        self.evaluate = None
        self.optim = None
        self.schedule = None
        self.default = None
        self.defaults = None
        self.autocast = None
        self.autocast_dtype = None

        self.compression_model_checkpoint = None
        self.channels = None
        self.logging = None
        self.lm_model = None
        self.codebooks_pattern = None
        self.transformer_lm = None
        self.classifier_free_guidance = None
        self.attribute_dropout = None
        self.fuser = None
        self.conditioners = None
        self.datasource = None


In [3]:
class AudioDataset(Dataset):
    def __init__(self, audio_paths, device, target_sample_rate=44100, duration=3):
        import pandas as pd
        """
        Args:
            audio_files_list (list): List of paths to audio files.
            target_sample_rate (int): The sample rate to which audio should be resampled.
            frame_length (int): The frame length for slicing or padding audio.
        """
        self.audio_paths = audio_paths
        self.target_sample_rate = target_sample_rate
        self.duration = duration
        self.device = device

        self.df = pd.read_csv(self.audio_paths)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        data = self.df.iloc[idx] #self.audio_files_list[idx]
        audio_path = data['sliced_audio_path']
        description = data['description']

        # Load audio signal file
        from audiotools import AudioSignal
        wav = AudioSignal(audio_path)
        length = wav.signal_length

        # Encode audio signal as one long file
        wav.to_mono()
        wav.resample(self.target_sample_rate)

        if wav.duration < self.duration:
          pad_len = int(self.duration * self.target_sample_rate) - wav.signal_length
          wav.zero_pad(0, pad_len)
        elif wav.duration > self.duration:
          wav.truncate_samples(self.duration * self.target_sample_rate)


        return wav.audio_data.squeeze(1), description, length

class TestDataset(Dataset):
    def __init__(self, prompts):

        self.prompts = prompts


    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):

        return self.prompts[idx]

In [None]:
from audiocraft.solvers import base, builders
from audiocraft.solvers.compression import CompressionSolver
from audiocraft import metrics as eval_metrics
from audiocraft import models
from audiocraft.data.audio_utils import normalize_audio
from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
from audiocraft.utils.utils import get_dataset_from_loader, is_jsonable, warn_once
from audiocraft.models.loaders import load_compression_model, load_lm_model

In [7]:
class AudioProcessing(nn.Module):
    
    def __init__(self, cfg):
        super().__init__()  # 부모 클래스 초기화 호출
        self.cfg = cfg
        self.compression_model, self.lm  = self.build_model(self.cfg)
        self.to_float32()
        self.freeze_layers()

    def forward(self, wav, descriptions, lengths):
        from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
        audio_tokens = self.process_audio_tokenizer(wav.to(self.cfg.device))
        audio_tokens, padding_mask = self.post_process_audio_tokenizer(audio_tokens, audio_lengths=lengths)
        
        attributes = [
            ConditioningAttributes(text={'description': description})
            for description in descriptions]
    
        model_output = self.lm.compute_predictions(audio_tokens, conditions=attributes, condition_tensors=None)  # type: ignore
        logits = model_output.logits
    
        mask = padding_mask & model_output.mask
        ce, ce_per_codebook = self.compute_cross_entropy(logits, audio_tokens, mask)
        
        return ce

    def build_model(self, cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""
        
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', device=cfg.device)
    
        return compression_model, lm

    def process_audio_tokenizer(self, wav):
        with torch.no_grad():
            audio_tokens, scale = self.compression_model.encode(wav)
        return audio_tokens

    def post_process_audio_tokenizer(self, audio_tokens, audio_lengths=None):
        padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)

        audio_tokens = audio_tokens.clone()
        padding_mask = padding_mask.clone()
        token_sample_rate = self.compression_model.frame_rate
        B, K, T_s = audio_tokens.shape
        for i in range(B):
            valid_tokens = math.floor(audio_lengths[i] / self.cfg.sample_rate * token_sample_rate)
            audio_tokens[i, :, valid_tokens:] = self.lm.special_token_id
            padding_mask[i, :, valid_tokens:] = 0

        return audio_tokens, padding_mask

    def compute_cross_entropy(self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:

        B, K, T = targets.shape
        assert logits.shape[:-1] == targets.shape
        assert mask.shape == targets.shape
        ce = torch.zeros([], device=targets.device)
        ce_per_codebook: tp.List[torch.Tensor] = []
        for k in range(K):
            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
            ce_targets = targets_k[mask_k]
            ce_logits = logits_k[mask_k]
            q_ce = F.cross_entropy(ce_logits, ce_targets)
            ce += q_ce
            ce_per_codebook.append(q_ce.detach())
        # average cross entropy across codebooks
        ce = ce / K
        return ce, ce_per_codebook

    def audio_generate(self, condition_tensors, gen_duration=5):
        with torch.no_grad():
            total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
            gen_tokens = self.lm.generate(
                None, condition_tensors, max_gen_len=total_gen_len,
                num_samples=1)
            gen_audio = self.compression_model.decode(gen_tokens, None)

        return gen_tokens, gen_audio

    def inference(self, descriptions):
        #with torch.no_grad():
        from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
        attributes = [
        ConditioningAttributes(text={'description': description})
        for description in descriptions]
        _, gen_audio = self.audio_generate(attributes, gen_duration=self.cfg.duration)
        
        return gen_audio

    def to_float32(self):
        # 모든 가중치를 FP32로 변환
        for param in self.lm.parameters():
            param.data = param.data.to(dtype=torch.float32)

    def freeze_layers(self, train_layers=12):
        for param in self.lm.parameters():
            param.requires_grad = False
            
        if train_layers > 0 :
            num_layers = len(self.lm.transformer.layers)
            
            for i in range(num_layers - train_layers, num_layers):
                for param in self.lm.transformer.layers[i].parameters():
                    param.requires_grad = True
                    
            for name, param in self.lm.named_parameters():
                if 'out_norm' in name or 'linears' in name:
                    param.requires_grad = True
                
            
        
        
            

In [4]:
# Extract discrete codes from EnCodec
def process_audio_tokenizer(wav):
  with torch.no_grad():
      audio_tokens, scale = compression_model.encode(wav)
  return audio_tokens

def post_process_audio_tokenizer(audio_tokens, audio_lengths=None, cfg=None):
  padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
  # replace encodec tokens from padded audio with special_token_id

  audio_tokens = audio_tokens.clone()
  padding_mask = padding_mask.clone()
  token_sample_rate = compression_model.frame_rate
  B, K, T_s = audio_tokens.shape
  for i in range(B):
      # take the last token generated from actual audio frames (non-padded audio)
      #math.floor(float(n_frames[i]) / sr[i] * token_sample_rate)
      valid_tokens = math.floor(audio_lengths[i] / cfg.sample_rate * token_sample_rate)
      audio_tokens[i, :, valid_tokens:] = lm.special_token_id
      padding_mask[i, :, valid_tokens:] = 0

  return audio_tokens, padding_mask

def _compute_cross_entropy(
      logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
    ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
        """Compute cross entropy between multi-codebook targets and model's logits.
        The cross entropy is computed per codebook to provide codebook-level cross entropy.
        Valid timesteps for each of the codebook are pulled from the mask, where invalid
        timesteps are set to 0.

        Args:
            logits (torch.Tensor): Model's logits of shape [B, K, T, card].
            targets (torch.Tensor): Target codes, of shape [B, K, T].
            mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
        Returns:
            ce (torch.Tensor): Cross entropy averaged over the codebooks
            ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
        """
        B, K, T = targets.shape
        assert logits.shape[:-1] == targets.shape
        assert mask.shape == targets.shape
        ce = torch.zeros([], device=targets.device)
        ce_per_codebook: tp.List[torch.Tensor] = []
        for k in range(K):
            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
            ce_targets = targets_k[mask_k]
            ce_logits = logits_k[mask_k]
            q_ce = F.cross_entropy(ce_logits, ce_targets)
            ce += q_ce
            ce_per_codebook.append(q_ce.detach())
        # average cross entropy across codebooks
        ce = ce / K
        return ce, ce_per_codebook

def audio_generate(condition_tensors, gen_duration=5):
    with torch.no_grad():
      total_gen_len = math.ceil(gen_duration * compression_model.frame_rate)
      gen_tokens = lm.generate(
          None, condition_tensors, max_gen_len=total_gen_len,
          num_samples=1)
      gen_audio = compression_model.decode(gen_tokens, None)

    return gen_tokens, gen_audio

In [None]:
wav, text, length = next(iter(audio_dataloader))
audio_tokens = process_audio_tokenizer(wav.to(cfg.device))
print("Wav shape: ", wav.shape)
print("Token shape: ", audio_tokens.shape)
post_process_audio_tokenizer(audio_tokens, audio_lengths=length)

import torch

# 가정: model이라는 이름의 PyTorch 모델이 이미 정의되어 있음
for name, param in lm.named_parameters():
    print(f"Layer {name} has data type {param.dtype}")
    #break  # 모든 레이어를 표시하지 않고 첫 레이어에서 루프 중단

def check_requires_grad(model: torch.nn.Module):
    for name, module in model.named_children():
        for param_name, param in module.named_parameters():
            print(f"{name}.{param_name}: requires_grad = {param.requires_grad}")

# DAC 모델의 인스턴스를 생성한 후에 아래와 같이 사용할 수 있습니다:
# dac_instance = DAC(...)
check_requires_grad(lm)