In [2]:
import torch
import torchaudio
from torchaudio.transforms import Resample

def noise_gen_gaussian_stereo(range_factor, frame_count, device):
    mean = 0.0
    #portion of values in range = 1 - 1 / range_factor^2
    #value range is 1 here
    std = 1.0 / range_factor
    
    # Gaussian noise: create a random normal distribution that has the same size as the data to add noise to 
    # Genearte noise with same size as that of the data.
    ch_1 = torch.normal(mean=mean, std=std, size=(frame_count,), device=device)
    ch_2 = torch.normal(mean=mean, std=std, size=(frame_count,), device=device)
    return torch.stack((ch_1, ch_2))


def load_audio(path, sample_rate=44100):
    audio, sr = torchaudio.load(path)
    # Resample if needed
    if sr != sample_rate:
        resampler = Resample(orig_freq=sr, new_freq=sample_rate)
        audio = resampler(audio)
    return audio

def prompts_concat(start_audio_path, end_audio_path, output_path, noise_duration, device, sample_rate=44100):
    range_factor = 4  # for gaussian noise generation

    start_audio_data = load_audio(start_audio_path, sample_rate).to(device)
    end_audio_data = load_audio(end_audio_path, sample_rate).to(device)
    
    noise_data = noise_gen_gaussian_stereo(
        range_factor,
        int(noise_duration * sample_rate),
        device,
    )
    concat_data = torch.cat((start_audio_data, noise_data, end_audio_data), dim=1)
    # print(concat_data.shape)
    torchaudio.save(output_path, concat_data, sample_rate)

    #since we had already loaded the start and end tracks, we can also get the repaint start and end times here
    start_audio_duration = start_audio_data.shape[-1] / sample_rate
    return (start_audio_duration, start_audio_duration + noise_duration)


start_audio_path = "/homes/al4624/Documents/YuE_finetune/inference_audio_prompts/start.mp3"
end_audio_path = "/homes/al4624/Documents/YuE_finetune/inference_audio_prompts/end.mp3"
output_path = "/homes/al4624/Documents/YuE_finetune/inference_audio_prompts/full.mp3"

device = torch.device(f"cuda:{cuda_idx}" if torch.cuda.is_available() else "cpu")
start, end = prompts_concat(start_audio_path, end_audio_path, output_path, 10, device)
print(start)
print(end)

10.057142857142857
20.057142857142857
