## This notebooks is used for extracting sonifications for audioset models

In [1]:
import glob
import matplotlib.pyplot as plt
import torchvision
import librosa
import random
import json
import pickle
import numpy as np
import torch
import os
torch.multiprocessing.set_sharing_strategy('file_system')
import tqdm
import gc
import sys
sys.path.append("/home/user/Research/sonifications-paper1517/")
from src.utilities import fourier_analysis
from src.utilities import interpretability_utils
from src.models.contrastive_model import Model
from src.utilities.config_parser import get_config
from src.data.raw_transforms import get_raw_transforms_v2
from src.data.raw_dataset import RawWaveformDataset
from torch.utils.data import DataLoader
from src.utilities.fourier_analysis import apply_notch_filter
from src.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
import IPython.display as ipd

from src.models.cnn12_decoder import DeconvolutionalDecoder
from src.utilities.fourier_analysis import compute_fft, get_top_n_frequency_peaks
from sklearn.preprocessing import MinMaxScaler, minmax_scale
import soundfile as sf
import matplotlib.pyplot as plt
# plt.rcParams["figure.figsize"] = (30,40)
import random
import json
import tqdm
import numpy as np
from src.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
from src.data.raw_transforms import get_raw_transforms_v2, simple_supervised_transforms
from src.data.raw_dataset import RawWaveformDataset
from torch.utils.data import DataLoader
cmaps = [(random.random(), random.random(), random.random()) for i in range(1, 9)]



In [2]:
def get_spec(x):
    _, _, data_spec, _ = fourier_analysis.perform_stft(x, noverlap_ms=0.01, nfft=512, boundary="zeros")
    spec = np.abs(data_spec) ** 2
    spec = librosa.amplitude_to_db(spec)
    return spec

In [3]:
def make_dataloader(meta_dir, audio_config, crop_size=5, csv_name=None, mode="multilabel", delim=";"):
    lbl_map_path = os.path.join(meta_dir, "lbl_map.json")
    if csv_name:
        eval_csv_path = os.path.join(meta_dir, csv_name)
    else:
        eval_csv_path = os.path.join(meta_dir, "sonification_eval.csv")

    crop_size = int(crop_size * audio_config['sample_rate'])
    val_tfs = simple_supervised_transforms(False, crop_size,
                                           sample_rate=audio_config['sample_rate'])
    val_set = RawWaveformDataset(eval_csv_path,
                                 lbl_map_path,
                                 audio_config, mode=mode,
                                 transform=val_tfs, is_val=True, delimiter=delim
                                 )
    collater = _collate_fn_raw if mode == "multilabel" else _collate_fn_raw_multiclass
    val_loader = DataLoader(val_set, sampler=None, num_workers=1,
                            collate_fn=collater,
                            shuffle=False, batch_size=1,
                            pin_memory=False)
    with open(lbl_map_path, "r") as fd:
        lbl_map = json.load(fd)
    inv_lbl_map = {v: k for k, v in lbl_map.items()}

    return val_loader, val_set, lbl_map, inv_lbl_map


def analyze_random_maps(layer_index, loader, dataset, inv_lbl_map,
                        net, deconv,
                        num_random_maps=0.1, top_per_map=9, to_exclude=[]):
    assert layer_index >= 1 and layer_index <= 11
    features_of_interest = {}
    random_maps = None
    # data = []
    # gts = []
    used_data_points = []
    cnt = 0
    skipped = 0
    # for batch in tqdm.notebook.tqdm_notebook(loader, position=1):
    for ix in tqdm.notebook.tqdm_notebook(range(len(dataset)), position=1):
        x, y = dataset[ix]
        # print(x.shape, y.shape)
        # x, _, y = batch
        x = x.unsqueeze(0)
        min_ = x.min()
        max_ = x.max()
        if min_ < -1 and max_ > 1:
            print("IN ANALYZIE RANDOM MAX, INPUT MIN, MAX:", x.min(), x.max())
        idxs = torch.where(y == 1)[0].tolist()
        skip_flag = False
        for idx in idxs:
            if idx in to_exclude:
                skip_flag = True
                break
        if skip_flag:
            skipped += 1
            continue
        lbls = ";".join([inv_lbl_map[lbl_idx] for lbl_idx in idxs])
        # data.append(x)
        # gts.append(lbls)
        x = x.cuda()
        output_features, switch_indices = interpretability_utils.infer_model(net, x)
        act_feats = output_features['act{}'.format(layer_index)]
        if random_maps is None:
            num_maps = int(num_random_maps * act_feats.shape[1])
            random_maps = np.random.permutation(act_feats.shape[1])[:num_maps]
            for ix in random_maps:
                features_of_interest[ix] = []
        for m in random_maps:

            features_of_interest[m].append(act_feats[0, m, :].detach().cpu().mean())
        cnt += 1
        used_data_points.append(ix)

    print("Skipped:", skipped)
    indices = {}
    for k, values in features_of_interest.items():
        mean_activations = torch.tensor(values)
        idxs = torch.argsort(mean_activations, descending=True)[:top_per_map]
        indices[k] = idxs

    outputs = {}
    for k in indices.keys():
        outputs[k] = []

    for k, idxs in indices.items():
        for idx in idxs:
            # inp = data[idx].cuda()
            inp, y = dataset[used_data_points[idx]]
            label_indicators = torch.where(y == 1)[0].tolist()
            gt = ";".join([inv_lbl_map[lbl_idx] for lbl_idx in label_indicators])
            inp = inp.unsqueeze(0).cuda()
            with torch.no_grad():
                pred, output_features, switch_indices = net(inp, True)
                vis = deconv.visualize_specific_map(inp, output_features, switch_indices, layer_index, k)
            outputs[k].append({
                # "data": data[idx].squeeze().cpu().numpy(),
                "data_idx": used_data_points[idx],
                "vis": interpretability_utils.process_vis(vis.squeeze(), inp.squeeze().cpu().numpy()),
                "gt": gt
            })

    return indices, outputs


def process_top_n(data, dataset, inv_lbl_map, save_dir, save_plots=False):
    name1 = os.path.join(save_dir, "waveforms.png")
    name2 = os.path.join(save_dir, "input_spectrograms.png")
    name3 = os.path.join(save_dir, "sonified_spectrograms.png")
    
    def save_specs_and_audio():
        input_spec_dir = os.path.join(save_dir, "input_spectrograms")
        deconv_spec_dir = os.path.join(save_dir, "deconv_spectrograms")
        
        input_audio_dir = os.path.join(save_dir, "input_audio")
        deconv_audio_dir = os.path.join(save_dir, "deconv_audio")
        
        flds = [input_spec_dir, input_audio_dir, deconv_spec_dir, deconv_audio_dir]
        for fld in flds:
            if not os.path.exists(fld):
                os.makedirs(fld)
        
        for i in range(len(data)):
            d_i = data[i]
            gt = d_i['gt']
            x, y = dataset[d_i['data_idx']]
            # print("in save specs and audio, input min max:", x.min(), x.max())
            input_spec = get_spec(x.squeeze().numpy())
            vis_spec = get_spec(d_i['vis'])
            np.save(os.path.join(input_spec_dir, "input_{:02d}.npy".format(i)), input_spec)
            np.save(os.path.join(deconv_spec_dir, "deconv_{:02d}.npy".format(i)), vis_spec)
            
            x_data = x.squeeze().numpy()
            sf.write(os.path.join(input_audio_dir, "input_{:02d}.wav".format(i)), x_data,
                     samplerate=8000)
            # np.save(os.path.join(input_audio_dir, "input_{:02d}.npy".format(i)), x_data)
            vis_data = d_i['vis']
            sf.write(os.path.join(deconv_audio_dir, "deconv_{:02d}.wav".format(i)), vis_data,
                     samplerate=8000)
            # np.save(os.path.join(deconv_audio_dir, "deconv_{:02d}.npy".format(i)), vis_data)
    
    def plot_waves():
        fig1 = plt.figure(figsize=(20, 20))
        # fig1.title(title)
        rows = 3
        cols = 3
        for i in range(9):
            d_i = data[i]
            gt = d_i['gt']
            x, y = dataset[d_i['data_idx']]
            # label_indicators = torch.where(y == 1)[0].tolist()
            # lbls = ";".join([inv_lbl_map[lbl_idx] for lbl_idx in label_indicators])
            # print(lbls, gt)
            # assert lbls == gt
            fig1.add_subplot(rows, cols, i + 1)
            plt.plot(x.squeeze().numpy())
            plt.plot(d_i['vis'])
            plt.axis("off")
        plt.tight_layout()
        plt.savefig(name1)
        plt.close()
        fig1.clf()
        plt.close(fig1)
        plt.clf()
        del fig1
        gc.collect()
    
    
    def plot_input_specs():
        cnt = 0
        specs_data = []
        for i in range(3):
            for j in range(3):
                d_i = data[cnt]
                gt = d_i['gt']
                x, y = dataset[d_i['data_idx']]
                cnt += 1
                _, _, data_spec, _ = fourier_analysis.perform_stft(x.squeeze().numpy(), noverlap_ms=0.01, nfft=512)
                # print(data_spec.shape)
                spec = np.abs(data_spec) ** 2
                spec = librosa.amplitude_to_db(spec)
                spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
                specs_data.append(spec)
        specs = torch.cat(specs_data)
        e = torchvision.utils.make_grid(specs, 3, 3)
        gt = e.permute(1, 2, 0)[:, :, 0].numpy()
        plt.imshow(gt)
        plt.title("Input")
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(name2, dpi=150)
        plt.close()
        plt.clf()
        gc.collect()
    
    def plot_vis_specs():
        cnt = 0
        specs_data = []
        for i in range(3):
            for j in range(3):
                d_i = data[cnt]
                cnt += 1
                _, _, data_spec, _ = fourier_analysis.perform_stft(d_i['vis'], noverlap_ms=0.01, nfft=512)
                # print(data_spec.shape)
                spec = np.abs(data_spec) ** 2
                spec = librosa.amplitude_to_db(spec)
                spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
                specs_data.append(spec)
        specs = torch.cat(specs_data)
        e = torchvision.utils.make_grid(specs, 3, 3)
        dec = e.permute(1, 2, 0)[:, :, 0].numpy()
        # fig2.add_subplot(rows, cols, 2)
        plt.imshow(dec)
        plt.axis("off")
        plt.title("Sonified")
        plt.tight_layout()
        plt.savefig(name3, dpi=150)
        plt.close()
        plt.clf()
        gc.collect()
    
    if save_plots:
        plot_waves()
        plot_input_specs()
        plot_vis_specs()
    
        plt.clf()
        plt.close("all")
        gc.collect()
    
    save_specs_and_audio()
    
    label_lines = []
    for i in range(9):
        label_lines.append("{}\n".format(data[i]['gt']))

    with open(os.path.join(save_dir, "gt.txt"), "w") as fd:
        fd.writelines(label_lines)

In [4]:
def evaluate(exp_dir, is_contrastive, num_random_maps=0.1, 
             output_dir_name="featuremap_expection", start_layer=1, 
             save_plots=False, last_epoch=None):
    EXP_DIR = exp_dir
    if is_contrastive:
        res = interpretability_utils.prep_contrastive_model_and_decoder(EXP_DIR)
    else:
        if last_epoch is None:
            last_epoch = 50
        res = interpretability_utils.prep_finetuned_model_and_decoder(EXP_DIR, last_epoch=last_epoch)
        
    model, net, deconv, hparams = res['full_model'], res['feature_extractor'], res['deconv_decoder'], res['hparams']

    loader, dset, lbl_map, inv_lbl_map = make_dataloader("/media/user/nvme/datasets/audioset/meta_8000/",
                                                         hparams.cfg['audio_config'], csv_name="eval.csv")
    to_exclude = ['Sine wave', 'Static']
    for k in lbl_map.keys():
        if "noise" in k.lower():
            to_exclude.append(k)
    to_exclude = [lbl_map[ex] for ex in to_exclude]
    base_output_dir = os.path.join(EXP_DIR, output_dir_name)
    if not os.path.exists(base_output_dir):
        os.makedirs(base_output_dir)
    for layer in tqdm.notebook.tqdm_notebook(range(start_layer, 12), position=0):
        _, results = analyze_random_maps(layer, loader, dset, inv_lbl_map,
                                         net, deconv, num_random_maps, to_exclude=to_exclude)
        for k, v in tqdm.notebook.tqdm_notebook(results.items(), position=2):
            tgt_dir = os.path.join(base_output_dir, "{:02d}".format(layer), "{:04d}".format(k))
            if not os.path.exists(tgt_dir):
                os.makedirs(tgt_dir)
            process_top_n(v, dset, inv_lbl_map, tgt_dir, save_plots=save_plots)

In [5]:
# evaluate("/media/user/nvme/contrastive_experiments/experiments_audioset_v5_full/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_full_ft_fullmodel_r2",
#          False, output_dir_name="inspection_all_maps_f", num_random_maps=1., start_layer=1)

In [13]:
evaluate("/media/user/nvme/contrastive_experiments/experiments_contrastive_v5/cnn12_1x_full_tr_8x256_Adam_1e-3_warmupcosine_0.5_wd1e-5_fixed_lr_scaling_randomgain_gaussiannoise_timemasking_bgnoise_nolineareval_rs8882",
         True, output_dir_name="inspection_all_maps_f", num_random_maps=1., start_layer=1)

Loading /media/user/nvme/contrastive_experiments/experiments_contrastive_v5/cnn12_1x_full_tr_8x256_Adam_1e-3_warmupcosine_0.5_wd1e-5_fixed_lr_scaling_randomgain_gaussiannoise_timemasking_bgnoise_nolineareval_rs8882/ckpts/epoch=100_tr_loss=1.122735_tr_acc=0.854891.pth


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/2048 [00:00<?, ?it/s]

In [14]:
evaluate("/media/user/nvme/contrastive_experiments/experiments_contrastive_v5/cnn12_1x_full_tr_8x256_Adam_1e-3_warmupcosine_0.5_wd1e-5_fixed_lr_scaling_randomgain_gaussiannoise_timemasking_bgnoise_nolineareval_rs8883",
         True, output_dir_name="inspection_all_maps_f", num_random_maps=1., start_layer=1)

Loading /media/user/nvme/contrastive_experiments/experiments_contrastive_v5/cnn12_1x_full_tr_8x256_Adam_1e-3_warmupcosine_0.5_wd1e-5_fixed_lr_scaling_randomgain_gaussiannoise_timemasking_bgnoise_nolineareval_rs8883/ckpts/epoch=100_tr_loss=1.120546_tr_acc=0.854796.pth


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/2048 [00:00<?, ?it/s]

In [15]:
s = "/media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8881"
evaluate(s, False, output_dir_name="inspection_all_maps_f", 
         num_random_maps=1., start_layer=1, last_epoch=100)

Loading /media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8881/ckpts/epoch=100_tr_loss=0.010972_tr_acc=0.516805_val_acc=0.298665.pth


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/2048 [00:00<?, ?it/s]

In [16]:
s = "/media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8882_noagc"
evaluate(s, False, output_dir_name="inspection_all_maps_f", 
         num_random_maps=1., start_layer=1, last_epoch=100)

Loading /media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8882_noagc/ckpts/epoch=100_tr_loss=0.010980_tr_acc=0.515919_val_acc=0.298523.pth


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/512 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/2048 [00:00<?, ?it/s]

In [6]:
s = "/media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8883_noagc"
evaluate(s, False, output_dir_name="inspection_all_maps_f", 
         num_random_maps=1., start_layer=11, last_epoch=100)

Loading /media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8883_noagc/ckpts/epoch=100_tr_loss=0.010980_tr_acc=0.515930_val_acc=0.299477.pth


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/17498 [00:00<?, ?it/s]

Skipped: 564


  0%|          | 0/2048 [00:00<?, ?it/s]

In [12]:
# evaluate("/media/user/nvme/contrastive_experiments/experiments_audioset_v5_full/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline/",
#          False, output_dir_name="inspection_all_maps_f", num_random_maps=1., start_layer=1)