In [1]:
import sys
sys.path.append('..')
from gamadhani.utils.utils import download_models, download_data
from gamadhani.src.dataset import SequenceDataset
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd

from gamadhani.utils.generate_utils import load_audio_fns
from gamadhani.utils.utils import get_device
import gin
from functools import partial
import torch
import gamadhani.utils.pitch_to_audio_utils as p2a
from tqdm import tqdm
import pdb

  from .autonotebook import tqdm as notebook_tqdm
  @autocast(enabled = False)
  @autocast(enabled = False)


In [2]:
device = get_device()
_, _, audio_model_path, audio_qt_path = download_models('kmaneeshad/GaMaDHaNi', pitch_model_type="diffusion")
audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(audio_path=audio_model_path,
    qt_path=audio_qt_path,
    config_path='/home/mila/n/nithya.shikarpur/GaMaDHaNi-dev/configs/pitch_to_audio_config.gin',
device=device)

Script is running on: GPU


  ckpt = torch.load(ckpt, map_location=device)


In [3]:
n_samples = 128
rand_inds = np.random.choice(np.arange(7174), n_samples)


In [4]:
gin.parse_config_file('/home/mila/n/nithya.shikarpur/GaMaDHaNi-dev/configs/pitch_to_audio_config_updated.gin')
db = '/home/mila/n/nithya.shikarpur/scratch/final-data-ismir/data/merged_data-finalest/cached-audio-pitch-16k/train'
db = SequenceDataset(db)

In [5]:
def get_calib_data(db, time_range = (0, 1), group_size = 16):
    calib_data = []
    for i in rand_inds:
        spec, pitch, singer = db[i]
        time = torch.rand((1,))
        calib_data.append((spec, pitch, singer, time))
    
    spec_vals = torch.stack([x[0] for x in calib_data])
    spec_vals = torch.reshape(spec_vals, (spec_vals.shape[0]//group_size, group_size, spec_vals.shape[1], spec_vals.shape[2]))
    pitch_vals = torch.stack([x[1] for x in calib_data])
    pitch_vals = torch.reshape(pitch_vals, (pitch_vals.shape[0]//group_size, group_size, pitch_vals.shape[1]))
    singer_vals = torch.stack([x[2] for x in calib_data])
    singer_vals = torch.reshape(singer_vals, (singer_vals.shape[0]//group_size, group_size, singer_vals.shape[1]))
    time_vals = torch.stack([x[3] for x in calib_data])
    time_vals = torch.reshape(time_vals, (time_vals.shape[0]//group_size, group_size, time_vals.shape[1]))

    # Combine values for each batch
    grouped_batches = [
        (spec_vals[i], pitch_vals[i], singer_vals[i], time_vals[i]) 
        for i in range(spec_vals.shape[0])
    ]
    return grouped_batches

In [6]:
data = get_calib_data(db)

In [7]:
def process_input(batch):
    # pdb.set_trace()
    spec, pitch, singer, time = batch
    padded_x, padding = audio_model.pad_to(spec, audio_model.strides_prod)
    padded_f0, _ = audio_model.pad_to(pitch, audio_model.strides_prod)
    # padded_singer, _ = audio_model.pad_to(singer, audio_model.strides_prod)
    noise = torch.normal(0, 1, padded_x.shape).to(padded_x)
    # print(t.device, noise.device, x.device)
    x_t = time[:, None] * padded_x + (1 - time[:, None]) * noise
    return x_t, padded_f0

In [None]:
def get_calib_feat(model, tokenizer): # doesn't handle groups rn
    input_dict = dict()
    def stat_input_max_hook_conv(m, x, y, name):
        # pdb.set_trace()
        if isinstance(x, tuple):
            x = x[0]
        # input is of shape (batch_size, n_channels, n_frames), reshape so that channels is the last dimension
        x_max = x.transpose(1, 2).contiguous().view(-1, x.shape[1]).abs().mean(dim=0).cpu().detach() # collecting the average max value per output channel
        if name not in input_dict:
            input_dict[name] = x_max
        else:
            input_dict[name] += x_max
    
    def stat_input_max_hook_linear(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        x_max = x.view(-1, x.shape[-1]).abs().mean(dim=0).cpu().detach()
        if name not in input_dict:
            input_dict[name] = x_max
        else:
            input_dict[name] += x_max

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): # add more layers here
            hooks.append(
                m.register_forward_hook(
                    partial(stat_input_max_hook_conv, name=name)))
            
        elif isinstance(m, torch.nn.Linear):
            hooks.append(
                m.register_forward_hook(
                    partial(stat_input_max_hook_linear, name=name)))

    print("Collecting activation scales...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    samples = get_calib_data(tokenizer)
    
    pbar = tqdm(samples)
    for input_vals in pbar:
        input_spec, f0 = process_input(input_vals)
        input_spec = input_spec.to(device)
        f0 = f0.to(device)
        time = input_vals[3].to(device).reshape(-1)
        singer = input_vals[2].to(device).int().reshape(-1)
        output = audio_model(x=input_spec, f0=f0, singer=singer, time=time, drop_tokens=False)

    for hook in hooks:
        hook.remove()
    return input_dict

In [9]:
input_feat = get_calib_feat(audio_model, db)

Collecting activation scales...


  0%|          | 0/8 [00:00<?, ?it/s]

100%|██████████| 8/8 [00:00<00:00, 20.71it/s]


# Calibrate weights

In [46]:
import copy
# reloads the model 
del audio_model
torch.cuda.empty_cache()

audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(audio_path=audio_model_path,
    qt_path=audio_qt_path,
    config_path='/home/mila/n/nithya.shikarpur/GaMaDHaNi-dev/configs/pitch_to_audio_config.gin',
device=device)

orig_model = copy.deepcopy(audio_model)

  ckpt = torch.load(ckpt, map_location=device)


In [47]:
x = input_feat['initial_projection']

In [48]:
ratio = 1/2
s_x = x


In [49]:
def quantize_per_tensor(x, bits=6, min_val=None, max_val=None):
    if min_val is None:
        min_val = x.min()
    if max_val is None:
        max_val = x.max()
    targets = torch.linspace(min_val, max_val, 2**bits).to(x.device)
    differences = torch.abs(x.unsqueeze(-1) - targets)
    nearest_indices = torch.argmin(differences, dim=-1)
    rounded_values = targets[nearest_indices]
    return rounded_values

In [50]:
scales = s_x
scales = torch.clamp(scales, min=1e-4)
for n, m in audio_model.named_modules():
    if n in input_feat.keys():
        activations = input_feat[n]
        scales = activations**ratio
        if isinstance(m, torch.nn.Conv1d):
            transpose_weight = torch.transpose(m.weight, 1, 2) # to keep the last dimension as the output channel dimension
            max_weight_per_channel = (m.weight.reshape(-1, transpose_weight.shape[-1]).abs().max(dim=0)[0])**(1-ratio)
            # pdb.set_trace()
            scales = scales / max_weight_per_channel
            
            quantized_tensor = quantize_per_tensor(m.weight.data.mul_(scales[None, :, None]))
            m.weight.data = quantized_tensor.div_(scales[None, :, None])
    

In [51]:
# preprocess pitch to feed it as conditioning to the spectrogram generator

# parse pitch config
gin.parse_config_file('/home/mila/n/nithya.shikarpur/GaMaDHaNi-dev/configs/diffusion_pitch_config.gin')
Task_ = gin.get_configurable('src.dataset.Task')
task_obj = Task_()
pitch_task_fn = partial(task_obj.read_)
invert_pitch_task_fn = partial(task_obj.invert_)

In [52]:
primes = np.load('/home/mila/n/nithya.shikarpur/scratch/final-data-ismir/data/merged_data-finalest/listening_study_primes.npz', allow_pickle=True)

In [53]:
processed_pitch_val = [pitch_task_fn(**{"inputs": {"pitch": {"data": val[0][:2400]}}})['sampled_sequence'] for val in primes['concatenated_array']]
processed_pitch_val = torch.Tensor(np.stack(processed_pitch_val)).reshape(len(processed_pitch_val), 1, processed_pitch_val[0].shape[0])
interpolated_pitch = p2a.interpolate_pitch(pitch=processed_pitch_val, audio_seq_len=audio_seq_len)    # interpolate pitch values to match the audio model's input size
interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)  # replace nan values with silent token
interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
f0 = interpolated_pitch.float()
audio_model = audio_model.to(device)
print(f0.shape)

torch.Size([16, 750])


In [54]:
singer_tensor = torch.tensor(np.repeat([3], repeats=f0.shape[0])).to(audio_model.device)

In [55]:
samples, _, singers, _ = audio_model.sample_cfg(f0.shape[0], f0=f0, num_steps=100, singer=singer_tensor, strength=3, invert_audio_fn=invert_audio_fn)

In [56]:
audio = invert_audio_fn(samples)

100%|██████████| 200/200 [00:00<00:00, 758.79it/s]


In [57]:
ipd.Audio(audio[1].detach().cpu().numpy(), rate=16000)