In [None]:
import os
import pandas as pd
from tqdm import tqdm
import librosa
from collections import defaultdict

latent_paths = os.listdir("/workspace/AVE_Datasetori/AVE_latents")
print(latent_paths[:5])
print(os.listdir('/workspace/AVE_Datasetori/AVE_audio/')[0])

with open('/workspace/AVE_Datasetori/Annotations.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()  # 리스트 형태로 모든 줄 읽음

datas = []
infos = defaultdict(list)

print(lines[:5])
for line in tqdm(lines):
    d = line.split("&")
    infos[d[1]] = [d[0]]
    # datas.append({
    #     "video_latent_path":,
    #     "caption":d[0],
    #     "duration":
    #     "audio_path":
    # })

In [None]:
import re
import os

datas = []
for latent_path in tqdm(latent_paths):
    latent_id = latent_path[:-len("_126res_24fps.npy")]
    audio_path = f'/workspace/AVE_Datasetori/AVE_audio/{latent_id}_126res_24fps.wav'
    caption = infos[latent_id]
    audio, sr = librosa.load(audio_path)
    duration = audio.shape[-1]/sr
    video_latent_path = "/workspace/AVE_Datasetori/AVE_latents/" + latent_path
    if not os.path.exists(video_latent_path):
        print("stops!", video_latent_path)
    audio_latent_path = f'/workspace/AVE_Datasetori/AVE_audio_latent/{latent_id}_126res_24fps.npy'
    if not os.path.exists(audio_latent_path):
        print("stops!", audio_latent_path)

    datas.append({
        "video_latent_path": video_latent_path,
        "caption": caption[0],
        "duration": duration,
        "audio_path": audio_latent_path
    })

In [None]:
datas[0]

In [None]:
import csv

with open('output_data.csv', 'w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=['video_latent_path', 'caption', 'duration', 'audio_path'])
    writer.writeheader()  # 컬럼명 쓰기
    writer.writerows(datas)

In [None]:
import pandas as pd
from data.datamodule import AudioDataModule
from lightning import Callback, LightningModule, Trainer, seed_everything

seed_everything(42, workers=True)

datamodule = AudioDataModule(
    dataset_path='output_data.csv',
    batch_size=4,
    num_workers=4,
    sampling_rate=44100,
    channel=2,
    max_audio_len=215,
    max_txt_len=128,
    max_video_len=30
)

In [None]:
datamodule.setup()

In [None]:
tdt = datamodule.val_dataloader()

In [None]:
data = next(iter(tdt))
print(data[0])

# %s does not exist.%s does not exist. %s does not exist./workspace/AVE_Datasetori/AVE_audio_latent/92sTzXubUgQ_126res_24fps.npy 
# /workspace/AVE_Datasetori/AVE_audio_latent/-j9TbKCJMwI_126res_24fps.npy 
# /workspace/AVE_Datasetori/AVE_audio_latent/92sTzXubUgQ_126res_24fps.npy does not exist.
# %s does not exist.
# /workspace/AVE_Datasetori/AVE_audio_latent/-h9REoRVzYY_126res_24fps.npy 
# /workspace/AVE_Datasetori/AVE_audio_latent/-j9TbKCJMwI_126res_24fps.npy does not exist.

# %s does not exist./workspace/AVE_Datasetori/AVE_audio_latent/05ExLJ7xRis_126res_24fps.npy%s does not exist.

In [None]:
audio_embed = data[0]
video_latents = data[1]
audio_mask = data[2]
text = data[3]
text_mask = data[4]

print(audio_embed.shape)
print(video_latents.shape)

In [None]:
from model.audiobox_video import AudioBox
from model.module_video import AudioBoxModule

# pip uninstall torchvision
# pip install torchvision --no-cache-dir

audiobox = AudioBox(
    audio_dim=64,
    text_dim=768,
    dim=1024,
    depth=24,
    heads=16,
    attn_dropout=0.0,
    ff_dropout=0.1,
    kernel_size=31,
)
model = AudioBoxModule(
    dim=64,
    depth=24,
    heads=16,
    attn_dropout=0.0,
    ff_dropout=0.1,
    kernel_size=31,
    voco_type="oobleck",
    optimizer="AdamW",
    lr=0.0001,
    scheduler="linear_warmup_decay",
    max_audio_len=215,
    max_steps=1000,
    text_repo_id="google/flan-t5-base",
)
device = 'cuda'
audiobox.to(device)
model.to(device)
print("-")

import torch

def count_parameters(mode):
    return sum(p.numel() for p in mode.parameters() if p.requires_grad)
print(f"Trainable parameters: {count_parameters(audiobox):,}")
# 433M

In [None]:
from utils.mask import min_span_mask, prob_mask_like
import torch
from einops import rearrange

audio_embed = audio_embed.to(device)
text = text.to(device)
text_mask = text_mask.to(device)
video_latents = video_latents.to(device)
audio_mask = audio_mask.to(device)

bs = audio_embed.shape[0]
with torch.no_grad():
    span_mask = model.get_span_mask(audio_mask)
    with torch.autocast(device_type=model.device.type, enabled=False):
        text_emb = model.t5(
            input_ids=text.to(device), attention_mask=text_mask.to(device)
        ).last_hidden_state

audio_x0 = torch.randn_like(audio_embed)
times = torch.rand((bs,), dtype=audio_embed.dtype, device=model.device) # torch.rand는 0~1 uniform sampling
t = rearrange(times, "b -> b () ()")
w = (1 - (1 - model.sigma) * t) * audio_x0 + t * audio_embed

cond_drop_mask = prob_mask_like((bs, 1), model.drop_prob, model.device)
audio_cond_mask = span_mask | cond_drop_mask

audio_context = torch.where(
    rearrange(audio_cond_mask, "b l -> b l ()"), 0, audio_embed
)

text_drop_mask = prob_mask_like((bs,), model.drop_prob, model.device)
text_emb = torch.where(
    rearrange(text_drop_mask, "b -> b () ()"), 0, text_emb
)

In [None]:
pred_audio_flow = audiobox(
    w=w,
    times=times,
    audio_mask=audio_mask,
    context=audio_context,
    text_emb=text_emb,
    text_mask=text_mask,
    video_latent=video_latents
)

In [None]:
pred_audio_flow.shape

In [None]:
import torch.nn as nn
import torch.nn.functional as F

x = torch.randn((4, 30, 64))
x = x.repeat_interleave(8, dim=1)  # [B, 240, 64]
x = F.pad(x, pad=(0, 0, 0, 5))  # (C_left, C_right, T_left, T_right)
print(x.shape)

In [None]:
F.pad(x, 0)