In [6]:
# Imports
import numpy as np
import torch
import torchaudio
import matplotlib.pyplot as plt
import IPython.display as ipd
import gin
from functools import partial
import copy
from gamadhani.utils.utils import download_models, download_data
from gamadhani.utils.generate_utils import load_audio_fns
from gamadhani.utils.utils import get_device
import gamadhani.utils.pitch_to_audio_utils as p2a


# Load Model
_, _, audio_path, audio_qt = download_models('kmaneeshad/GaMaDHaNi', pitch_model_type="diffusion")
base_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(audio_path=audio_path, qt_path=audio_qt, config_path='../configs/pitch_to_audio_config.gin')

# Parse pitch config
gin.parse_config_file('../configs/diffusion_pitch_config.gin')
Task_ = gin.get_configurable('src.dataset.Task')
task_obj = Task_()
pitch_task_fn = partial(task_obj.read_)
invert_pitch_task_fn = partial(task_obj.invert_)


# Prepare Test Set
test_set = np.load(download_data('kmaneeshad/GaMaDHaNi-db'), allow_pickle=True)['concatenated_array']
n_test = test_set.shape[0]
test_pitches = np.array([test_set[i,0][:2400] for i in range(n_test)])
test_audios = np.array([test_set[i,1] for i in range(n_test)])
test_audios_resampled_truncated = torch.stack([torchaudio.functional.resample(torch.tensor(test_audios[i]), orig_freq=44100, new_freq=16000) for i in range(n_test)])[...,:191744]
test_singer_ids = np.array([test_set[i,4] for i in range(n_test)])

In [7]:
def get_f0_input(pitch_val):
    processed_pitch_val = pitch_task_fn(**{"inputs": {"pitch": {"data": pitch_val}}})['sampled_sequence']
    processed_pitch_val = torch.Tensor(processed_pitch_val).reshape(1, 1, processed_pitch_val.shape[0])
    interpolated_pitch = p2a.interpolate_pitch(pitch=processed_pitch_val, audio_seq_len=audio_seq_len)    # interpolate pitch values
    interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)
    interpolated_pitch = interpolated_pitch.squeeze(1)
    f0 = interpolated_pitch.float()
    return f0

def generate_audio(model, pitch_val, singer_id, num_steps=100, audio_only=True):
    model = model.cuda()
    f0 = get_f0_input(pitch_val).to(model.device)
    singer_tensor = torch.tensor(np.repeat([singer_id], repeats=f0.shape[0])).to(model.device)

    if audio_only:
        out_spec, _, _, _ =  model.sample_cfg(f0.shape[0], f0=f0, num_steps=num_steps, singer=singer_tensor, strength=3, invert_audio_fn=invert_audio_fn)
        audio = invert_audio_fn(out_spec)
        return audio
    else:
        out_spec, out_pitch, singer_id, all_activations = model.sample_cfg(f0.shape[0], f0=f0, num_steps=num_steps, singer=singer_tensor, strength=3, invert_audio_fn=invert_audio_fn, log_interim_samples=True, log_interim_forward_activations=True)
        audio = invert_audio_fn(out_spec)
        return out_spec, out_pitch, singer_id, all_activations, audio
  
def generate_all_audios(model, num_steps):
  generated_audios = []
  for i in range(n_test):
     generated_audios.append(generate_audio(model, test_pitches[i], test_singer_ids[i], num_steps=num_steps))
  return torch.cat(generated_audios, dim=0)
    

def quantize_per_tensor(x, bits=4, min_val=None, max_val=None):
    if min_val is None:
        min_val = x.min()
    if max_val is None:
        max_val = x.max()
    targets = torch.linspace(min_val, max_val, 2**bits).to(x.device)
    differences = torch.abs(x.unsqueeze(-1) - targets)
    nearest_indices = torch.argmin(differences, dim=-1)
    rounded_values = targets[nearest_indices]
    return rounded_values

def quantize_model(net, bits, mode='global'):
  quantized_net = copy.deepcopy(net)
  for name, module in quantized_net.named_modules():
    if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear, torch.nn.ConvTranspose1d)):
      if mode == 'global':
        module.weight.data = quantize_per_tensor(module.weight.data, bits=bits)
      elif mode == 'per_channel':
        for i in range(module.weight.data.shape[0]):
          module.weight.data[i] = quantize_per_tensor(module.weight.data[i], bits=bits)

  if hasattr(module, 'bias') and module.bias is not None and mode == 'per_channel':
      module.bias.data = quantize_per_tensor(module.bias.data, bits=bits)
  return quantized_net

In [None]:
base_generated = generate_all_audios(base_model, 100)


100%|██████████| 200/200 [00:00<00:00, 1495.63it/s]
100%|██████████| 200/200 [00:00<00:00, 1447.17it/s]
100%|██████████| 200/200 [00:00<00:00, 1444.26it/s]
100%|██████████| 200/200 [00:00<00:00, 1500.56it/s]
100%|██████████| 200/200 [00:00<00:00, 1486.76it/s]
100%|██████████| 200/200 [00:00<00:00, 1549.65it/s]
100%|██████████| 200/200 [00:00<00:00, 1577.66it/s]
100%|██████████| 200/200 [00:00<00:00, 1599.44it/s]
100%|██████████| 200/200 [00:00<00:00, 1526.61it/s]
100%|██████████| 200/200 [00:00<00:00, 1500.14it/s]
100%|██████████| 200/200 [00:00<00:00, 1448.26it/s]
100%|██████████| 200/200 [00:00<00:00, 1499.73it/s]
100%|██████████| 200/200 [00:00<00:00, 1457.11it/s]
100%|██████████| 200/200 [00:00<00:00, 421.99it/s]
100%|██████████| 200/200 [00:00<00:00, 1335.37it/s]
100%|██████████| 200/200 [00:00<00:00, 1548.41it/s]
