In [1]:
import glob
import matplotlib.pyplot as plt
import torchvision
import librosa
import random
import json
import pickle
import numpy as np
import torch
import os
torch.multiprocessing.set_sharing_strategy('file_system')
import tqdm
import gc
import sys
sys.path.append("/home/user/Research/sonifications-paper1517/")
from src.utilities import fourier_analysis
from src.utilities import interpretability_utils
from src.models.contrastive_model import Model
from src.utilities.config_parser import get_config
from src.data.raw_transforms import get_raw_transforms_v2
from src.data.raw_dataset import RawWaveformDataset
from torch.utils.data import DataLoader
from src.utilities.fourier_analysis import apply_notch_filter
from src.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
import IPython.display as ipd

from src.models.cnn12_decoder import DeconvolutionalDecoder
from src.utilities.fourier_analysis import compute_fft, get_top_n_frequency_peaks
from sklearn.preprocessing import MinMaxScaler, minmax_scale
import soundfile as sf
import matplotlib.pyplot as plt
# plt.rcParams["figure.figsize"] = (30,40)
import random
import json
import tqdm
import numpy as np
from src.data.utils import _collate_fn_raw, _collate_fn_raw_multiclass
from src.data.raw_transforms import get_raw_transforms_v2, simple_supervised_transforms
from src.data.raw_dataset import RawWaveformDataset
from src.data.raw_transforms import get_raw_transforms_v2, simple_supervised_transforms, PadToSize, PeakNormalization, Compose
from torch.utils.data import DataLoader
cmaps = [(random.random(), random.random(), random.random()) for i in range(1, 9)]



In [2]:
def make_dataloader(meta_dir, audio_config, crop_size=5, csv_name=None, mode="multilabel", delim=";"):
    lbl_map_path = os.path.join(meta_dir, "lbl_map.json")
    if csv_name:
        eval_csv_path = os.path.join(meta_dir, csv_name)
    else:
        eval_csv_path = os.path.join(meta_dir, "sonification_eval.csv")

    crop_size = int(crop_size * audio_config['sample_rate'])
    # val_tfs = simple_supervised_transforms(False, crop_size,
    #                                        sample_rate=audio_config['sample_rate'])
    val_tfs = Compose([
        PadToSize(crop_size, 'wrap'),
        PeakNormalization(sr=audio_config['sample_rate'])
    ])
    val_set = RawWaveformDataset(eval_csv_path,
                                 lbl_map_path,
                                 audio_config, mode=mode,
                                 transform=val_tfs, is_val=True, delimiter=delim
                                 )
    collater = _collate_fn_raw if mode == "multilabel" else _collate_fn_raw_multiclass
    val_loader = DataLoader(val_set, sampler=None, num_workers=1,
                            collate_fn=collater,
                            shuffle=False, batch_size=1,
                            pin_memory=False)
    with open(lbl_map_path, "r") as fd:
        lbl_map = json.load(fd)
    inv_lbl_map = {v: k for k, v in lbl_map.items()}

    return val_loader, val_set, lbl_map, inv_lbl_map

In [3]:
def get_model_and_data(exp_dir, is_contrastive, num_random_maps=0.1, 
                       output_dir_name="featuremap_expection", last_epoch=None):
    EXP_DIR = exp_dir
    if is_contrastive:
        res = interpretability_utils.prep_contrastive_model_and_decoder(EXP_DIR)
    else:
        if last_epoch:
            res = interpretability_utils.prep_finetuned_model_and_decoder(EXP_DIR, last_epoch=last_epoch)
        else:
            res = interpretability_utils.prep_finetuned_model_and_decoder(EXP_DIR, last_epoch=50)
    model, net, deconv, hparams = res['full_model'], res['feature_extractor'], res['deconv_decoder'], res['hparams']

    loader, dset, lbl_map, inv_lbl_map = make_dataloader("/media/user/nvme/datasets/fsd50k/fsd50k_8000/meta/",
                                                         hparams.cfg['audio_config'], csv_name="eval.csv", delim=",")
    results = {
        "model": model,
        "net": net,
        "deconv": deconv,
        "hparams": hparams,
        "loader": loader,
        "dset": dset,
        "lbl_map": lbl_map
    }
    return results

In [4]:
import math

In [5]:
def process_data_record(meta, x, y, layer_index, top_k_perc=0.01, gaussian_perturb=False, baseline=False):
    x = x.unsqueeze(0)
    signal_input = x.squeeze().cpu().numpy()
    if baseline:
        if gaussian_perturb:
            x_super = noiser(x)
            x_super = x_super.cuda()
        else:
            x_super = x.cuda()
    x = x.cuda()
    output_features, switch_indices = interpretability_utils.infer_model(meta['net'], x)
    if baseline:
        supervised_preds = interpretability_utils.get_supervised_predictions(meta['model'], x_super)
    else:
        supervised_preds = None
    act_feats = output_features['act{}'.format(layer_index)]
    top_k_num_maps = 1 #int(math.ceil(top_k_perc * act_feats.shape[1]))
    # print(top_k_num_maps)
    with torch.no_grad():
        sonifications = meta['deconv'].visualize_top_k_maps(x, output_features, switch_indices, 
                                                            layer_index, top_k_num_maps)
        # abs scale sonifications
        sonifications = [interpretability_utils.process_vis(s, signal_input) for s in sonifications]
        # take mean
        sonifications = np.asarray(sonifications)
        # print(sonifications.shape)
        signal_deconved = sonifications.mean(0)
        # print(signal_deconved.shape)
        
    residual_sig = signal_input - signal_deconved
    
    residual_supervised_preds = interpretability_utils.get_supervised_predictions(meta['model'],
                                                                                 interpretability_utils.prep_input(residual_sig))
        
    return residual_sig, signal_deconved, residual_supervised_preds, supervised_preds

In [6]:
# meta1 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_v3/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel/", False)

In [7]:
meta2 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel_rs8882/", False, last_epoch=95)

Loading /media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel_rs8882/ckpts/epoch=095_tr_loss=0.031408_tr_acc=0.685658_val_acc=0.418466.pth


In [8]:
meta3 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel_rs8883/", False, last_epoch=98)

Loading /media/user/nvme/contrastive_experiments/experiments_fsd50k_latest/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel_rs8883/ckpts/epoch=098_tr_loss=0.031482_tr_acc=0.681768_val_acc=0.422190.pth


In [9]:
meta4 = get_model_and_data("/media/user/nvme/contrastive_experiments/experiments_fsd50k_v3/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel_r4/", False, last_epoch=98)

Loading /media/user/nvme/contrastive_experiments/experiments_fsd50k_v3/cnn12_1x_full_tr_1x64_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_ft_fullmodel_r4/ckpts/epoch=098_tr_loss=0.031548_tr_acc=0.683315_val_acc=0.422022.pth


In [10]:
from sklearn.metrics import average_precision_score

In [11]:
def calculate_mAP(preds, gts, mode="macro"):
    preds = torch.cat(preds, 0).cpu().numpy()
    gts = torch.cat(gts, 0).cpu().numpy()
    map_value = average_precision_score(gts, preds, average=mode)
    return map_value

In [12]:
from src.data.raw_transforms import AddGaussianNoise

In [13]:
noiser = AddGaussianNoise()

In [14]:
def process_layer(meta, layer_index, top_k_perc=0.1, gaussian=False, baseline=False):
    dset = meta['dset']
    gts = []
    res_preds = []
    inp_preds = []
    # range(len(dset))
    for ix in tqdm.notebook.tqdm_notebook(range(len(dset)), position=2):
        x, y = dset[ix]
        _, _, residual_preds, input_preds = process_data_record(meta, x, y, layer_index, top_k_perc, baseline=baseline, gaussian_perturb=gaussian)
        gts.append(y.unsqueeze(0))
        res_preds.append(residual_preds.detach().cpu())
        inp_preds.append(input_preds)
    if baseline:
        
        baseline_map = calculate_mAP(inp_preds, gts)
    else:
        baseline_map = None
    residual_mAP = calculate_mAP(res_preds, gts)
    print(residual_mAP, baseline_map)
    return residual_mAP, baseline_map

In [15]:
from tqdm import notebook

In [16]:
# results = {}
# for layer_idx in notebook.tqdm(range(1, 12), position=1):
#     if layer_idx == 1:
#         baseline = True
#     else:
#         baseline = False
#     res_mAP, baseline_mAP = process_layer(meta, layer_idx, top_k_perc=1., baseline=baseline)
#     results[layer_idx] = res_mAP

In [17]:
# with new data transforms

In [18]:
# meta_results = {0: {1: 0.21898723119808686,
#   2: 0.2075165284444755,
#   3: 0.2169938626862424,
#   4: 0.2232939336995754,
#   5: 0.25147119901573556,
#   6: 0.27084707045355466,
#   7: 0.27310423734162564,
#   8: 0.3126478078935492,
#   9: 0.31253935617985484,
#   10: 0.3534935464858034,
#   11: 0.3164541589543721}}

In [19]:
meta_results = {}
meta_count = 0
for meta in [meta2, meta3, meta4]:
    results = {}
    for layer_idx in notebook.tqdm(range(1, 12), position=1):
        if layer_idx == 1:
            baseline = True
        else:
            baseline = False
        res_mAP, baseline_mAP = process_layer(meta, layer_idx, top_k_perc=1., baseline=baseline)
        results[layer_idx] = res_mAP
    meta_results[meta_count] = results
    meta_count += 1

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10231 [00:00<?, ?it/s]

0.1837950956654363 0.4243540602566276


  0%|          | 0/10231 [00:00<?, ?it/s]

0.19671473705687795 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.20108673916061362 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.22839356873791183 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.24803364962186142 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.2801172312563643 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.27110793890395973 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3110693194328128 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.316840358987352 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3215811625128609 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3271420230040678 None


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10231 [00:00<?, ?it/s]

0.22985252242840487 0.4274145287715402


  0%|          | 0/10231 [00:00<?, ?it/s]

0.19597478469046467 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.21783907609314937 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.23930789814136008 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.26865411317923715 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.28090549465691333 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.288099491039961 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.29758997141753923 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.31679773694370217 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.35268393548779886 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3378161897150029 None


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10231 [00:00<?, ?it/s]

0.20279004818311505 0.42345384894634636


  0%|          | 0/10231 [00:00<?, ?it/s]

0.1887408070434549 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.21838904490771974 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.2158818242469624 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.27319212312492397 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.28078952116189515 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.28692273536359414 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3179243595785271 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3185604812317797 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3533927597841378 None


  0%|          | 0/10231 [00:00<?, ?it/s]

0.3262227478317309 None


In [20]:
baselines = [0.4243540602566276, 0.4274145287715402, 0.42345384894634636]

In [22]:
meta_results

{0: {1: 0.1837950956654363,
  2: 0.19671473705687795,
  3: 0.20108673916061362,
  4: 0.22839356873791183,
  5: 0.24803364962186142,
  6: 0.2801172312563643,
  7: 0.27110793890395973,
  8: 0.3110693194328128,
  9: 0.316840358987352,
  10: 0.3215811625128609,
  11: 0.3271420230040678},
 1: {1: 0.22985252242840487,
  2: 0.19597478469046467,
  3: 0.21783907609314937,
  4: 0.23930789814136008,
  5: 0.26865411317923715,
  6: 0.28090549465691333,
  7: 0.288099491039961,
  8: 0.29758997141753923,
  9: 0.31679773694370217,
  10: 0.35268393548779886,
  11: 0.3378161897150029},
 2: {1: 0.20279004818311505,
  2: 0.1887408070434549,
  3: 0.21838904490771974,
  4: 0.2158818242469624,
  5: 0.27319212312492397,
  6: 0.28078952116189515,
  7: 0.28692273536359414,
  8: 0.3179243595785271,
  9: 0.3185604812317797,
  10: 0.3533927597841378,
  11: 0.3262227478317309}}

In [28]:
# baseline: 0.42486969684209214

In [2]:
import pygal

In [4]:
from pygal.style import Style
custom_style = Style(
  background='transparent',
  plot_background='transparent')

In [None]:
line_chart = pygal.Bar(style=custom_style)
# line_chart.title = 'Mean Top-5 magnitude-squared coherence'
line_chart.x_labels = ["{}".format(i) for i in range(1, 12)]
line_chart.x_title = "Layer"
# line_chart.y_title = "Coherence coefficient"
line_chart.add('Contrastive', contrastive_mean)
line_chart.add('Supervised',  supervised_mean)
line_chart.render_to_png("/home/user/Desktop/sonifications_stimulate.png")