In [None]:
print(1)

In [None]:
import os
import math
import torch
import random
import torchaudio
from tqdm import tqdm
import torch.nn.functional as F
from dit_clap import DiffusionTransformer
from config.model.config import config
from train.dataset import SampleDataset
from condition.t5condition import T5Conditioner
from train.validation import validate_or_test
from train.utils import draw_plot, cleanup_memory, save_concatenated_mel_spectrogram, calculate_targets, prepare_batch_data, masked_loss
from vae.get_function import create_autoencoder_from_config
from torch.cuda.amp import autocast, GradScaler
from inference.inference import generation
from utils import get_span_mask, prob_mask_like
from accelerate import Accelerator, DistributedDataParallelKwargs
from safetensors.torch import load_file, save_file
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
from torch.utils.data import Dataset, DataLoader, RandomSampler
from train.dataset import Config
from audiotools import AudioSignal
from condition.clap_model import CLAPModule
import pandas as pd

In [None]:
device = 'cuda'

model = DiffusionTransformer(
    d_model          = 1536,
    depth            = 24,
    num_heads        = 24,
    input_concat_dim = 0,
    global_cond_type = 'prepend',
    latent_channels  = 64, 
    config           = config, 
    device           = device,
    ada_cond_dim     = None,
    use_skip         = False,
    is_melody_prompt = True,
    is_audio_prompt  = True
)
model = model.to(device)

ae = create_autoencoder_from_config(config['auto_encoder_config']).to(device)
text_conditioner = T5Conditioner(output_dim=768).to(device)

ae.eval()
text_conditioner.eval()
model.eval()

num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(num_trainable_params)

model.load_state_dict(torch.load('/workspace/mel_con_sample/total_model_350.pth'), strict=False)
# model.load_state_dict(torch.load('/workspace/mel_con_sample/main/weights_0205_clap_full/model_0.pth'))
print("-")

ae_state_dict = torch.load('/workspace/mel_con_sample/vae_weight.pth')
ae_clean_state_dict = {layer_name.replace('model.', ''): weights for layer_name, weights in ae_state_dict.items()}
ae.load_state_dict(ae_clean_state_dict)

del ae_clean_state_dict
del ae_state_dict
torch.cuda.empty_cache()

clap_model = CLAPModule().to(device)
clap_model.eval()

print("-")

In [None]:
import librosa
from condition.voice import make_voice_cond

cfg = {
    "sample_steps": 100,
    "sample_cfg": 6.5,
    "sample_rate": 44100,
    "diffusion_objective": "v"
}

audios = [
    {
        'audio_path': '/workspace/mel_con_sample/testsamples/bass_synth.wav',
        'prompt': 'prompts: electric guitar, loop',
        'duration': 3.7
    },
    {
        'audio_path': '/workspace/mel_con_sample/testsamples/guitar_melody.wav',
        'prompt': 'prompts: piano, loop',
        'duration': 13.0
    },
    {
        'audio_path': '/workspace/mel_con_sample/testsamples/hiphop_vocal.wav',
        'prompt': 'prompts: hip-hop, vocal, loop, wet',
        'duration': 6.9
    },
    {
        'audio_path': '/workspace/mel_con_sample/testsamples/percussion.wav',
        'prompt': 'prompts: electric drums, loop, wet',
        'duration': 9.5
    },
    {
        'audio_path': '/workspace/mel_con_sample/testsamples/acoustic_chords.wav',
        'prompt': 'prompts: electric guitar, dry',
        'duration': 4.1
    },
]

for i, d in enumerate(audios):
    ap = d['audio_path']
    prompt = d['prompt']
    duration = d['duration']
    
    audio, sr = librosa.load(ap, sr=44100)
    audio = torch.tensor(audio)
    audio = audio[:44100*10]
    melody = make_voice_cond(audio)
    zeros = torch.zeros((16, 215-melody.shape[-1]))
    melody = torch.concat([melody, zeros], dim=-1)
    melody = melody.unsqueeze(dim=0).to(device)
    audio_embed = clap_model(audio.unsqueeze(dim=0))

    audio_embed = torch.zeros_like(audio_embed)
    # print('audio_embed ', audio_embed.shape)
    # melody = torch.zeros((1, 16, 215)).to(device)
    
    output = generation(
        model,
        ae,
        text_conditioner,
        text=prompt,
        steps=cfg['sample_steps'],
        cfg_scale=cfg['sample_cfg'],
        duration=duration,
        sample_rate=cfg['sample_rate'],
        batch_size=1,
        device=device,
        disable=False,
        melody_prompt=melody,
        audio_embed=audio_embed
    )
    # AudioSignal(ap).widget()
    # print(output.shape)
    # torchaudio.save(f'/workspace/generateds/generated_{i}.wav', output.cpu()[0], 44100)
    AudioSignal(output.cpu().numpy()[0][:, :int(44100*duration)], sample_rate=44100).widget()

In [None]:
audio_embed.shape

In [None]:
melody.shape

In [None]:
df = pd.read_csv("/workspace/for_test_rank.csv")

df.iloc[0]['generated_prompt']

In [None]:
import librosa
from condition.voice import make_voice_cond
import pandas as pd

df = pd.read_csv("/workspace/for_test_rank.csv")

for i in range(10):
    d = df.iloc[i]
    
    prompt = d['generated_prompt']
    duration = d['duration']
    
    melody = torch.zeros((1, 16, 215)).to(device)
    audio_embed = torch.zeros((1, 512)).to(device)
    
    cfg = {
        "sample_steps": 100,
        "sample_cfg": 7.0,
        "sample_duration": duration,
        "sample_rate": 44100,
        "diffusion_objective": "v"
    }

    for j in range(2):
        output = generation(
             model,
            ae,
            text_conditioner,
            text=prompt,
            steps=cfg['sample_steps'],
            cfg_scale=cfg['sample_cfg'],
            duration=duration,
            sample_rate=cfg['sample_rate'],
            batch_size=1,
            device=device,
            disable=False,
            melody_prompt=melody,
            audio_embed=audio_embed
        )
        # torchaudio.save(f'/workspace/generateds/generated_{i}_{j}.wav', output.cpu()[0], 44100)
        AudioSignal(output.cpu().numpy()[0][:, :int(44100*duration)], sample_rate=44100).widget()

In [None]:
import librosa
from condition.voice import make_voice_cond

cfg = {
    "sample_steps": 100,
    "sample_cfg": 6.0,
    "sample_rate": 44100,
    "diffusion_objective": "v"
}
seed = 12
torch.manual_seed(seed)

audios = [
    {
        'audio_path': './FMT_-_Bm_-_Analog_Phazer_Chords.wav',
        'prompt': 'prompts: guitar, blues, rock, electric, funk, live sounds, high, loop, key: b major, bpm: 105',
        # 'prompt': 'prompts: guitar, blues, loop',
        'duration': 4.5
    }
]

# for i, d in enumerate(audios):
for i, d in enumerate(audios[:1]):
    inst_ap = audios[0]['audio_path']
    melody_ap = audios[0]['audio_path']
    prompt = audios[0]['prompt']
    duration = audios[0]['duration']
    
    audio, sr = librosa.load(melody_ap, sr=44100)
    audio = torch.tensor(audio)
    audio = audio[:44100*10]
    audio = audio
    melody = make_voice_cond(audio)
    zeros = torch.zeros((16, 215-melody.shape[-1]))
    melody = torch.concat([melody, zeros], dim=-1)
    melody = melody.unsqueeze(dim=0).to(device)

    # # melody = torch.where(melody >= 0.5, melody, torch.tensor(0.0))
    
    # audio, sr = librosa.load(inst_ap, sr=44100)
    # audio = torch.tensor(audio)
    # audio = audio[:44100*10]
    # audio_embed = clap_model(audio.unsqueeze(dim=0))
    
    # melody = torch.zeros((1, 16, 215)).to(device)
    # audio_embed = torch.zeros((1, 512)).to(device)
    
    output = generation(
        model,
        ae,
        text_conditioner,
        text=prompt,
        steps=cfg['sample_steps'],
        cfg_scale=cfg['sample_cfg'],
        duration=duration,
        sample_rate=cfg['sample_rate'],
        batch_size=1, 
        device=device,
        disable=False,
        melody_prompt=melody,
        audio_embed=audio_embed
    )
    AudioSignal(output.cpu().numpy()[:, :int(44100*duration)], sample_rate=44100).widget()
    # AudioSignal(inst_ap).widget()

In [None]:
model.model

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_chromagram(chromagram, sample_rate):
    """
    시각화 함수
    """
    plt.figure(figsize=(10, 6))
    # chromagram을 NumPy 배열로 변환
    chroma_np = chromagram.squeeze().cpu().detach().numpy()  # 텐서를 NumPy로 변환
    # 시각화
    plt.imshow(chroma_np, aspect='auto', origin='lower', cmap='coolwarm', interpolation='nearest')
    plt.colorbar(format='%+2.0f dB')
    plt.xlabel('Frames')
    plt.ylabel('Chroma Bins')
    plt.title('Chromagram')
    plt.show()

In [None]:
visualize_chromagram(melody, sample_rate=44100)

melody = make_voice_cond(output[0][:, :int(44100*duration)].cpu().detach())
zeros = torch.zeros((16, 215-melody.shape[-1]))
melody = torch.concat([melody, zeros], dim=-1)
visualize_chromagram(melody, sample_rate=44100)

In [None]:
import librosa
import torchaudio

ap = './FMT_-_Bm_-_Analog_Phazer_Chords.wav'

audio, sr = librosa.load(ap, sr=44100)
audio = torch.tensor(audio)
audio = torch.stack([audio, audio], dim=0).to(device)
print(audio.shape, sr)
codes = ae.encode(audio)
print(codes.shape)

decodes = ae.decode(codes)
print(decodes.shape)

In [None]:
AudioSignal(ap).widget()
AudioSignal(decodes.cpu().numpy().squeeze(), sample_rate=44100).widget()