In [1]:
import os
import json
import math
import sys
import copy
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm

import librosa
import soundfile as sf
from audiotools import AudioSignal

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import wandb

from accelerate import Accelerator
from transformers import get_scheduler

from beats.BEATs import BEATsConfig, BEATs

from config import Config
from audiomodel_inpainting import AudioProcessing
from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes

def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def wandb_init(cfg):
    wandb.init(
            # set the wandb project where this run will be logged
            project=cfg.wandb_project_name,
            
            # track hyperparameters and run metadata
            config={
            "learning_rate": cfg.learning_rate,
            "epochs": cfg.num_train_epochs,
            "batch_size": cfg.batch_size,
            }
    )
    
def save_checkpoint(cfg, model, result, best_loss, epoch=0):
    save_checkpoint = False
    with open("{}/summary.jsonl".format(cfg.output_dir), "a") as f:
        f.write(json.dumps(result) + "\n\n")
        
    if result["train_loss"] < best_loss:
      best_loss = result["train_loss"]
      save_checkpoint = True
      
    # 모델 상태 저장
    if save_checkpoint and cfg.checkpointing_steps == "best":
        torch.save(model.state_dict(), os.path.join(cfg.output_dir, "best.pth"))

    torch.save(model.state_dict(), os.path.join(cfg.output_dir, "last.pth"))
    torch.save(model.state_dict(), os.path.join(cfg.output_dir, f"epoch_{epoch}.pth"))

    return best_loss



    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cu118)
    Python  3.10.13 (you have 3.10.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [2]:
def build_model(cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""     
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', custom_cfg=cfg)
        return compression_model, lm

In [3]:
def process_audio_tokenizer(wav, compression_model):
        with torch.no_grad():
            audio_tokens, scale = compression_model.encode(wav)
        return audio_tokens

def post_process_audio_tokenizer(audio_tokens, audio_lengths=None, compression_model=None, lm=None, cfg=None):
    padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
    audio_tokens = audio_tokens.clone()
    padding_mask = padding_mask.clone()
    token_sample_rate = compression_model.frame_rate
    B, K, T_s = audio_tokens.shape
    
    for i in range(B):
        valid_tokens = math.floor(audio_lengths[i] / cfg.sample_rate * token_sample_rate)
        audio_tokens[i, :, valid_tokens:] = lm.special_token_id
        padding_mask[i, :, valid_tokens:] = 0

    return audio_tokens, padding_mask

In [4]:
class TestDataset(Dataset):
    def __init__(self, cfg):
        
        self.target_sample_rate = cfg.sample_rate
        self.duration = cfg.duration
        self.device = cfg.device
        self.audio_paths = cfg.eval_data_path

        self.df = pd.read_csv(self.audio_paths)[:20]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        data = self.df.iloc[idx] #self.audio_files_list[idx]
        
        import random
        audio_path = data['audio_path']
        total_duration = data['duration']
        description = "<RANDOM>"
        
        # Set duration
        duration = self.duration if total_duration >= 3 else total_duration  # Duration is 3 seconds or total_duration if less than 3
        
        offset = 0.0   
        # Load audio signal file
        wav = AudioSignal(audio_path, offset=offset, duration=duration)
        length = wav.signal_length

        # Encode audio signal as one long file
        wav.to_mono()
        wav.resample(self.target_sample_rate)

        if wav.duration < self.duration:
          pad_len = int(self.duration * self.target_sample_rate) - wav.signal_length
          wav.zero_pad(0, pad_len)
        elif wav.duration > self.duration:
          wav.truncate_samples(self.duration * self.target_sample_rate)


        return wav.audio_data.squeeze(1), description

In [5]:
class AudioDataset(Dataset):
    def __init__(self, cfg, train=True):
        self.train = train
        
        self.target_sample_rate = cfg.sample_rate
        self.duration = cfg.duration
        self.device = cfg.device

        if self.train:
            self.audio_paths = cfg.train_data_path
        else:
            self.audio_paths = cfg.eval_data_path

        self.df = pd.read_csv(self.audio_paths)[:100]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        data = self.df.iloc[idx] #self.audio_files_list[idx]
        
        import random
        audio_path = data['audio_path']
        total_duration = data['duration']
        description = "<RANDOM>"
        
        # Set duration
        duration = self.duration if total_duration >= 3 else total_duration  # Duration is 3 seconds or total_duration if less than 3
        
        # Set offset based on conditions
        if total_duration < self.duration or self.train == False:
            offset = 0.0 
        else:
            max_offset = total_duration - duration  # Calculate the maximum possible offset
            offset = random.uniform(0, max_offset)  # Choose a random offset within the possible range
        
        # Load audio signal file
        wav = AudioSignal(audio_path, offset=offset, duration=duration)
        length = wav.signal_length

        # Encode audio signal as one long file
        wav.to_mono()
        wav.resample(self.target_sample_rate)

        if wav.duration < self.duration:
          pad_len = int(self.duration * self.target_sample_rate) - wav.signal_length
          wav.zero_pad(0, pad_len)
        elif wav.duration > self.duration:
          wav.truncate_samples(self.duration * self.target_sample_rate)

        return wav.audio_data.squeeze(1), description, length

In [61]:
# SoundConditioner 클래스 정의
class SoundConditionerEncodec(nn.Module):
    def __init__(self, compression_model):
        super(SoundConditionerEncodec, self).__init__()

        # 비트 모델 로드
       
        self.device = cfg.device
        self.compression_model = compression_model

    def forward(self, wav):
        # 오디오 토큰 길이 설정
        emb = self.compression_model.encoder(wav)

        return emb.permute(0,2,1), None


In [None]:
# SoundConditioner 클래스 정의
class SoundConditioner(nn.Module):
    def __init__(self, cfg):
        super(SoundConditioner, self).__init__()

         beats_ckpt = "beats/weights.pt"
        self.device = cfg.device
        self.beats_model = self.load_beats(beats_ckpt)
        self.beats_model.eval()

    def forward(self, wav):
        # 오디오 토큰 길이 설정
        audio_token_length = 180
        # 오디오 임베딩 처리
        audio_embeds = self.process_audio_embedding(wav.squeeze(1).to(self.device), self.beats_model, audio_token_length, self.device)
        return audio_embeds, None

    def load_beats(self, beats_ckpt):
        beats_checkpoint = torch.load(beats_ckpt, map_location='cpu')
        beats_cfg = BEATsConfig(beats_checkpoint['cfg'])
        beats = BEATs(beats_cfg)
        beats.load_state_dict(beats_checkpoint['model'])
        for name, param in beats.named_parameters():
            param.requires_grad = False
        return beats

    def process_audio_embedding(self, wav, beats, audio_token_length, device):
        # 오디오 패딩 마스크 생성
        audio_padding_mask = torch.zeros(wav.shape, device=wav.device).bool()
    
        # 오디오 특징 추출
        audio_embeds, _ = beats.extract_features(wav, padding_mask=audio_padding_mask, feature_only=True)
    
        # 현재 길이 확인
        current_length = audio_embeds.size(1)
    
        if current_length > audio_token_length:
            # 오디오 임베딩 자르기
            audio_embeds = audio_embeds.narrow(1, 0, audio_token_length)
        elif current_length < audio_token_length:
            # 필요한 패딩 길이 계산 및 적용
            padding_length = audio_token_length - current_length
            audio_embeds = F.pad(audio_embeds, (0, 0, 0, padding_length))
    
        return audio_embeds

In [11]:
def main():
    cfg = Config()
    accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
    device = accelerator.device
    cfg.update(device=accelerator.device)
    make_dir(cfg.output_dir)
    make_dir(cfg.generated_dir)
    
    base_path = "./csv_files/"
    train_data_path = f"{base_path}/eval_epidemic_dataset.csv"
    eval_data_path = f"{base_path}/eval_epidemic_dataset.csv"
    cfg.update(train_data_path=train_data_path, eval_data_path=eval_data_path)
    
    # 'sound'를 'cross' 키의 리스트에 추가
    cfg.fuser['cross'].append('sound')
    if accelerator.is_main_process: 
        wandb_init(cfg)
    
    with accelerator.main_process_first():  
        compression_model, lm = build_model(cfg)
        model = AudioProcessing(cfg, lm)  
        t5conditioner = copy.deepcopy(lm.condition_provider.conditioners.description)
        soundconditioner = SoundConditioner(cfg)
        audio_dataset = AudioDataset(cfg, train=True) 
        eval_dataset = AudioDataset(cfg, train=False)
    test_dataset = TestDataset(cfg)
    
    audio_dataloader = DataLoader(audio_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=8)
    eval_dataloader = DataLoader(eval_dataset, batch_size=cfg.eval_batch_size, shuffle=False, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=1)
        
    optimizer_parameters = [param for param in model.parameters() if param.requires_grad]
        
    optimizer = torch.optim.AdamW(
        optimizer_parameters, lr=cfg.learning_rate,
        betas=(cfg.adam_beta1, cfg.adam_beta2),
        weight_decay=cfg.adam_weight_decay,
        eps=cfg.adam_epsilon,
    )
    
    
    num_update_steps_per_epoch = math.ceil(len(audio_dataloader) / cfg.gradient_accumulation_steps)
    if cfg.max_train_steps is None:
      cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
    
    lr_scheduler = get_scheduler(
          name=cfg.lr_scheduler_type,
          optimizer=optimizer,
          num_warmup_steps=cfg.num_warmup_steps * cfg.gradient_accumulation_steps,
          num_training_steps=cfg.max_train_steps * cfg.gradient_accumulation_steps,
      )

    with accelerator.main_process_first():
        if cfg.resume_from_checkpoint is not None:
            accelerator.print(f"Resumed from local checkpoint: {cfg.resume_from_checkpoint}")
            model.load_state_dict(torch.load(cfg.resume_from_checkpoint, map_location=accelerator.device))
            #accelerator.load_state(cfg.resume_from_checkpoint)

    audio_dataloader, eval_dataloader, model, compression_model, t5conditioner, soundconditioner, optimizer, lr_scheduler = accelerator.prepare(
audio_dataloader, eval_dataloader, model, compression_model, t5conditioner, soundconditioner, optimizer, lr_scheduler
)

    starting_epoch, completed_steps, best_loss, save_epoch = 0, 0, np.inf, 0
    progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)

    
    for epoch in range(starting_epoch, cfg.num_train_epochs):
        accelerator.print(f"-------------------EPOCH{epoch}-------------------------" )
        total_loss, total_val_loss = 0, 0
        model.eval()
        for batch_idx, (wav, descriptions, lengths) in enumerate(audio_dataloader):
            with accelerator.accumulate(model):
                with torch.no_grad():
                    unwrapped_textconditioner = accelerator.unwrap_model(t5conditioner)
                    unwrapped_soundconditioner = accelerator.unwrap_model(soundconditioner)
                    
                    tokenized = {}
                    tokenized["description"] =  unwrapped_textconditioner.tokenize(descriptions)
                    tokenized["sound"] = wav
                    
                    # conditioning
                    output = {}
                    for attribute, inputs in tokenized.items():
                        if attribute == "description":   
                            condition, mask = unwrapped_textconditioner(inputs)
                        elif attribute == "sound":
                            condition, mask = unwrapped_soundconditioner(inputs)
                        output[attribute] = (condition, mask)

                    unwrapped_vae = accelerator.unwrap_model(compression_model)
                    audio_tokens = process_audio_tokenizer(wav, unwrapped_vae)
                    audio_tokens, padding_mask = post_process_audio_tokenizer(audio_tokens, lengths, unwrapped_vae, lm, cfg) 

                loss = model(audio_tokens, padding_mask, attributes=None, condition_tensors=output)
                ppl =  torch.exp(loss)
                total_loss += loss.detach().float()
                accelerator.backward(loss)     
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    completed_steps += 1
        
        model.eval()
        for batch_idx, (wav, descriptions, lengths) in enumerate(eval_dataloader):
            with accelerator.accumulate(model):
                with torch.no_grad():
                    unwrapped_textconditioner = accelerator.unwrap_model(t5conditioner)
                    unwrapped_soundconditioner = accelerator.unwrap_model(soundconditioner)
                    
                    tokenized = {}
                    tokenized["description"] =  unwrapped_textconditioner.tokenize(descriptions)
                    tokenized["sound"] = wav
                    
                    # conditioning
                    output = {}
                    for attribute, inputs in tokenized.items():
                        if attribute == "description":   
                            condition, mask = unwrapped_textconditioner(inputs)
                        elif attribute == "sound":
                            condition, mask = unwrapped_soundconditioner(inputs)
                        output[attribute] = (condition, mask)

                    unwrapped_vae = accelerator.unwrap_model(compression_model)
                    audio_tokens = process_audio_tokenizer(wav, unwrapped_vae)
                    audio_tokens, padding_mask = post_process_audio_tokenizer(audio_tokens, lengths, unwrapped_vae, lm, cfg) 

                    loss = model(audio_tokens, padding_mask, attributes=None, condition_tensors=output)
                    total_val_loss += loss  
    
        if accelerator.is_main_process:         
            result = {}
            result["epoch"] = save_epoch + 1,
            result["step"] = completed_steps
            result["train_loss"] = round(total_loss.item()/cfg.save_steps, 4)
            result["valid_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
            
            wandb.log(result)
            result_string = "Epoch: {}, Loss Train: {}, Valid: {}\n".format(save_epoch + 1, result["train_loss"], result["valid_loss"])    
            accelerator.print(result_string) 
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_vae = accelerator.unwrap_model(compression_model)
            unwrapped_soundconditioner = accelerator.unwrap_model(soundconditioner)
            best_loss = save_checkpoint(cfg, unwrapped_model, result, best_loss, save_epoch)
            for test_step, (wav, descriptions) in enumerate(test_dataloader):
                audio_conditions = unwrapped_soundconditioner(wav)
                gen_token, gen_audio = unwrapped_model.inference(descriptions, audio_conditions, unwrapped_vae)
                audio_filename = f"epoch_{save_epoch}_{test_step}.wav"
                unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
            save_epoch += 1 

In [12]:
from accelerate import notebook_launcher
args = ()
notebook_launcher(main, args, num_processes=1)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Launching training on one GPU.


[34m[1mwandb[0m: Currently logged in as: [33moptimizerai[0m ([33moptimizer_ai[0m). Use [1m`wandb login --relogin`[0m to force relogin




FUSER:  {'cross_attention_pos_emb': False, 'cross_attention_pos_emb_scale': 1, 'sum': [], 'prepend': [], 'cross': ['description', 'sound'], 'input_interpolate': []}


  0%|          | 0/50000 [00:00<?, ?it/s]

-------------------EPOCH0-------------------------


  0%|          | 50/50000 [00:31<8:30:43,  1.63it/s]

Epoch: 1, Loss Train: 6.198, Valid: 2.3788





-------------------EPOCH1-------------------------


  0%|          | 100/50000 [03:14<8:26:20,  1.64it/s] 

Epoch: 2, Loss Train: 5.5799, Valid: 2.3259

-------------------EPOCH2-------------------------


  0%|          | 150/50000 [05:48<8:23:57,  1.65it/s]  

Epoch: 3, Loss Train: 5.5875, Valid: 2.2814

-------------------EPOCH3-------------------------


  0%|          | 200/50000 [08:18<8:20:27,  1.66it/s]  

Epoch: 4, Loss Train: 5.3535, Valid: 2.2375

-------------------EPOCH4-------------------------


  0%|          | 250/50000 [11:00<8:34:01,  1.61it/s]  

Epoch: 5, Loss Train: 5.2808, Valid: 2.1868

-------------------EPOCH5-------------------------


  1%|          | 300/50000 [13:33<8:24:00,  1.64it/s]  

Epoch: 6, Loss Train: 5.288, Valid: 2.1317

-------------------EPOCH6-------------------------


  1%|          | 350/50000 [16:00<8:23:46,  1.64it/s]  

Epoch: 7, Loss Train: 5.1227, Valid: 2.0706

-------------------EPOCH7-------------------------


  1%|          | 400/50000 [18:38<8:19:38,  1.65it/s]  

Epoch: 8, Loss Train: 4.9106, Valid: 2.0166

-------------------EPOCH8-------------------------


  1%|          | 450/50000 [21:16<8:23:38,  1.64it/s]  

Epoch: 9, Loss Train: 4.7773, Valid: 1.9531

-------------------EPOCH9-------------------------


  1%|          | 500/50000 [23:50<8:28:02,  1.62it/s]  

Epoch: 10, Loss Train: 4.7438, Valid: 1.8975

-------------------EPOCH10-------------------------


  1%|          | 550/50000 [26:27<8:45:40,  1.57it/s]  

Epoch: 11, Loss Train: 4.4945, Valid: 1.8513

-------------------EPOCH11-------------------------


  1%|          | 600/50000 [29:06<8:35:41,  1.60it/s]  

Epoch: 12, Loss Train: 4.3462, Valid: 1.8046

-------------------EPOCH12-------------------------


  1%|▏         | 650/50000 [31:46<8:31:28,  1.61it/s]  

Epoch: 13, Loss Train: 4.1591, Valid: 1.7687

-------------------EPOCH13-------------------------


  1%|▏         | 700/50000 [34:26<8:36:02,  1.59it/s]  

Epoch: 14, Loss Train: 4.1328, Valid: 1.7437

-------------------EPOCH14-------------------------


  2%|▏         | 750/50000 [37:06<8:40:56,  1.58it/s]  

Epoch: 15, Loss Train: 4.2902, Valid: 1.7124

-------------------EPOCH15-------------------------


  2%|▏         | 800/50000 [39:43<8:35:37,  1.59it/s]  

Epoch: 16, Loss Train: 3.9247, Valid: 1.6881

-------------------EPOCH16-------------------------


  2%|▏         | 850/50000 [42:17<8:40:49,  1.57it/s]  

Epoch: 17, Loss Train: 4.0537, Valid: 1.6716

-------------------EPOCH17-------------------------


  2%|▏         | 900/50000 [44:48<8:17:38,  1.64it/s]  

Epoch: 18, Loss Train: 3.8975, Valid: 1.6594

-------------------EPOCH18-------------------------


  2%|▏         | 950/50000 [47:20<8:12:24,  1.66it/s]  

Epoch: 19, Loss Train: 3.9747, Valid: 1.6542

-------------------EPOCH19-------------------------


  2%|▏         | 1000/50000 [49:43<8:15:19,  1.65it/s] 

Epoch: 20, Loss Train: 3.9652, Valid: 1.6298

-------------------EPOCH20-------------------------


  2%|▏         | 1050/50000 [52:09<8:28:57,  1.60it/s]  

Epoch: 21, Loss Train: 3.8466, Valid: 1.6174

-------------------EPOCH21-------------------------


  2%|▏         | 1100/50000 [54:41<8:12:59,  1.65it/s]  

Epoch: 22, Loss Train: 3.9264, Valid: 1.6105

-------------------EPOCH22-------------------------


  2%|▏         | 1150/50000 [57:11<8:22:15,  1.62it/s]  

Epoch: 23, Loss Train: 4.0448, Valid: 1.6017

-------------------EPOCH23-------------------------


  2%|▏         | 1200/50000 [59:38<8:16:36,  1.64it/s]  

Epoch: 24, Loss Train: 3.8256, Valid: 1.5957

-------------------EPOCH24-------------------------


  2%|▎         | 1250/50000 [1:02:19<8:34:38,  1.58it/s]  

Epoch: 25, Loss Train: 3.7045, Valid: 1.5889

-------------------EPOCH25-------------------------


  3%|▎         | 1300/50000 [1:04:58<8:29:31,  1.59it/s]  

Epoch: 26, Loss Train: 3.5998, Valid: 1.5842

-------------------EPOCH26-------------------------


  3%|▎         | 1350/50000 [1:07:35<8:31:20,  1.59it/s]  

Epoch: 27, Loss Train: 3.7249, Valid: 1.583

-------------------EPOCH27-------------------------


  3%|▎         | 1400/50000 [1:10:14<8:28:07,  1.59it/s]  

Epoch: 28, Loss Train: 4.104, Valid: 1.5795

-------------------EPOCH28-------------------------


  3%|▎         | 1450/50000 [1:12:40<8:28:01,  1.59it/s]  

Epoch: 29, Loss Train: 3.7691, Valid: 1.572

-------------------EPOCH29-------------------------


  3%|▎         | 1500/50000 [1:15:10<8:20:39,  1.61it/s]  

Epoch: 30, Loss Train: 3.3003, Valid: 1.5677

-------------------EPOCH30-------------------------


  3%|▎         | 1550/50000 [1:17:49<8:23:21,  1.60it/s]  

Epoch: 31, Loss Train: 3.8439, Valid: 1.5669

-------------------EPOCH31-------------------------


  3%|▎         | 1600/50000 [1:20:23<8:28:44,  1.59it/s]  

Epoch: 32, Loss Train: 3.6988, Valid: 1.5645

-------------------EPOCH32-------------------------


  3%|▎         | 1650/50000 [1:22:52<8:23:17,  1.60it/s]  

Epoch: 33, Loss Train: 3.6993, Valid: 1.5609

-------------------EPOCH33-------------------------


  3%|▎         | 1700/50000 [1:25:20<8:22:56,  1.60it/s]  

Epoch: 34, Loss Train: 3.8863, Valid: 1.5582

-------------------EPOCH34-------------------------


  4%|▎         | 1750/50000 [1:27:48<8:09:10,  1.64it/s]  

Epoch: 35, Loss Train: 3.8211, Valid: 1.5567

-------------------EPOCH35-------------------------


  4%|▎         | 1800/50000 [1:30:15<8:13:01,  1.63it/s]  

Epoch: 36, Loss Train: 3.8264, Valid: 1.5563

-------------------EPOCH36-------------------------


  4%|▎         | 1850/50000 [1:32:42<8:08:20,  1.64it/s]  

Epoch: 37, Loss Train: 3.6404, Valid: 1.5512

-------------------EPOCH37-------------------------


  4%|▍         | 1900/50000 [1:35:08<8:08:17,  1.64it/s]  

Epoch: 38, Loss Train: 3.7977, Valid: 1.5466

-------------------EPOCH38-------------------------


  4%|▍         | 1950/50000 [1:37:32<8:06:31,  1.65it/s]  

Epoch: 39, Loss Train: 3.5955, Valid: 1.5459

-------------------EPOCH39-------------------------


  4%|▍         | 2000/50000 [1:39:58<8:04:05,  1.65it/s]  

Epoch: 40, Loss Train: 3.6488, Valid: 1.543

-------------------EPOCH40-------------------------


  4%|▍         | 2050/50000 [1:42:23<8:06:30,  1.64it/s]  

Epoch: 41, Loss Train: 3.7306, Valid: 1.5413

-------------------EPOCH41-------------------------


  4%|▍         | 2100/50000 [1:44:50<8:20:58,  1.59it/s]  

Epoch: 42, Loss Train: 3.7485, Valid: 1.5383

-------------------EPOCH42-------------------------


  4%|▍         | 2150/50000 [1:47:18<8:21:31,  1.59it/s]  

Epoch: 43, Loss Train: 3.4957, Valid: 1.5366

-------------------EPOCH43-------------------------


  4%|▍         | 2200/50000 [1:49:44<8:05:42,  1.64it/s]  

Epoch: 44, Loss Train: 3.7325, Valid: 1.5355

-------------------EPOCH44-------------------------


  4%|▍         | 2250/50000 [1:52:11<8:07:40,  1.63it/s]  

Epoch: 45, Loss Train: 3.7423, Valid: 1.5325

-------------------EPOCH45-------------------------


  5%|▍         | 2300/50000 [1:54:38<7:59:47,  1.66it/s]  

Epoch: 46, Loss Train: 3.5762, Valid: 1.5342

-------------------EPOCH46-------------------------


  5%|▍         | 2350/50000 [1:57:05<7:59:31,  1.66it/s]  

Epoch: 47, Loss Train: 3.6612, Valid: 1.5291

-------------------EPOCH47-------------------------


  5%|▍         | 2400/50000 [1:59:30<8:01:36,  1.65it/s]  

Epoch: 48, Loss Train: 3.5878, Valid: 1.5289

-------------------EPOCH48-------------------------


  5%|▍         | 2450/50000 [2:01:56<8:00:30,  1.65it/s]  

Epoch: 49, Loss Train: 3.8204, Valid: 1.5253

-------------------EPOCH49-------------------------


  5%|▌         | 2500/50000 [2:04:25<7:58:30,  1.65it/s]  

Epoch: 50, Loss Train: 3.544, Valid: 1.5238

-------------------EPOCH50-------------------------


  5%|▌         | 2550/50000 [2:06:54<8:06:08,  1.63it/s]  

Epoch: 51, Loss Train: 3.6089, Valid: 1.5207

-------------------EPOCH51-------------------------


  5%|▌         | 2600/50000 [2:09:23<8:17:09,  1.59it/s]  

Epoch: 52, Loss Train: 3.7189, Valid: 1.5174

-------------------EPOCH52-------------------------


  5%|▌         | 2650/50000 [2:11:52<8:01:56,  1.64it/s]  

Epoch: 53, Loss Train: 3.5306, Valid: 1.5155

-------------------EPOCH53-------------------------


  5%|▌         | 2700/50000 [2:14:22<8:14:21,  1.59it/s]  

Epoch: 54, Loss Train: 3.5496, Valid: 1.5149

-------------------EPOCH54-------------------------


  6%|▌         | 2750/50000 [2:16:47<8:09:29,  1.61it/s]  

Epoch: 55, Loss Train: 3.5042, Valid: 1.5148

-------------------EPOCH55-------------------------


  6%|▌         | 2800/50000 [2:19:12<8:00:28,  1.64it/s]  

Epoch: 56, Loss Train: 3.5516, Valid: 1.5118

-------------------EPOCH56-------------------------


  6%|▌         | 2850/50000 [2:21:38<8:12:53,  1.59it/s]  

Epoch: 57, Loss Train: 3.6427, Valid: 1.5113

-------------------EPOCH57-------------------------


  6%|▌         | 2900/50000 [2:24:07<8:11:47,  1.60it/s]  

Epoch: 58, Loss Train: 3.4072, Valid: 1.5098

-------------------EPOCH58-------------------------


  6%|▌         | 2950/50000 [2:26:36<7:58:16,  1.64it/s]  

Epoch: 59, Loss Train: 3.4861, Valid: 1.508

-------------------EPOCH59-------------------------


  6%|▌         | 3000/50000 [2:29:05<7:57:19,  1.64it/s]  

Epoch: 60, Loss Train: 3.671, Valid: 1.5063

-------------------EPOCH60-------------------------


  6%|▌         | 3050/50000 [2:31:31<7:51:55,  1.66it/s]  

Epoch: 61, Loss Train: 3.4194, Valid: 1.5016

-------------------EPOCH61-------------------------


  6%|▌         | 3100/50000 [2:34:00<7:51:29,  1.66it/s]  

Epoch: 62, Loss Train: 3.6481, Valid: 1.5017

-------------------EPOCH62-------------------------


  6%|▋         | 3150/50000 [2:36:28<7:52:26,  1.65it/s]  

Epoch: 63, Loss Train: 3.6663, Valid: 1.4978

-------------------EPOCH63-------------------------


  6%|▋         | 3200/50000 [2:38:58<8:04:12,  1.61it/s]  

Epoch: 64, Loss Train: 3.695, Valid: 1.4976

-------------------EPOCH64-------------------------


  6%|▋         | 3250/50000 [2:41:25<7:54:59,  1.64it/s]  

Epoch: 65, Loss Train: 3.2115, Valid: 1.4968

-------------------EPOCH65-------------------------


  7%|▋         | 3300/50000 [2:43:59<7:52:45,  1.65it/s]  

Epoch: 66, Loss Train: 3.5234, Valid: 1.4943

-------------------EPOCH66-------------------------


  7%|▋         | 3350/50000 [2:46:29<7:54:09,  1.64it/s]  

Epoch: 67, Loss Train: 3.3293, Valid: 1.4965

-------------------EPOCH67-------------------------


  7%|▋         | 3400/50000 [2:48:56<8:14:10,  1.57it/s]  

Epoch: 68, Loss Train: 3.5285, Valid: 1.4938

-------------------EPOCH68-------------------------


  7%|▋         | 3450/50000 [2:51:31<7:47:26,  1.66it/s]  

Epoch: 69, Loss Train: 3.6151, Valid: 1.4903

-------------------EPOCH69-------------------------


  7%|▋         | 3500/50000 [2:53:59<7:51:19,  1.64it/s]  

Epoch: 70, Loss Train: 3.5177, Valid: 1.4886

-------------------EPOCH70-------------------------


  7%|▋         | 3550/50000 [2:56:34<8:05:58,  1.59it/s]  

Epoch: 71, Loss Train: 3.5864, Valid: 1.4937

-------------------EPOCH71-------------------------


  7%|▋         | 3600/50000 [2:59:09<8:02:14,  1.60it/s]  

Epoch: 72, Loss Train: 3.4531, Valid: 1.4869

-------------------EPOCH72-------------------------


  7%|▋         | 3650/50000 [3:01:43<8:04:51,  1.59it/s]  

Epoch: 73, Loss Train: 3.2977, Valid: 1.488

-------------------EPOCH73-------------------------


  7%|▋         | 3700/50000 [3:04:18<8:00:15,  1.61it/s]  

Epoch: 74, Loss Train: 3.5963, Valid: 1.4828

-------------------EPOCH74-------------------------


  8%|▊         | 3750/50000 [3:06:49<8:02:45,  1.60it/s]  

Epoch: 75, Loss Train: 3.5288, Valid: 1.4816

-------------------EPOCH75-------------------------


  8%|▊         | 3800/50000 [3:09:22<8:05:06,  1.59it/s]  

Epoch: 76, Loss Train: 3.4222, Valid: 1.4816

-------------------EPOCH76-------------------------


  8%|▊         | 3850/50000 [3:12:01<8:11:00,  1.57it/s]  

Epoch: 77, Loss Train: 3.4431, Valid: 1.4817

-------------------EPOCH77-------------------------


  8%|▊         | 3900/50000 [3:14:40<8:02:53,  1.59it/s]  

Epoch: 78, Loss Train: 3.5238, Valid: 1.4806

-------------------EPOCH78-------------------------


  8%|▊         | 3950/50000 [3:17:21<8:02:13,  1.59it/s]  

Epoch: 79, Loss Train: 3.5032, Valid: 1.477

-------------------EPOCH79-------------------------


  8%|▊         | 4000/50000 [3:19:59<8:03:40,  1.59it/s]  

Epoch: 80, Loss Train: 3.421, Valid: 1.4766

-------------------EPOCH80-------------------------


  8%|▊         | 4050/50000 [3:22:35<7:46:29,  1.64it/s]  

Epoch: 81, Loss Train: 3.4176, Valid: 1.4789

-------------------EPOCH81-------------------------


  8%|▊         | 4100/50000 [3:25:10<7:58:13,  1.60it/s]  

Epoch: 82, Loss Train: 3.2793, Valid: 1.4763

-------------------EPOCH82-------------------------


  8%|▊         | 4150/50000 [3:27:45<7:55:03,  1.61it/s]  

Epoch: 83, Loss Train: 3.4833, Valid: 1.4791

-------------------EPOCH83-------------------------


  8%|▊         | 4200/50000 [3:30:14<8:02:15,  1.58it/s]  

Epoch: 84, Loss Train: 3.5871, Valid: 1.4773

-------------------EPOCH84-------------------------


  8%|▊         | 4250/50000 [3:32:41<7:49:51,  1.62it/s]  

Epoch: 85, Loss Train: 3.3707, Valid: 1.4736

-------------------EPOCH85-------------------------


  9%|▊         | 4300/50000 [3:35:10<7:54:33,  1.60it/s]  

Epoch: 86, Loss Train: 3.4753, Valid: 1.4726

-------------------EPOCH86-------------------------


  9%|▊         | 4350/50000 [3:37:41<7:57:27,  1.59it/s]  

Epoch: 87, Loss Train: 3.1493, Valid: 1.4723

-------------------EPOCH87-------------------------


  9%|▉         | 4400/50000 [3:40:17<7:41:54,  1.65it/s]  

Epoch: 88, Loss Train: 3.5135, Valid: 1.4702

-------------------EPOCH88-------------------------


  9%|▉         | 4450/50000 [3:42:51<7:59:38,  1.58it/s]  

Epoch: 89, Loss Train: 3.3592, Valid: 1.4693

-------------------EPOCH89-------------------------


  9%|▉         | 4500/50000 [3:45:23<8:00:01,  1.58it/s]  

Epoch: 90, Loss Train: 3.5216, Valid: 1.4697

-------------------EPOCH90-------------------------


  9%|▉         | 4550/50000 [3:47:57<8:02:18,  1.57it/s]  

Epoch: 91, Loss Train: 3.4419, Valid: 1.4658

-------------------EPOCH91-------------------------


  9%|▉         | 4600/50000 [3:50:33<7:51:01,  1.61it/s]  

Epoch: 92, Loss Train: 3.3656, Valid: 1.464

-------------------EPOCH92-------------------------


  9%|▉         | 4650/50000 [3:53:06<7:51:57,  1.60it/s]  

Epoch: 93, Loss Train: 3.4589, Valid: 1.4604

-------------------EPOCH93-------------------------


  9%|▉         | 4700/50000 [3:55:36<7:52:42,  1.60it/s]  

Epoch: 94, Loss Train: 3.357, Valid: 1.4575

-------------------EPOCH94-------------------------


 10%|▉         | 4750/50000 [3:58:09<7:57:55,  1.58it/s]  

Epoch: 95, Loss Train: 3.4654, Valid: 1.4595

-------------------EPOCH95-------------------------


 10%|▉         | 4800/50000 [4:00:36<7:53:12,  1.59it/s]  

Epoch: 96, Loss Train: 3.3507, Valid: 1.4569

-------------------EPOCH96-------------------------


 10%|▉         | 4850/50000 [4:03:07<7:52:03,  1.59it/s]  

Epoch: 97, Loss Train: 3.4681, Valid: 1.4583

-------------------EPOCH97-------------------------


 10%|▉         | 4900/50000 [4:05:39<7:59:26,  1.57it/s]  

Epoch: 98, Loss Train: 3.1649, Valid: 1.4603

-------------------EPOCH98-------------------------


 10%|▉         | 4950/50000 [4:08:05<7:41:52,  1.63it/s]  

Epoch: 99, Loss Train: 3.3946, Valid: 1.4606

-------------------EPOCH99-------------------------


 10%|█         | 5000/50000 [4:10:39<7:56:33,  1.57it/s]  

Epoch: 100, Loss Train: 3.4465, Valid: 1.4566

-------------------EPOCH100-------------------------


 10%|█         | 5050/50000 [4:13:13<7:52:27,  1.59it/s]  

Epoch: 101, Loss Train: 3.3593, Valid: 1.4553

-------------------EPOCH101-------------------------


 10%|█         | 5100/50000 [4:15:46<7:48:19,  1.60it/s]  

Epoch: 102, Loss Train: 3.5119, Valid: 1.4519

-------------------EPOCH102-------------------------


 10%|█         | 5150/50000 [4:18:17<7:41:28,  1.62it/s]  

Epoch: 103, Loss Train: 3.3127, Valid: 1.4526

-------------------EPOCH103-------------------------


 10%|█         | 5200/50000 [4:20:48<7:50:04,  1.59it/s]  

Epoch: 104, Loss Train: 3.3957, Valid: 1.45

-------------------EPOCH104-------------------------


 10%|█         | 5250/50000 [4:23:21<7:49:22,  1.59it/s]  

Epoch: 105, Loss Train: 3.5011, Valid: 1.4485

-------------------EPOCH105-------------------------


 11%|█         | 5300/50000 [4:25:52<7:50:50,  1.58it/s]  

Epoch: 106, Loss Train: 3.4975, Valid: 1.4483

-------------------EPOCH106-------------------------


 11%|█         | 5350/50000 [4:28:25<7:43:58,  1.60it/s]  

Epoch: 107, Loss Train: 3.3989, Valid: 1.4498

-------------------EPOCH107-------------------------


 11%|█         | 5400/50000 [4:30:57<7:44:27,  1.60it/s]  

Epoch: 108, Loss Train: 3.5222, Valid: 1.4461

-------------------EPOCH108-------------------------


 11%|█         | 5450/50000 [4:33:33<7:47:40,  1.59it/s]  

Epoch: 109, Loss Train: 3.281, Valid: 1.4455

-------------------EPOCH109-------------------------


 11%|█         | 5500/50000 [4:36:06<7:43:26,  1.60it/s]  

Epoch: 110, Loss Train: 3.4003, Valid: 1.4465

-------------------EPOCH110-------------------------


 11%|█         | 5550/50000 [4:38:39<7:41:01,  1.61it/s]  

Epoch: 111, Loss Train: 3.3039, Valid: 1.4464

-------------------EPOCH111-------------------------


 11%|█         | 5600/50000 [4:41:13<7:39:32,  1.61it/s]  

Epoch: 112, Loss Train: 3.2873, Valid: 1.4435

-------------------EPOCH112-------------------------


 11%|█▏        | 5650/50000 [4:43:47<7:40:14,  1.61it/s]  

Epoch: 113, Loss Train: 3.1575, Valid: 1.4428

-------------------EPOCH113-------------------------


 11%|█▏        | 5700/50000 [4:46:21<7:40:33,  1.60it/s]  

Epoch: 114, Loss Train: 3.405, Valid: 1.44

-------------------EPOCH114-------------------------


 12%|█▏        | 5750/50000 [4:48:53<7:41:29,  1.60it/s]  

Epoch: 115, Loss Train: 3.2915, Valid: 1.4395

-------------------EPOCH115-------------------------


 12%|█▏        | 5800/50000 [4:51:28<7:41:46,  1.60it/s]  

Epoch: 116, Loss Train: 3.3832, Valid: 1.4419

-------------------EPOCH116-------------------------


 12%|█▏        | 5850/50000 [4:54:01<7:37:14,  1.61it/s]  

Epoch: 117, Loss Train: 3.3942, Valid: 1.4398

-------------------EPOCH117-------------------------


 12%|█▏        | 5900/50000 [4:56:35<7:49:06,  1.57it/s]  

Epoch: 118, Loss Train: 3.1458, Valid: 1.4446

-------------------EPOCH118-------------------------


 12%|█▏        | 5950/50000 [4:59:18<7:34:28,  1.62it/s]  

Epoch: 119, Loss Train: 3.4171, Valid: 1.4497

-------------------EPOCH119-------------------------


 12%|█▏        | 6000/50000 [5:01:49<7:38:19,  1.60it/s]  

Epoch: 120, Loss Train: 3.3776, Valid: 1.4496

-------------------EPOCH120-------------------------


 12%|█▏        | 6050/50000 [5:04:22<7:38:40,  1.60it/s]  

Epoch: 121, Loss Train: 3.3732, Valid: 1.4395

-------------------EPOCH121-------------------------


 12%|█▏        | 6100/50000 [5:06:55<7:35:01,  1.61it/s]  

Epoch: 122, Loss Train: 3.1889, Valid: 1.4374

-------------------EPOCH122-------------------------


 12%|█▏        | 6150/50000 [5:09:28<7:38:13,  1.59it/s]  

Epoch: 123, Loss Train: 3.313, Valid: 1.4373

-------------------EPOCH123-------------------------


 12%|█▏        | 6200/50000 [5:12:01<7:34:19,  1.61it/s]  

Epoch: 124, Loss Train: 3.4514, Valid: 1.4421

-------------------EPOCH124-------------------------


 12%|█▎        | 6250/50000 [5:14:33<7:37:07,  1.60it/s]  

Epoch: 125, Loss Train: 3.4727, Valid: 1.441

-------------------EPOCH125-------------------------


 13%|█▎        | 6300/50000 [5:17:06<7:35:04,  1.60it/s]  

Epoch: 126, Loss Train: 3.3577, Valid: 1.4415



RuntimeError: [enforce fail at inline_container.cc:424] . unexpected pos 6692818368 vs 6692818264

In [7]:
cfg = Config()
accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
device = accelerator.device
cfg.update(device=accelerator.device)

base_path = "./csv_files/"
train_data_path = f"{base_path}/eval_epidemic_dataset.csv"
eval_data_path = f"{base_path}/eval_epidemic_dataset.csv"
cfg.update(train_data_path=train_data_path, eval_data_path=eval_data_path)

# 'sound'를 'cross' 키의 리스트에 추가
cfg.fuser['cross'].append('sound')

compression_model, lm = build_model(cfg)
model = AudioProcessing(cfg, lm)

audio_dataset = AudioDataset(cfg, train=True) 
eval_dataset = AudioDataset(cfg, train=False)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


compression_state_dict.bin:   0%|          | 0.00/236M [00:00<?, ?B/s]



state_dict.bin:   0%|          | 0.00/3.68G [00:00<?, ?B/s]

FUSER:  {'cross_attention_pos_emb': False, 'cross_attention_pos_emb_scale': 1, 'sum': [], 'prepend': [], 'cross': ['description', 'sound'], 'input_interpolate': []}


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

In [62]:
t5conditioner = copy.deepcopy(lm.condition_provider.conditioners.description)
soundconditioner = SoundConditionerEncodec(compression_model)

In [63]:
audio_dataloader = DataLoader(audio_dataset, batch_size=1, shuffle=False, num_workers=1)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=1)

optimizer_parameters = [param for param in model.lm.parameters() if param.requires_grad]

optimizer = torch.optim.AdamW(
    optimizer_parameters, lr=cfg.learning_rate,
    betas=(cfg.adam_beta1, cfg.adam_beta2),
    weight_decay=cfg.adam_weight_decay,
    eps=cfg.adam_epsilon,
)

num_update_steps_per_epoch = math.ceil(len(audio_dataloader) / cfg.gradient_accumulation_steps)
if cfg.max_train_steps is None:
  cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
      name=cfg.lr_scheduler_type,
      optimizer=optimizer,
      num_warmup_steps=cfg.num_warmup_steps * cfg.gradient_accumulation_steps,
      num_training_steps=cfg.max_train_steps * cfg.gradient_accumulation_steps,
  )


audio_dataloader, eval_dataloader, model, compression_model, t5conditioner, soundconditioner, optimizer, lr_scheduler = accelerator.prepare(
audio_dataloader, eval_dataloader, model, compression_model, t5conditioner, soundconditioner, optimizer, lr_scheduler
)

In [64]:
wav, description, lengths = next(iter(audio_dataloader))

with torch.no_grad():
    unwrapped_textconditioner = accelerator.unwrap_model(t5conditioner)
    unwrapped_soundconditioner = accelerator.unwrap_model(soundconditioner)
    
    tokenized = {}
    tokenized["description"] =  unwrapped_textconditioner.tokenize(description)
    tokenized["sound"] = wav
    
    # conditioning
    output = {}
    for attribute, inputs in tokenized.items():
        if attribute == "description":   
            condition, mask = unwrapped_textconditioner(inputs)
        elif attribute == "sound":
            condition, mask = unwrapped_soundconditioner(inputs)
        output[attribute] = (condition, mask)

In [69]:
condition.permute(0,2,1).shape

torch.Size([1, 150, 128])

In [12]:
with torch.no_grad():
    unwrapped_vae = accelerator.unwrap_model(compression_model)
    audio_tokens = process_audio_tokenizer(wav, unwrapped_vae)
    audio_tokens, padding_mask = post_process_audio_tokenizer(audio_tokens, lengths, unwrapped_vae, lm, cfg) 

In [13]:
loss = model(audio_tokens, padding_mask, attributes=None, condition_tensors=output)

description torch.Size([1, 7, 1536])
sound torch.Size([1, 180, 1536])
torch.Size([1, 7, 1536])
torch.Size([1, 180, 1536])


In [14]:
loss.backward()

In [15]:
model.eval()

AudioProcessing(
  (lm): LMModel(
    (cfg_dropout): ClassifierFreeGuidanceDropout(p=0.1)
    (att_dropout): AttributeDropout({})
    (condition_provider): ConditioningProvider(
      (conditioners): ModuleDict(
        (description): T5Conditioner(
          (output_proj): Linear(in_features=1024, out_features=1536, bias=True)
        )
      )
    )
    (fuser): ConditionFuser()
    (emb): ModuleList(
      (0-3): 4 x ScaledEmbedding(2049, 1536)
    )
    (transformer): StreamingTransformer(
      (layers): ModuleList(
        (0-47): 48 x StreamingTransformerLayer(
          (self_attn): StreamingMultiheadAttention(
            (out_proj): Linear(in_features=1536, out_features=1536, bias=False)
          )
          (linear1): Linear(in_features=1536, out_features=6144, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
          (linear2): Linear(in_features=6144, out_features=1536, bias=False)
          (norm1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
  

In [18]:
test_dataset = TestDataset(cfg)
test_dataloader = DataLoader(test_dataset, batch_size=2)

In [20]:
descriptions

('<Random>', '<Random>')

In [19]:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_vae = accelerator.unwrap_model(compression_model)
unwrapped_soundconditioner = accelerator.unwrap_model(soundconditioner)
#best_loss = save_checkpoint(cfg, unwrapped_model, result, best_loss, save_epoch)
for test_step, (wav, descriptions) in enumerate(test_dataloader):
    audio_conditions = unwrapped_soundconditioner(wav)
    gen_token, gen_audio = unwrapped_model.inference(descriptions, audio_conditions, unwrapped_vae)

[ConditioningAttributes(text={'description': '<Random>'}, wav={}, joint_embed={}), ConditioningAttributes(text={'description': '<Random>'}, wav={}, joint_embed={}), ConditioningAttributes(text={'description': None}, wav={}, joint_embed={}), ConditioningAttributes(text={'description': None}, wav={}, joint_embed={})]
cfg_conditions {'description': (tensor([[[-0.1169, -0.0594, -0.0393,  ...,  0.0848, -0.1661,  0.0396],
         [-0.0717,  0.0090, -0.0641,  ..., -0.1875,  0.1627, -0.0029],
         [-0.0853, -0.0236, -0.1074,  ..., -0.0780,  0.4068,  0.0394],
         ...,
         [-0.1232,  0.0309,  0.0335,  ...,  0.0234, -0.0042,  0.1150],
         [ 0.3359, -0.1430, -0.1474,  ..., -0.2348, -0.2121,  0.1714],
         [ 0.0261, -0.0256, -0.0012,  ...,  0.0121, -0.0851,  0.0096]],

        [[-0.1169, -0.0594, -0.0393,  ...,  0.0848, -0.1661,  0.0396],
         [-0.0717,  0.0090, -0.0641,  ..., -0.1875,  0.1627, -0.0029],
         [-0.0853, -0.0236, -0.1074,  ..., -0.0780,  0.4068,  0.039

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 8 for tensor number 1 in the list.

### run train

Launching training on one GPU.



  0%|          | 0/14893000 [00:00<?, ?it/s][A

-------------------EPOCH0-------------------------



  0%|          | 1/14893000 [00:07<30143:36:38,  7.29s/it][A
  0%|          | 2/14893000 [00:11<21975:37:12,  5.31s/it][A
  0%|          | 3/14893000 [00:15<19416:09:57,  4.69s/it][A
  0%|          | 4/14893000 [00:19<18229:25:53,  4.41s/it][A
  0%|          | 5/14893000 [00:22<17288:14:16,  4.18s/it][A
  0%|          | 6/14893000 [00:26<16950:58:43,  4.10s/it][A
  0%|          | 7/14893000 [00:30<16643:03:51,  4.02s/it][A
  0%|          | 8/14893000 [00:34<16558:05:44,  4.00s/it][A
  0%|          | 9/14893000 [00:38<16448:22:03,  3.98s/it][A
  0%|          | 10/14893000 [00:42<16383:33:38,  3.96s/it][A
  0%|          | 11/14893000 [00:46<16338:35:22,  3.95s/it][A
  0%|          | 12/14893000 [00:50<16323:30:47,  3.95s/it][A
  0%|          | 13/14893000 [00:54<16315:28:48,  3.94s/it][A
  0%|          | 14/14893000 [00:58<16269:39:58,  3.93s/it][A
  0%|          | 15/14893000 [01:02<16331:42:04,  3.95s/it][A
  0%|          | 16/14893000 [01:06<16288:19:17,  3.94s/it][A


### captioning unit test

In [None]:
def main():

    cfg = Config()
    
    accelerator = Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
    make_dir(cfg.output_dir)
    make_dir(cfg.generated_dir)
    wandb_init(cfg)
    
    #compression_model, lm = build_model(cfg)
    model = AudioProcessing(cfg)
    
    audio_dataset = AudioDataset(cfg, train=True) 
    eval_dataset = AudioDataset(cfg, train=False)
    test_dataset = TestDataset(cfg)

    audio_dataloader = DataLoader(audio_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=12)
    eval_dataloader = DataLoader(eval_dataset, batch_size=cfg.eval_batch_size, shuffle=False, num_workers=4)
    test_dataloader = DataLoader(test_dataset, batch_size=1)

    optimizer_parameters = [param for param in model.lm.parameters() if param.requires_grad]
    
    optimizer = torch.optim.AdamW(
        optimizer_parameters, lr=cfg.learning_rate,
        betas=(cfg.adam_beta1, cfg.adam_beta2),
        weight_decay=cfg.adam_weight_decay,
        eps=cfg.adam_epsilon,
    )
    
    
    num_update_steps_per_epoch = math.ceil(len(audio_dataloader) / cfg.gradient_accumulation_steps)
    if cfg.max_train_steps is None:
      cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
    
    lr_scheduler = get_scheduler(
          name=cfg.lr_scheduler_type,
          optimizer=optimizer,
          num_warmup_steps=cfg.num_warmup_steps * cfg.gradient_accumulation_steps,
          num_training_steps=cfg.max_train_steps * cfg.gradient_accumulation_steps,
      )


    audio_dataloader, eval_dataloader, model, optimizer, lr_scheduler = accelerator.prepare(
        audio_dataloader, eval_dataloader, model, optimizer, lr_scheduler
    )

    starting_epoch, completed_steps, best_loss = 0, 0, np.inf
    progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
    
    for epoch in range(starting_epoch, cfg.num_train_epochs):
        print(f"-------------------EPOCH{epoch}-------------------------" )
        total_loss, total_val_loss = 0, 0
        model.train()
        for batch_idx, (wav, descriptions, lengths) in enumerate(audio_dataloader):
          with accelerator.accumulate(model):
              loss = model(wav, descriptions, lengths)
              ppl =  torch.exp(loss)
              total_loss += loss.detach().float()
              accelerator.backward(loss)     
              optimizer.step()
              lr_scheduler.step()
              optimizer.zero_grad()
              
          if accelerator.sync_gradients:
              progress_bar.update(1)
              completed_steps += 1
            
        model.eval()
        for batch_idx, (wav, descriptions, lengths) in enumerate(eval_dataloader):
              loss = model(wav, descriptions, lengths)
              total_val_loss += loss  
    
        if accelerator.is_main_process:         
            result = {}
            result["epoch"] = epoch + 1,
            result["step"] = completed_steps
            result["train_loss"] = round(total_loss.item()/len(audio_dataloader), 4)
            result["valid_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
            
            wandb.log(result)
            result_string = "Epoch: {}, Loss Train: {}, Valid: {}\n".format(epoch + 1, result["train_loss"], result["valid_loss"])    
            accelerator.print(result_string) 
            best_loss = save_checkpoint(cfg, model, result, best_loss, epoch)
            
            unwrapped_model = accelerator.unwrap_model(model)
            for test_step, batch in enumerate(test_dataloader):
                gen_audio = unwrapped_model.inference(batch)
                audio_filename = f"epoch_{epoch}_{test_step}.wav"
                unwrapped_model.save_audio(gen_audio, audio_filename, cfg)
             
    #wandb.finish()

In [6]:
#main()
from accelerate import notebook_launcher
notebook_launcher(main, num_processes=1)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Launching training on one GPU.


VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step,▁█
train_loss,█▁
valid_loss,▁█

0,1
step,92.0
train_loss,2.151
valid_loss,2.4137


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112272594538, max=1.0)))


  0%|          | 0/46000 [00:00<?, ?it/s][A

-------------------EPOCH0-------------------------



  0%|          | 1/46000 [00:03<40:52:21,  3.20s/it][A
  0%|          | 2/46000 [00:03<20:06:14,  1.57s/it][A
  0%|          | 3/46000 [00:04<13:44:09,  1.08s/it][A
  0%|          | 4/46000 [00:04<10:43:04,  1.19it/s][A
  0%|          | 5/46000 [00:05<9:00:29,  1.42it/s] [A
  0%|          | 6/46000 [00:05<8:01:04,  1.59it/s][A
  0%|          | 7/46000 [00:06<7:22:41,  1.73it/s][A
  0%|          | 8/46000 [00:06<6:57:42,  1.84it/s][A
  0%|          | 9/46000 [00:06<6:41:39,  1.91it/s][A
  0%|          | 10/46000 [00:07<6:31:49,  1.96it/s][A
  0%|          | 11/46000 [00:07<6:23:56,  2.00it/s][A
  0%|          | 12/46000 [00:08<6:19:03,  2.02it/s][A
  0%|          | 13/46000 [00:08<6:16:05,  2.04it/s][A
  0%|          | 14/46000 [00:09<6:13:20,  2.05it/s][A
  0%|          | 15/46000 [00:09<6:12:47,  2.06it/s][A
  0%|          | 16/46000 [00:10<6:09:37,  2.07it/s][A
  0%|          | 17/46000 [00:10<6:10:40,  2.07it/s][A
  0%|          | 18/46000 [00:11<6:09:34,  2.07it/s

Epoch: 1, Loss Train: 2.457, Valid: 2.4067

-------------------EPOCH1-------------------------



  0%|          | 47/46000 [01:00<140:45:57, 11.03s/it][A
  0%|          | 48/46000 [01:00<100:20:00,  7.86s/it][A
  0%|          | 49/46000 [01:01<72:03:39,  5.65s/it] [A
  0%|          | 50/46000 [01:01<52:17:27,  4.10s/it][A
  0%|          | 51/46000 [01:02<38:27:35,  3.01s/it][A
  0%|          | 52/46000 [01:02<28:46:16,  2.25s/it][A
  0%|          | 53/46000 [01:03<21:59:09,  1.72s/it][A
  0%|          | 54/46000 [01:03<17:13:31,  1.35s/it][A
  0%|          | 55/46000 [01:04<13:53:30,  1.09s/it][A
  0%|          | 56/46000 [01:04<11:34:12,  1.10it/s][A
  0%|          | 57/46000 [01:05<9:56:37,  1.28it/s] [A
  0%|          | 58/46000 [01:05<8:49:51,  1.45it/s][A
  0%|          | 59/46000 [01:06<8:01:54,  1.59it/s][A
  0%|          | 60/46000 [01:06<7:27:37,  1.71it/s][A
  0%|          | 61/46000 [01:07<7:03:48,  1.81it/s][A
  0%|          | 62/46000 [01:07<6:47:33,  1.88it/s][A
  0%|          | 63/46000 [01:08<6:35:16,  1.94it/s][A
  0%|          | 64/46000 [01:08<

Epoch: 2, Loss Train: 2.151, Valid: 2.4137

-------------------EPOCH2-------------------------



  0%|          | 93/46000 [01:57<139:14:47, 10.92s/it][A
  0%|          | 94/46000 [01:57<99:18:36,  7.79s/it] [A
  0%|          | 95/46000 [01:58<71:21:33,  5.60s/it][A
  0%|          | 96/46000 [01:58<51:45:49,  4.06s/it][A
  0%|          | 97/46000 [01:59<38:04:44,  2.99s/it][A

KeyboardInterrupt: 

### Other

In [2]:
from pathlib import Path
import time
import tqdm
import json
import typing as tp
import pandas as pd
import glob2
import math
import omegaconf
import torch
from torch.nn import functional as F
import torch.nn as nn
import sys
import torchaudio
from torch.utils.data import Dataset, DataLoader

class Config:
    def __init__(self):
        self.sample_rate = 16000
        self.is_training = True
        self.duration = 1
        self.total_updates = 10000
        self.eval_steps = 4
        self.device = 'cuda'  # 'cuda' 또는 'cpu'
        self.batch_size = 24
        self.eval_batch_size = 4
        self.train_data_path = "/workspace/train_dataset.csv"
        self.eval_data_path = "/workspace/eval_dataset.csv"
        self.output_dir = "./output_dir"
        self.checkpointing_steps = "best"
        self.save_every = 10
        self.with_tracking = False
        self.text_encoder_name = None  # 나중에 설정
        self.snr_gamma = 5.0
        self.freeze_text_encoder = True
        self.uncondition = False
        self.learning_rate = 3e-5
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.adam_weight_decay = 1e-2
        self.adam_epsilon = 1e-08
        self.gradient_accumulation_steps = 1
        self.num_train_epochs = 1000
        self.num_warmup_steps = 0
        self.max_train_steps = None
        self.lr_scheduler_type = "linear"
        self.resume_from_checkpoint = None #"/workspace/output_dir_batch48/last/" #None
        self.wandb_project_name = "audiogen-finetune-init-test1"
        self.wandb_id = None #"earnest-pond-52"
        self.resume_epoch = 0 #127
        self.dtype = "float32"
        
        self.update_audiocraft_config()
        

    def update(self, **kwargs):
        for key, value in kwargs.items():
            # 기존 속성에 값 할당하거나 새 속성 생성
            if not hasattr(self, key):
                print(key)
            setattr(self, key, value)
            
    def update_audiocraft_config(self):
        self.solver = None
        self.fsdp = None
        self.profiler = None
        self.deadlock = None
        self.dataset = None
        self.checkpoint = None
        self.generate = None
        self.evaluate = None
        self.optim = None
        self.schedule = None
        self.default = None
        self.defaults = None
        self.autocast = None
        self.autocast_dtype = None

        self.compression_model_checkpoint = None
        self.channels = None
        self.logging = None
        self.lm_model = None
        self.codebooks_pattern = None
        self.transformer_lm = None
        self.classifier_free_guidance = None
        self.attribute_dropout = None
        self.fuser = None
        self.conditioners = None
        self.datasource = None


In [3]:
class AudioDataset(Dataset):
    def __init__(self, audio_paths, device, target_sample_rate=44100, duration=3):
        import pandas as pd
        """
        Args:
            audio_files_list (list): List of paths to audio files.
            target_sample_rate (int): The sample rate to which audio should be resampled.
            frame_length (int): The frame length for slicing or padding audio.
        """
        self.audio_paths = audio_paths
        self.target_sample_rate = target_sample_rate
        self.duration = duration
        self.device = device

        self.df = pd.read_csv(self.audio_paths)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        data = self.df.iloc[idx] #self.audio_files_list[idx]
        audio_path = data['sliced_audio_path']
        description = data['description']

        # Load audio signal file
        from audiotools import AudioSignal
        wav = AudioSignal(audio_path)
        length = wav.signal_length

        # Encode audio signal as one long file
        wav.to_mono()
        wav.resample(self.target_sample_rate)

        if wav.duration < self.duration:
          pad_len = int(self.duration * self.target_sample_rate) - wav.signal_length
          wav.zero_pad(0, pad_len)
        elif wav.duration > self.duration:
          wav.truncate_samples(self.duration * self.target_sample_rate)


        return wav.audio_data.squeeze(1), description, length

class TestDataset(Dataset):
    def __init__(self, prompts):

        self.prompts = prompts


    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):

        return self.prompts[idx]

In [None]:
from audiocraft.solvers import base, builders
from audiocraft.solvers.compression import CompressionSolver
from audiocraft import metrics as eval_metrics
from audiocraft import models
from audiocraft.data.audio_utils import normalize_audio
from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
from audiocraft.utils.utils import get_dataset_from_loader, is_jsonable, warn_once
from audiocraft.models.loaders import load_compression_model, load_lm_model

In [7]:
class AudioProcessing(nn.Module):
    
    def __init__(self, cfg):
        super().__init__()  # 부모 클래스 초기화 호출
        self.cfg = cfg
        self.compression_model, self.lm  = self.build_model(self.cfg)
        self.to_float32()
        self.freeze_layers()

    def forward(self, wav, descriptions, lengths):
        from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
        audio_tokens = self.process_audio_tokenizer(wav.to(self.cfg.device))
        audio_tokens, padding_mask = self.post_process_audio_tokenizer(audio_tokens, audio_lengths=lengths)
        
        attributes = [
            ConditioningAttributes(text={'description': description})
            for description in descriptions]
    
        model_output = self.lm.compute_predictions(audio_tokens, conditions=attributes, condition_tensors=None)  # type: ignore
        logits = model_output.logits
    
        mask = padding_mask & model_output.mask
        ce, ce_per_codebook = self.compute_cross_entropy(logits, audio_tokens, mask)
        
        return ce

    def build_model(self, cfg):
        from audiocraft.models.loaders import load_compression_model, load_lm_model
        """Instantiate models and optimizer."""
        
        compression_model = load_compression_model('facebook/audiogen-medium', device=cfg.device)
        lm = load_lm_model('facebook/audiogen-medium', device=cfg.device)
    
        return compression_model, lm

    def process_audio_tokenizer(self, wav):
        with torch.no_grad():
            audio_tokens, scale = self.compression_model.encode(wav)
        return audio_tokens

    def post_process_audio_tokenizer(self, audio_tokens, audio_lengths=None):
        padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)

        audio_tokens = audio_tokens.clone()
        padding_mask = padding_mask.clone()
        token_sample_rate = self.compression_model.frame_rate
        B, K, T_s = audio_tokens.shape
        for i in range(B):
            valid_tokens = math.floor(audio_lengths[i] / self.cfg.sample_rate * token_sample_rate)
            audio_tokens[i, :, valid_tokens:] = self.lm.special_token_id
            padding_mask[i, :, valid_tokens:] = 0

        return audio_tokens, padding_mask

    def compute_cross_entropy(self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:

        B, K, T = targets.shape
        assert logits.shape[:-1] == targets.shape
        assert mask.shape == targets.shape
        ce = torch.zeros([], device=targets.device)
        ce_per_codebook: tp.List[torch.Tensor] = []
        for k in range(K):
            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
            ce_targets = targets_k[mask_k]
            ce_logits = logits_k[mask_k]
            q_ce = F.cross_entropy(ce_logits, ce_targets)
            ce += q_ce
            ce_per_codebook.append(q_ce.detach())
        # average cross entropy across codebooks
        ce = ce / K
        return ce, ce_per_codebook

    def audio_generate(self, condition_tensors, gen_duration=5):
        with torch.no_grad():
            total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
            gen_tokens = self.lm.generate(
                None, condition_tensors, max_gen_len=total_gen_len,
                num_samples=1)
            gen_audio = self.compression_model.decode(gen_tokens, None)

        return gen_tokens, gen_audio

    def inference(self, descriptions):
        #with torch.no_grad():
        from audiocraft.modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition, ConditioningAttributes
        attributes = [
        ConditioningAttributes(text={'description': description})
        for description in descriptions]
        _, gen_audio = self.audio_generate(attributes, gen_duration=self.cfg.duration)
        
        return gen_audio

    def to_float32(self):
        # 모든 가중치를 FP32로 변환
        for param in self.lm.parameters():
            param.data = param.data.to(dtype=torch.float32)

    def freeze_layers(self, train_layers=12):
        for param in self.lm.parameters():
            param.requires_grad = False
            
        if train_layers > 0 :
            num_layers = len(self.lm.transformer.layers)
            
            for i in range(num_layers - train_layers, num_layers):
                for param in self.lm.transformer.layers[i].parameters():
                    param.requires_grad = True
                    
            for name, param in self.lm.named_parameters():
                if 'out_norm' in name or 'linears' in name:
                    param.requires_grad = True
                
            
        
        
            

In [4]:
# Extract discrete codes from EnCodec
def process_audio_tokenizer(wav):
  with torch.no_grad():
      audio_tokens, scale = compression_model.encode(wav)
  return audio_tokens

def post_process_audio_tokenizer(audio_tokens, audio_lengths=None, cfg=None):
  padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
  # replace encodec tokens from padded audio with special_token_id

  audio_tokens = audio_tokens.clone()
  padding_mask = padding_mask.clone()
  token_sample_rate = compression_model.frame_rate
  B, K, T_s = audio_tokens.shape
  for i in range(B):
      # take the last token generated from actual audio frames (non-padded audio)
      #math.floor(float(n_frames[i]) / sr[i] * token_sample_rate)
      valid_tokens = math.floor(audio_lengths[i] / cfg.sample_rate * token_sample_rate)
      audio_tokens[i, :, valid_tokens:] = lm.special_token_id
      padding_mask[i, :, valid_tokens:] = 0

  return audio_tokens, padding_mask

def _compute_cross_entropy(
      logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
    ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
        """Compute cross entropy between multi-codebook targets and model's logits.
        The cross entropy is computed per codebook to provide codebook-level cross entropy.
        Valid timesteps for each of the codebook are pulled from the mask, where invalid
        timesteps are set to 0.

        Args:
            logits (torch.Tensor): Model's logits of shape [B, K, T, card].
            targets (torch.Tensor): Target codes, of shape [B, K, T].
            mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
        Returns:
            ce (torch.Tensor): Cross entropy averaged over the codebooks
            ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
        """
        B, K, T = targets.shape
        assert logits.shape[:-1] == targets.shape
        assert mask.shape == targets.shape
        ce = torch.zeros([], device=targets.device)
        ce_per_codebook: tp.List[torch.Tensor] = []
        for k in range(K):
            logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))  # [B x T, card]
            targets_k = targets[:, k, ...].contiguous().view(-1)  # [B x T]
            mask_k = mask[:, k, ...].contiguous().view(-1)  # [B x T]
            ce_targets = targets_k[mask_k]
            ce_logits = logits_k[mask_k]
            q_ce = F.cross_entropy(ce_logits, ce_targets)
            ce += q_ce
            ce_per_codebook.append(q_ce.detach())
        # average cross entropy across codebooks
        ce = ce / K
        return ce, ce_per_codebook

def audio_generate(condition_tensors, gen_duration=5):
    with torch.no_grad():
      total_gen_len = math.ceil(gen_duration * compression_model.frame_rate)
      gen_tokens = lm.generate(
          None, condition_tensors, max_gen_len=total_gen_len,
          num_samples=1)
      gen_audio = compression_model.decode(gen_tokens, None)

    return gen_tokens, gen_audio

In [None]:
wav, text, length = next(iter(audio_dataloader))
audio_tokens = process_audio_tokenizer(wav.to(cfg.device))
print("Wav shape: ", wav.shape)
print("Token shape: ", audio_tokens.shape)
post_process_audio_tokenizer(audio_tokens, audio_lengths=length)

import torch

# 가정: model이라는 이름의 PyTorch 모델이 이미 정의되어 있음
for name, param in lm.named_parameters():
    print(f"Layer {name} has data type {param.dtype}")
    #break  # 모든 레이어를 표시하지 않고 첫 레이어에서 루프 중단

def check_requires_grad(model: torch.nn.Module):
    for name, module in model.named_children():
        for param_name, param in module.named_parameters():
            print(f"{name}.{param_name}: requires_grad = {param.requires_grad}")

# DAC 모델의 인스턴스를 생성한 후에 아래와 같이 사용할 수 있습니다:
# dac_instance = DAC(...)
check_requires_grad(lm)