In [None]:
# Imports
import numpy as np
import torch
import torchaudio
import matplotlib.pyplot as plt
import IPython.display as ipd
import gin
from functools import partial
import copy
from gamadhani.utils.utils import download_models, download_data
from gamadhani.utils.generate_utils import load_audio_fns
from gamadhani.utils.utils import get_device
import gamadhani.utils.pitch_to_audio_utils as p2a


# Load Model
_, _, audio_path, audio_qt = download_models('kmaneeshad/GaMaDHaNi', pitch_model_type="diffusion")
base_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(audio_path=audio_path, qt_path=audio_qt, config_path='../configs/pitch_to_audio_config.gin')

# Parse pitch config
gin.parse_config_file('../configs/diffusion_pitch_config.gin')
Task_ = gin.get_configurable('src.dataset.Task')
task_obj = Task_()
pitch_task_fn = partial(task_obj.read_)
invert_pitch_task_fn = partial(task_obj.invert_)


# Prepare Test Set
test_set = np.load(download_data('kmaneeshad/GaMaDHaNi-db'), allow_pickle=True)['concatenated_array']
n_test = test_set.shape[0]
test_pitches = np.array([test_set[i,0][:2400] for i in range(n_test)])
test_audios = np.array([test_set[i,1] for i in range(n_test)])
test_audios_resampled_truncated = torch.stack([torchaudio.functional.resample(torch.tensor(test_audios[i]), orig_freq=44100, new_freq=16000) for i in range(n_test)])[...,:191744]
test_singer_ids = np.array([test_set[i,4] for i in range(n_test)])

In [2]:
def get_f0_input(pitch_val):
    processed_pitch_val = pitch_task_fn(**{"inputs": {"pitch": {"data": pitch_val}}})['sampled_sequence']
    processed_pitch_val = torch.Tensor(processed_pitch_val).reshape(1, 1, processed_pitch_val.shape[0])
    interpolated_pitch = p2a.interpolate_pitch(pitch=processed_pitch_val, audio_seq_len=audio_seq_len)    # interpolate pitch values
    interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)
    interpolated_pitch = interpolated_pitch.squeeze(1)
    f0 = interpolated_pitch.float()
    return f0

def generate_audio(model, pitch_val, singer_id, num_steps=100, audio_only=True):
    model = model.cuda()
    f0 = get_f0_input(pitch_val).to(model.device)
    singer_tensor = torch.tensor(np.repeat([singer_id], repeats=f0.shape[0])).to(model.device)

    if audio_only:
        out_spec, _, _, _ =  model.sample_cfg(f0.shape[0], f0=f0, num_steps=num_steps, singer=singer_tensor, strength=3, invert_audio_fn=invert_audio_fn)
        audio = invert_audio_fn(out_spec)
        return audio
    else:
        out_spec, out_pitch, singer_id, all_activations = model.sample_cfg(f0.shape[0], f0=f0, num_steps=num_steps, singer=singer_tensor, strength=3, invert_audio_fn=invert_audio_fn, log_interim_samples=True, log_interim_forward_activations=True)
        audio = invert_audio_fn(out_spec)
        return out_spec, out_pitch, singer_id, all_activations, audio
  
def generate_all_audios(model, num_steps):
  generated_audios = []
  for i in range(n_test):
     generated_audios.append(generate_audio(model, test_pitches[i], test_singer_ids[i], num_steps=num_steps))
  return torch.cat(generated_audios, dim=0)
    

def quantize_per_tensor(x, bits=4, min_val=None, max_val=None):
    if min_val is None:
        min_val = x.min()
    if max_val is None:
        max_val = x.max()
    targets = torch.linspace(min_val, max_val, 2**bits).to(x.device)
    differences = torch.abs(x.unsqueeze(-1) - targets)
    nearest_indices = torch.argmin(differences, dim=-1)
    rounded_values = targets[nearest_indices]
    return rounded_values


def get_input_ranges(in_dim, layer_type):
  if layer_type == "conv":
    if in_dim == 512:
      return [range(0, 256),
              range(256, 384),
              range(384, 512)]
    elif in_dim == 768:
      return [range(0, 512),
              range(512, 640),
              range(640, 768)]
    elif in_dim == 896:
      return [range(0, 640),
              range(640, 768),
              range(768, 896)]
    else:
      return [range(0, in_dim)]
    
  elif layer_type == "convtranspose":
    if in_dim == 2432:
      return [range(0, 1024),
              range(1024, 1152),
              range(1152, 1280),
              range(1280, 2304),
              range(2304, 2432)]
    elif in_dim == 1664:
      return [range(0, 640),
              range(640, 768),
              range(768, 896),
              range(896, 1536),
              range(1536, 1664)]
    elif in_dim == 1408:
      return [range(0, 512),
              range(512, 640),
              range(640, 768),
              range(768, 1280),
              range(1280, 1408)]
    elif in_dim == 640:
      return [range(0, 256),
              range(256, 512),
              range(512, 640)]
    else:
      return [range(0, in_dim)] 



def quantize_model(net, bits, mode='global'):
  quantized_net = copy.deepcopy(net)
  for name, module in quantized_net.named_modules():
    if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear, torch.nn.ConvTranspose1d)):
      if mode == 'global':
        module.weight.data = quantize_per_tensor(module.weight.data, bits=bits)
      
      elif mode == 'per_channel':        
        if isinstance(module, (torch.nn.Conv1d)):
          for i in range(module.weight.data.shape[0]):
            module.weight.data[i] = quantize_per_tensor(module.weight.data[i], bits=bits)
        elif isinstance(module, (torch.nn.ConvTranspose1d)):
          for i in range(module.weight.data.shape[1]):
            module.weight.data[:,i] = quantize_per_tensor(module.weight.data[:,i], bits=bits)
        else:
          module.weight.data = quantize_per_tensor(module.weight.data, bits=bits)   

      elif mode == 'grouped_input_channels':       
        if isinstance(module, (torch.nn.Conv1d)):
          
          num_output_channels = module.weight.data.shape[0]
          num_input_channels = module.weight.data.shape[1]
          input_ranges = get_input_ranges(num_input_channels, "conv")
          
          for i in range(num_output_channels):
            for input_range in input_ranges:
              print(input_range)
              module.weight.data[i, input_range] = quantize_per_tensor(module.weight.data[i, input_range], bits=bits)
            
        elif isinstance(module, (torch.nn.ConvTranspose1d)):
          num_output_channels = module.weight.data.shape[1]
          num_input_channels = module.weight.data.shape[0]
          input_ranges = get_input_ranges(num_input_channels, "convtranspose")
          
          for i in range(num_output_channels):
            for input_range in input_ranges:
              #print(input_range)
              module.weight.data[input_range,i] = quantize_per_tensor(module.weight.data[input_range,i], bits=bits)
        else:
          module.weight.data = quantize_per_tensor(module.weight.data, bits=bits)   
      else:
         raise ValueError

  if hasattr(module, 'bias') and module.bias is not None and mode == 'per_channel':
      module.bias.data = quantize_per_tensor(module.bias.data, bits=bits)
  return quantized_net

# Quantization Generation Experiments

In [None]:
base_generated = generate_all_audios(base_model, 100)
torch.save(torch.nan_to_num(base_generated,0), "examples/base_model.pt")
torch.save(test_audios_resampled_truncated, "examples/ground_truth.pt")

In [None]:
name = "per_chan_8_bit"
q_model = quantize_model(base_model, bits=8, mode='per_channel')
generated = generate_all_audios(q_model, 100)
torch.save(torch.nan_to_num(generated,0), f"examples/{name}.pt")

for i in range(16):
    display(ipd.Audio(generated[i].cpu(), rate=16000))

In [None]:
name = "per_group_8_bit"
q_model = quantize_model(base_model, bits=8, mode='grouped_input_channels')
generated = generate_all_audios(q_model, 100)
torch.save(torch.nan_to_num(generated,0), f"examples/{name}.pt")

for i in range(16):
    display(ipd.Audio(generated[i].cpu(), rate=16000))

In [None]:
name = "per_tensor_8_bit"
q_model = quantize_model(base_model, bits=8, mode='global')
generated = generate_all_audios(q_model, 10)
#torch.save(torch.nan_to_num(generated,0), f"examples/{name}.pt")

# for i in range(16):
#     display(ipd.Audio(generated[i].cpu(), rate=16000))

In [None]:
name = "per_tensor_4_bit"
q_model = quantize_model(base_model, bits=4, mode='global')
generated = generate_all_audios(q_model, 3)
torch.save(torch.nan_to_num(generated,0), f"examples/{name}.pt")

for i in range(16):
    display(ipd.Audio(generated[i].cpu(), rate=16000))

In [None]:
name = "per_chan_4_bit"
q_model = quantize_model(base_model, bits=4, mode='per_channel')
generated = generate_all_audios(q_model, 4)
torch.save(torch.nan_to_num(generated,0), f"examples/{name}.pt")

for i in range(16):
    display(ipd.Audio(generated[i].cpu(), rate=16000))

In [None]:
name = "per_group_4_bit"
q_model = quantize_model(base_model, bits=4, mode='grouped_input_channels')
generated = generate_all_audios(q_model, 4)
torch.save(torch.nan_to_num(generated,0), f"examples/{name}.pt")

for i in range(16):
    display(ipd.Audio(generated[i].cpu(), rate=16000))

# Visualizations

In [None]:
out = generate_audio(base_model, pitch_val=test_pitches[0], singer_id=test_singer_ids[0], num_steps=100, audio_only=False)
activations = out[3][2]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import torch

mags = torch.abs(activations[0.50]['upsample_layer_input_0'])[0].cpu().T

def plot_activations(mags, title="", figsize=(12, 12), dpi=300, mode='weight'):
    
    num_timeframes = mags.shape[0]
    num_channels = mags.shape[1]

    ranges = get_input_ranges(num_channels, "convtranspose")
    
    # Create figure with specific layout parameters
    fig = plt.figure(figsize=figsize, dpi=dpi)
    
    plt.subplots_adjust(top=0.9, bottom=0.15, hspace=0.3, wspace=0.3)
    
    # Adjust subplot parameters to give more space
    #plt.subplots_adjust(top=0.9, bottom=0.15)
    
    ax = fig.add_subplot(111, projection="3d")

    x, y = np.meshgrid(
        np.arange(num_channels),
        np.arange(num_timeframes),
    )

    # Custom legend handles
    legend_handles = []
    print(len(ranges))
    names = ['Prev. Activation','Time-Step Condition', 'Singer ID Condition', 'Skip Connection', 'Pitch Condition']

    # Loop through ranges and plot each channel group with a different color
    for i, curr_range in enumerate(ranges):
        z = mags[:, curr_range]
        x_range, y_range = np.meshgrid(
            np.arange(curr_range.start, curr_range.stop),
            np.arange(num_timeframes),
        )
        # Use a perceptually uniform colormap for better distinction
        color_map = plt.cm.get_cmap("viridis", len(ranges))
        color = color_map(i)

        ax.plot_surface(
            x_range, y_range, z,
            color=color,
            alpha=0.7,  # Slightly reduced alpha for depth perception
            edgecolor="none",
            rstride=1,
            cstride=1,
        )

        # Add a legend entry for the current range
        legend_handles.append(mpatches.Patch(color=color, label=names[i]))

    # Styling improvements
    
    if mode=='weight':
        ax.set_xlabel("Input Channel", fontweight='bold', fontsize=24)
        ax.set_ylabel("Output Channel", fontweight='bold', fontsize=24)
        ax.set_zlabel("Avg. Weight Magnitude", fontweight='bold', fontsize=24)
    else:
        ax.set_xlabel("Channel", fontweight='bold', fontsize=24)
        ax.set_ylabel("Time Frame", fontweight='bold', fontsize=24)
        ax.set_zlabel("Activation Magnitude", fontweight='bold', fontsize=24)
    ax.view_init(elev=20, azim=260)
    

    # Add title if provided, with reduced spacing
    if title:
        fig.suptitle(title, fontweight='bold', fontsize=30, y=0.87) 

    # Add a legend closer to the plot
    plt.legend(
        handles=legend_handles, 
        loc='upper center', 
        bbox_to_anchor=(0.5, 0.05),  # Moved closer to the plot
        frameon=False, 
        ncol=2,
        fontsize=24
    )

    # Reduce whitespace and adjust layout
    plt.tight_layout()

    # Save as transparent SVG
    plt.savefig(f'{title}_activations.png', 
                transparent=True, 
                bbox_inches='tight', 
                pad_inches=0)

    # Optional: show the plot   
    plt.show()

# Example function call
plot_activations(mags, 'Inputs to First Upsampling Layer', mode='asdf')


In [241]:
weights = torch.mean(torch.abs(base_model.upsample_layers[0].process_layer.convs[0].conv.weight.data.cpu()), dim=-1).T

In [None]:
plot_activations(weights, "First ConvTranpose", mode='weight')

# Box and Whiskey

In [None]:
out = generate_audio(base_model, pitch_val=test_pitches[0], singer_id=test_singer_ids[0], num_steps=100, audio_only=False)

plt.figure(figsize=(14, 8), dpi=300)
plt.boxplot([interim_activations[i].cpu().numpy().flatten() for i in list(interim_activations.keys())[::5]], 
            vert=True, 
            tick_labels=[int(k*100) for k in list(interim_activations.keys())[::5]],
            flierprops=dict(alpha=0.1, marker='o'))



plt.title("Activations Over Diffusion Steps", fontsize=30)
plt.xlabel("Diffusion Step", fontsize=24)
plt.ylabel("Activation Values", fontsize=24)
plt.tick_params(axis='both', which='major', labelsize=20)  # For major ticks

plt.tight_layout()

# Save with transparency
plt.savefig('activations_boxplot_nb.svg', dpi=300, transparent=True, bbox_inches='tight')

# Divergence of Quantization

In [None]:
trajectories_dict = {}

for num_bits in [8,4]:
    for quantization_method in ['global', 'per_channel', 'grouped_input_channels']:
        baseline_name = quantization_method+"_"+str(num_bits)
        q_model = quantize_model(base_model, bits=num_bits, mode=quantization_method)
        out = generate_audio(q_model, pitch_val=test_pitches[0], singer_id=test_singer_ids[0], num_steps=100, audio_only=False)
        samples, _, singers, (interim_activations, _, all_activation), _ = out

        relevant_keys = ['downsample_layer_0', 'downsample_layer_1', 'downsample_layer_2', 'upsample_layer_0', 'upsample_layer_1', 'upsample_layer_2']
        activation_list = []

        for t in np.linspace(0.0, 0.99, 100):
            for relkey in relevant_keys:
                activation_list.append(torch.mean(torch.abs(all_activation[t][relkey])).cpu())
        
        trajectories_dict[baseline_name] = torch.stack(activation_list)

In [None]:
out = generate_audio(base_model, pitch_val=test_pitches[0], singer_id=test_singer_ids[0], num_steps=100, audio_only=False)
samples, _, singers, (interim_activations, _, all_activation), _ = out

relevant_keys = ['downsample_layer_0', 'downsample_layer_1', 'downsample_layer_2', 'upsample_layer_0', 'upsample_layer_1', 'upsample_layer_2']
activation_list = []

for t in np.linspace(0.0, 0.99, 100):
    for relkey in relevant_keys:
        activation_list.append(torch.mean(torch.abs(all_activation[t][relkey])).cpu())

trajectories_dict['base_model'] = torch.stack(activation_list)

In [None]:
plt.figure(figsize=(7, 5), dpi=300)



for name in ['base_model', 'global_8','grouped_input_channels_8']:
    
    if name == 'base_model':
        realname = "Base"
    elif name == 'global_8':
        realname = "Per-Tensor"
    else:
        realname = "Ours"
    traj = trajectories_dict[name]
    plt.plot(np.linspace(0, 99, 100), traj[::6], label = realname)

plt.legend(fontsize=20)
#plt.yscale('log')
#plt.xscale('log')  
plt.ylim(0,5)
plt.title("Activations Over Time", fontsize=30)
plt.xlabel("Diffusion Step", fontsize=24)
plt.ylabel("Activation Mag.", fontsize=24)
plt.tight_layout()

plt.tick_params(axis='both', which='major', labelsize=16)  # For major ticks


# Save with transparency
plt.savefig('activations_over_time_nb.svg', dpi=300, transparent=True, bbox_inches='tight')
