In [1]:
import glob
import matplotlib.pyplot as plt
import torchvision
import librosa
import random
import json
import pickle
import numpy as np
import torch
import os
torch.multiprocessing.set_sharing_strategy('file_system')
import tqdm
import gc
import sys
sys.path.append("/home/user/Research/sonifications-paper1517/")
from src.utilities import fourier_analysis
from src.utilities import interpretability_utils
from src.models.contrastive_model import Model
from src.utilities.config_parser import get_config
from src.data.raw_transforms import get_raw_transforms_v2
from src.data.raw_dataset import RawWaveformDataset
from torch.utils.data import DataLoader
from src.utilities.fourier_analysis import apply_notch_filter
from src.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
import IPython.display as ipd

from src.models.cnn12_decoder import DeconvolutionalDecoder
from src.utilities.fourier_analysis import compute_fft, get_top_n_frequency_peaks
from sklearn.preprocessing import MinMaxScaler, minmax_scale
import soundfile as sf
import matplotlib.pyplot as plt
# plt.rcParams["figure.figsize"] = (30,40)
import random
import json
import tqdm
import numpy as np
from src.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
from src.data.raw_transforms import get_raw_transforms_v2, simple_supervised_transforms
from src.data.raw_transforms import get_raw_transforms_v2, simple_supervised_transforms, PadToSize, PeakNormalization, Compose
from src.data.raw_dataset import RawWaveformDataset
from torch.utils.data import DataLoader
cmaps = [(random.random(), random.random(), random.random()) for i in range(1, 9)]



In [2]:
def make_dataloader(meta_dir, audio_config, crop_size=5, csv_name=None, mode="multilabel", delim=";"):
    lbl_map_path = os.path.join(meta_dir, "lbl_map.json")
    if csv_name:
        eval_csv_path = os.path.join(meta_dir, csv_name)
    else:
        eval_csv_path = os.path.join(meta_dir, "sonification_eval.csv")

    crop_size = int(crop_size * audio_config['sample_rate'])
    # val_tfs = simple_supervised_transforms(False, crop_size,
    #                                        sample_rate=audio_config['sample_rate'])
    val_tfs = Compose([
        PadToSize(crop_size, 'wrap'),
        PeakNormalization(sr=audio_config['sample_rate'])
    ])
    val_set = RawWaveformDataset(eval_csv_path,
                                 lbl_map_path,
                                 audio_config, mode=mode,
                                 transform=val_tfs, is_val=True, delimiter=delim
                                 )
    collater = _collate_fn_raw if mode == "multilabel" else _collate_fn_raw_multiclass
    val_loader = DataLoader(val_set, sampler=None, num_workers=1,
                            collate_fn=collater,
                            shuffle=False, batch_size=1,
                            pin_memory=False)
    with open(lbl_map_path, "r") as fd:
        lbl_map = json.load(fd)
    inv_lbl_map = {v: k for k, v in lbl_map.items()}

    return val_loader, val_set, lbl_map, inv_lbl_map

In [3]:
def get_model_and_data(exp_dir, is_contrastive, num_random_maps=0.1, 
                       output_dir_name="featuremap_expection", last_epoch=None):
    EXP_DIR = exp_dir
    if is_contrastive:
        res = interpretability_utils.prep_contrastive_model_and_decoder(EXP_DIR)
    else:
        if last_epoch:
            res = interpretability_utils.prep_finetuned_model_and_decoder(EXP_DIR, last_epoch=last_epoch)
        else:
            res = interpretability_utils.prep_finetuned_model_and_decoder(EXP_DIR, last_epoch=50)
    model, net, deconv, hparams = res['full_model'], res['feature_extractor'], res['deconv_decoder'], res['hparams']

    loader, dset, lbl_map, inv_lbl_map = make_dataloader("/media/user/nvme/datasets/fsd50k/fsd50k_8000/meta/",
                                                         hparams.cfg['audio_config'], csv_name="eval.csv", delim=",")
    results = {
        "model": model,
        "net": net,
        "deconv": deconv,
        "hparams": hparams,
        "loader": loader,
        "dset": dset,
        "lbl_map": lbl_map
    }
    return results

In [4]:
import math

In [5]:
def process_data_record(meta, x, y, layer_index, top_k_perc=0.01, gaussian_perturb=False, baseline=False):
    x = x.unsqueeze(0)
    signal_input = x.squeeze().cpu().numpy()
    if baseline:
        if gaussian_perturb:
            x_super = noiser(x)
            x_super = x_super.cuda()
        else:
            x_super = x.cuda()
    x = x.cuda()
    output_features, switch_indices = interpretability_utils.infer_model(meta['net'], x)
    if baseline:
        supervised_preds = interpretability_utils.get_supervised_predictions(meta['model'], x_super)
    else:
        supervised_preds = None
    act_feats = output_features['act{}'.format(layer_index)]
    top_k_num_maps = 1 #int(math.ceil(top_k_perc * act_feats.shape[1]))
    # print(top_k_num_maps)
    with torch.no_grad():
        sonifications = meta['deconv'].visualize_top_k_maps(x, output_features, switch_indices, 
                                                            layer_index, top_k_num_maps)
        # abs scale sonifications
        sonifications = [interpretability_utils.process_vis(s, signal_input) for s in sonifications]
        # take mean
        sonifications = np.asarray(sonifications)
        # print(sonifications.shape)
        signal_deconved = sonifications.mean(0)
        # print(signal_deconved.shape)
        
    residual_sig = signal_input - signal_deconved
    
    residual_supervised_preds = interpretability_utils.get_supervised_predictions(meta['model'],
                                                                                 interpretability_utils.prep_input(residual_sig))
        
    return residual_sig, signal_deconved, residual_supervised_preds, supervised_preds

In [6]:
# meta1 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_v3/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly/", False)

In [7]:
meta2 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly_rs8882", False, last_epoch=96)

Loading /media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly_rs8882/ckpts/epoch=096_tr_loss=0.042350_tr_acc=0.475538_val_acc=0.315252.pth


In [8]:
meta3 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly_rs8883", False, last_epoch=97)

Loading /media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly_rs8883/ckpts/epoch=097_tr_loss=0.042332_tr_acc=0.473105_val_acc=0.320551.pth


In [9]:
meta4 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_v3/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly_r4/", False, last_epoch=100)

Loading /media/user/nvme/contrastive_experiments/experiments_fsd50k_v3/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fconly_r4/ckpts/epoch=100_tr_loss=0.042172_tr_acc=0.478655_val_acc=0.325465.pth


In [10]:
from sklearn.metrics import average_precision_score

In [11]:
def calculate_mAP(preds, gts, mode="macro"):
    preds = torch.cat(preds, 0).cpu().numpy()
    gts = torch.cat(gts, 0).cpu().numpy()
    map_value = average_precision_score(gts, preds, average=mode)
    return map_value

In [12]:
from src.data.raw_transforms import AddGaussianNoise

In [13]:
noiser = AddGaussianNoise()

In [14]:
def process_layer(meta, layer_index, top_k_perc=0.1, gaussian=False, baseline=False):
    dset = meta['dset']
    gts = []
    res_preds = []
    inp_preds = []
    # range(len(dset))
    for ix in tqdm.notebook.tqdm_notebook(range(len(dset)), position=2):
        x, y = dset[ix]
        _, _, residual_preds, input_preds = process_data_record(meta, x, y, layer_index, top_k_perc, baseline=baseline, gaussian_perturb=gaussian)
        gts.append(y.unsqueeze(0))
        res_preds.append(residual_preds.detach().cpu())
        inp_preds.append(input_preds)
    if baseline:
        
        baseline_map = calculate_mAP(inp_preds, gts)
    else:
        baseline_map = None
    residual_mAP = calculate_mAP(res_preds, gts)
    print(residual_mAP, baseline_map)
    return residual_mAP, baseline_map

In [15]:

# orig map: 0.28305179
# gaussian noise (0.15 max): 0.26651813197676033

In [16]:
from tqdm import notebook

In [17]:
# meta_results = {0: {1: 0.112363904929497,
#   2: 0.16124149684864902,
#   3: 0.17087513362364473,
#   4: 0.13637552450176288,
#   5: 0.1516938459454982,
#   6: 0.13516237455953947,
#   7: 0.13348355100642734,
#   8: 0.15114770606482256,
#   9: 0.16080071509563712,
#   10: 0.16559746885370352,
#   11: 0.18547917782726564},}

In [18]:
meta_results = {}
meta_count = 0
for meta in [meta2, meta3, meta4]:
    results = {}
    for layer_idx in notebook.tqdm(range(1, 12), position=1):
        if layer_idx == 1:
            baseline = True
        else:
            baseline = False
        res_mAP, baseline_mAP = process_layer(meta, layer_idx, top_k_perc=1., baseline=baseline)
        results[layer_idx] = res_mAP
    meta_results[meta_count] = results
    meta_count += 1

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10231 [00:00<?, ?it/s]

0.12566471177609206 0.303938489322368


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1647873842473217 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.16019018325134254 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.15784904774401828 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.14319607804204895 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.16323147087577017 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.13900931630018876 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1666123977618358 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.17142070211346172 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.17521368748782357 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1897644526447548 None


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10231 [00:00<?, ?it/s]

0.128077341496854 0.3093833789899899


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1563530354432886 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.15620128503875821 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.14879247434643703 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1393481996769682 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.14218285483776774 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.13139365072408943 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.15342667145954111 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1572712177826516 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.17504515113215752 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.17780279074768604 None


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10231 [00:00<?, ?it/s]

0.1137792364142563 0.30591501256690773


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1663257130724236 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.17699919104957346 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.14259340262057157 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1587097415129478 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1404819869242055 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.13896763261105718 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.15722643945899098 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.16610890440179626 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.17226367902502862 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.19119348947204767 None


In [19]:
baselines = [0.303938489322368, 0.3093833789899899, 0.30591501256690773]

In [21]:
meta_results

{0: {1: 0.12566471177609206,
  2: 0.1647873842473217,
  3: 0.16019018325134254,
  4: 0.15784904774401828,
  5: 0.14319607804204895,
  6: 0.16323147087577017,
  7: 0.13900931630018876,
  8: 0.1666123977618358,
  9: 0.17142070211346172,
  10: 0.17521368748782357,
  11: 0.1897644526447548},
 1: {1: 0.128077341496854,
  2: 0.1563530354432886,
  3: 0.15620128503875821,
  4: 0.14879247434643703,
  5: 0.1393481996769682,
  6: 0.14218285483776774,
  7: 0.13139365072408943,
  8: 0.15342667145954111,
  9: 0.1572712177826516,
  10: 0.17504515113215752,
  11: 0.17780279074768604},
 2: {1: 0.1137792364142563,
  2: 0.1663257130724236,
  3: 0.17699919104957346,
  4: 0.14259340262057157,
  5: 0.1587097415129478,
  6: 0.1404819869242055,
  7: 0.13896763261105718,
  8: 0.15722643945899098,
  9: 0.16610890440179626,
  10: 0.17226367902502862,
  11: 0.19119348947204767}}