In [None]:
%pip install captum
%matplotlib inline
%pip install grad-cam
%pip install resampy

In [None]:
import numpy as np
import json
import os
from captum.attr import IntegratedGradients
from captum.attr import LayerConductance
from captum.attr import NeuronConductance
import captum.attr
import matplotlib
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, tqdm_notebook
from torch.autograd import Variable
from torch.autograd import Function
from fastprogress.fastprogress import format_time, master_bar, progress_bar
from sklearn.metrics import f1_score, jaccard_score
from sklearn import preprocessing
from scipy import stats
import pandas as pd
import resampy

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class InstrumentDataset(Dataset):
  def __init__(self, csv_file, json_file, root_dir, spec_type):
    self.audio_frame = pd.read_csv(csv_file)
    with open(json_file, 'r') as f:
      self.instrument_classes = json.load(f)
    
    self.json_file = json_file
    self.csv_file = csv_file
    self.root_dir = root_dir

    # Add chroma_cqt if desired
    if spec_type not in ["mel_spectrogram", "chroma_stft", "cqt"]:
      raise Exception("Not valid spectrogram type")
    else:
      self.spec_type = spec_type

    # List of unique sample_keys (aka audio file names)
    self.unique_audio_files = self.audio_frame.sample_key.unique()

    # Dataframe specifying the sample_key of the audio file and the instrument labels
    self.audio_file_labels = self.audio_frame.groupby('sample_key')['instrument'].apply(list).reset_index(name='labels')

    # self.audio_file = self.audio_file_labels.iloc[:500,:].copy()
    self.audio_file = self.audio_file_labels.copy()

    num_data = self.audio_file.shape

    os_join = np.vectorize(os.path.join)

    self.audio_file[spec_type] = np.full(num_data[0], [0])
    self.audio_file[spec_type] = os_join(np.full(num_data[0], self.root_dir), np.full(num_data[0], spec_type),
                                          self.audio_file.sample_key.str[:3], 
                                          self.audio_file.sample_key.str[:] + np.full(num_data[0], '_' + spec_type + '.npy'))

    # Matrix of instrument labels ordered by audio file number (increasing sample_key value)
    # self.label_matrix = self.audio_file_labels.labels.tolist()
    self.label_matrix = self.audio_file.labels.tolist()

    binarizer = preprocessing.MultiLabelBinarizer()
    
    self.binary_label_matrix = binarizer.fit_transform(self.label_matrix)
    self.label_df = pd.DataFrame(self.binary_label_matrix,columns=[instrument for instrument in self.instrument_classes.keys()])

    self.audio_file = pd.concat([self.audio_file, self.label_df], axis=1)

  def get_instrument_class_dict(self):
    return self.instrument_classes

  def __len__(self):
    return len(self.audio_file.index)

  def __getitem__(self, idx):
    # Allow for slicing
    if torch.is_tensor(idx):
      idx = idx.tolist()

    if type(idx) is int:
      idx = [idx]

    # Get the instruments types on hot encoded
    instrument_types = np.array(self.audio_file.iloc[idx, 3:]).astype(float)
    # Get the spectrograms as a numpy array from the npy files
    specs = np.array(self.audio_file.iloc[idx, 2])

    spec_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([-47.3835, -47.3835, -47.3835], [18.5056, 18.5056, 18.5056])
    ])
    
    for file_idx, file_name in enumerate(specs):
      # 3 channel expanded
      spec = np.repeat(np.expand_dims(np.load(file_name).astype(float),-1), 3, -1)
      # 1 channel
      # spec = np.load(file_name).astype(float)
      # spec = 2*(spec - np.min(spec))/np.ptp(spec) - 1
      spec = spec_transforms(spec)
      specs[file_idx] = spec

    specs = np.stack(specs)
    specs = np.squeeze(specs)
    # specs = np.repeat(specs[:,:,:,np.newaxis], 3, -1)
    # print(specs.shape)

    instrument_types = torch.from_numpy(instrument_types)

    
    sample = {'specs': specs, 
              'instrument(s)': instrument_types, 
              'sample_key': self.audio_file.iloc[idx, 0].tolist(), 
              'spec_type': self.spec_type}

    return sample

In [None]:
class VGGishInstrumentDataset(Dataset):
  def __init__(self, csv_file, json_file, root_dir, spec_type):
    self.audio_frame = pd.read_csv(csv_file)
    with open(json_file, 'r') as f:
      self.instrument_classes = json.load(f)
    
    self.json_file = json_file
    self.csv_file = csv_file
    self.root_dir = root_dir

    # Add chroma_cqt if desired
    if spec_type not in ["vgg", "audio"]:
      raise Exception("Not valid spectrogram type")
    else:
      self.spec_type = spec_type

    # List of unique sample_keys (aka audio file names)
    self.unique_audio_files = self.audio_frame.sample_key.unique()

    # Dataframe specifying the sample_key of the audio file and the instrument labels
    self.audio_file_labels = self.audio_frame.groupby('sample_key')['instrument'].apply(list).reset_index(name='labels')

#     self.audio_file = self.audio_file_labels.iloc[:500,:].copy()
    self.audio_file = self.audio_file_labels.copy()

    num_data = self.audio_file.shape

    os_join = np.vectorize(os.path.join)

    self.audio_file[spec_type] = np.full(num_data[0], [0])
#     self.audio_file[spec_type] = os_join(np.full(num_data[0], self.root_dir), np.full(num_data[0], 'audio'), np.full(num_data[0], 'audio'),
#                                           self.audio_file.sample_key.str[:3], 
#                                           self.audio_file.sample_key.str[:] + np.full(num_data[0], '.ogg'))
    self.audio_file[spec_type] = os_join(np.full(num_data[0], self.root_dir), np.full(num_data[0], self.spec_type),
                                          self.audio_file.sample_key.str[:3], 
                                          self.audio_file.sample_key.str[:] + np.full(num_data[0], '_' + spec_type + '.npy'))

    # Matrix of instrument labels ordered by audio file number (increasing sample_key value)
    # self.label_matrix = self.audio_file_labels.labels.tolist()
    self.label_matrix = self.audio_file.labels.tolist()

    binarizer = preprocessing.MultiLabelBinarizer()
    
    self.binary_label_matrix = binarizer.fit_transform(self.label_matrix)
    self.label_df = pd.DataFrame(self.binary_label_matrix,columns=[instrument for instrument in self.instrument_classes.keys()])

    self.audio_file = pd.concat([self.audio_file, self.label_df], axis=1)

  def get_instrument_class_dict(self):
    return self.instrument_classes

  def __len__(self):
    return len(self.audio_file.index)

  def __getitem__(self, idx):
    # Allow for slicing
    if torch.is_tensor(idx):
      idx = idx.tolist()

    if type(idx) is int:
      idx = [idx]

    # Get the instruments types on hot encoded
    instrument_types = np.array(self.audio_file.iloc[idx, 3:]).astype(float)
    # Get the spectrograms as a numpy array from the npy files
    audios = np.array(self.audio_file.iloc[idx, 2])

    spec_transforms = transforms.Compose([
        transforms.Normalize([-47.3835], [18.5056]),
    ])
    
    spec_array = []
    
    for file_idx, file_name in enumerate(audios):
      spec = np.load(file_name, allow_pickle=True).astype(float)
      spec = np.expand_dims(spec,-1)
      spec_array.append(spec)
    specs = np.stack(spec_array)
    specs = torch.from_numpy(specs)
    specs = specs.squeeze()
    specs = specs.double()
#     audios = np.expand_dims(audios, axis=-1)
#     audios = torch.from_numpy(audios)

#     audios = np.stack(audios)
#     specs = np.squeeze(specs)
    # specs = np.repeat(specs[:,:,:,np.newaxis], 3, -1)
    # print(specs.shape)

    instrument_types = torch.from_numpy(instrument_types)
    
    fs = 22050
    
    sample = {'specs': specs, 
              'instrument(s)': instrument_types, 
              'sample_key': self.audio_file.iloc[idx, 0].tolist(), 
              'spec_type': self.spec_type,
              'fs': fs}

    return sample

In [None]:
mel_spec_dataset = InstrumentDataset(csv_file='/content/drive/MyDrive/ECE6255_Project/openmic-2018-aggregated-labels.csv',
                                       json_file='/content/drive/MyDrive/ECE6255_Project/class-map.json',
                                       root_dir='/content/drive/MyDrive/ECE6255_Project',
                                       spec_type='mel_spectrogram')

In [None]:
# VGGish Dataset definition
vgg_dataset = VGGishInstrumentDataset(csv_file='/content/drive/MyDrive/ECE6255_Project/openmic-2018-aggregated-labels.csv',
                                       json_file='/content/drive/MyDrive/ECE6255_Project/class-map.json',
                                       root_dir='/content/drive/MyDrive/ECE6255_Project',
                                       spec_type='vgg')

In [None]:
batch_size = 1
mel_train_size = int(0.8*len(mel_spec_dataset))
mel_test_size = len(mel_spec_dataset) - mel_train_size

mel_train_set, mel_test_set = torch.utils.data.random_split(mel_spec_dataset, [mel_train_size, mel_test_size])
mel_val_size = int(0.25*mel_test_size)
mel_test_size = mel_test_size - mel_val_size
print(mel_train_size)

mel_val_set, mel_test_set = torch.utils.data.random_split(mel_test_set, [mel_val_size, mel_test_size])

mel_train_loader = DataLoader(mel_train_set, batch_size=batch_size, shuffle=True, num_workers=2)
mel_val_loader = DataLoader(mel_val_set, batch_size=batch_size, shuffle=True, num_workers=2)
mel_test_loader = DataLoader(mel_test_set, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
# Data splits
batch_size = 1
vgg_train_size = int(0.8*len(vgg_dataset))
vgg_test_size = len(vgg_dataset) - vgg_train_size

vgg_train_set, vgg_test_set = torch.utils.data.random_split(vgg_dataset, [vgg_train_size, vgg_test_size])
vgg_val_size = int(0.25*vgg_test_size)
vgg_test_size = vgg_test_size - vgg_val_size
print(vgg_train_size)

vgg_val_set, vgg_test_set = torch.utils.data.random_split(vgg_test_set, [vgg_val_size, vgg_test_size])

vgg_train_loader = DataLoader(vgg_train_set, batch_size=batch_size, shuffle=True, num_workers=2)
vgg_val_loader = DataLoader(vgg_val_set, batch_size=batch_size, shuffle=True, num_workers=2)
vgg_test_loader = DataLoader(vgg_test_set, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:

class MelSpecNetwork(nn.Module):
  def __init__(self):
    super(MelSpecNetwork, self).__init__()

    self.resnet = models.resnet50(pretrained=True)
    default_in_ftrs = self.resnet.fc.in_features
    for param in self.resnet.parameters():
        param.requires_grad = False
    
    # Don't freeze last layer of resnet
    for param in self.resnet.layer4.parameters():
        param.requires_grad = True

    # Replace fully connected layer to fit 20 instrument classes
    self.resnet.fc = nn.Linear(default_in_ftrs, 20)
    
    # self.sigmoid = nn.Sigmoid()

  def forward(self, input):
    output = self.resnet(input)
    output = torch.sigmoid(output)

    return output

# VGGish Arch
class VggNetwork(nn.Module):
  def __init__(self):
    super(VggNetwork, self).__init__()
    
    urls = {
            'vggish': "https://github.com/harritaylor/torchvggish/releases/download/v0.1/vggish-10086976.pth",
            'pca': "vggish_pca_params-970ea276.pth"
        }

    self.vggish_model = torch.hub.load('harritaylor/torchvggish', 'vggish', preprocess=False, postprocess=False)
#     for param in self.vggish_model.parameters():
#         param.requires_grad = False
    
#     # Don't freeze last layer of resnet
#     for param in self.resnet.layer4.parameters():
#         param.requires_grad = True
    # 10 seconds * 128 embedding size to 20 instrument classes
    self.classify = nn.Sequential(
            nn.Linear(10 * 128, 20),
        )       

  def forward(self, input):
    bs, num_frames, _, _ = input.size()
    input = input.view(bs*num_frames, 1, input.size(2), input.size(3))
    vggish_logits = self.vggish_model(input) # [bs*num_frames, 128]
    vggish_logits = vggish_logits.reshape(bs, vggish_logits.size(1) * num_frames)
    
    output = self.classify(vggish_logits)
    output = torch.sigmoid(output)

    return output

In [None]:
mel_spec_model = MelSpecNetwork()
mel_spec_model.eval()
mel_spec_model.double()
print(mel_spec_model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict_name = 'mel_spectrogram_full.pt'
model_state_dict = torch.load('/content/drive/MyDrive/ECE6255_Project/models/' + state_dict_name, map_location=device)
mel_spec_model.load_state_dict(model_state_dict['model'])

In [None]:
vgg_model = VggNetwork()
vgg_model.eval()
vgg_model.double()
print(vgg_model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
state_dict_name = 'vgg_full.pt'
model_state_dict = torch.load('/content/drive/MyDrive/ECE6255_Project/models/' + state_dict_name, map_location=device)
vgg_model.load_state_dict(model_state_dict['model'])

In [None]:
# GradCAM class
class GradCAM:
  # init
  def __init__(self, model, target_layer):
    self.model = model.eval()
    self.target_layer = target_layer
    self.feature_maps = None
    self.gradients = None
    self.hooks = []

    # extract feature maps
    def save_feature_maps(module, input, output):
      self.feature_maps = output.detach()

    # extract gradients
    def save_gradients(module, grad_in, grad_out):
      self.gradients = grad_out[0].detach()

    # register the functions as hooks for all the modules in the model
    for name, module in self.model.named_modules():
      if name == self.target_layer:
        self.hooks.append(module.register_forward_hook(save_feature_maps))
        self.hooks.append(module.register_backward_hook(save_gradients))
  
  # forward pass
  def forward(self, x):
    return self.model(x)

  # backward pass
  def backward(self, index):
    one_hot = torch.zeros(self.feature_maps.shape[1:], dtype=torch.float32)
    one_hot[index] = 1.0
    one_hot = one_hot.unsqueeze(0).to(device=self.feature_maps.device)
    self.model.zero_grad()
    self.feature_maps.backward(gradient=one_hot, retain_graph=True)

  # heatmap
  def generate_heatmap(self, index):
    weights = self.gradients.mean(dim=[2, 3], keepdim=True)
    heatmap = (weights * self.feature_maps).sum(dim=1, keepdim=True)
    heatmap = torch.relu(heatmap)
    heatmap /= torch.max(heatmap)
    return heatmap

In [None]:
with open('/content/drive/MyDrive/ECE6255_Project/class-map.json', 'r') as f:
    instrument_classes = json.load(f)
instrument_classes = list(instrument_classes.keys())
threshold = 0.5

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, RawScoresOutputTarget, ClassifierOutputSoftmaxTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

from pytorch_grad_cam import DeepFeatureFactorization
from pytorch_grad_cam.utils.image import show_factorization_on_image

# Get layer to analyze
layer_list = []
layer_to_analyze = mel_spec_model.resnet.layer4.children()
for layer in layer_to_analyze:
  layer_list.append(layer)
print(layer_list[2].conv3)
layer_target = layer_list[2].conv3

# Get data
data = next(iter(mel_test_loader))
input_spec = data['specs']
print(data['sample_key'])
plt.imshow(data['specs'].squeeze().numpy()[0,:,:])
plt.show()

# Transform spectrogram into an image with float values between 0 and 1
input_image = input_spec.squeeze().detach().numpy().transpose(1,2,0)
input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())

# Create GradCAM object
gc = GradCAM(model=mel_spec_model, target_layers=[layer_target])

# Create Deep Feature Factorization
dff = DeepFeatureFactorization(model=mel_spec_model, target_layer=layer_target, computation_on_concepts=mel_spec_model.resnet.fc)

# Get true label indices
label_indices = torch.argwhere(data['instrument(s)'].squeeze(0))
label_indices_list = torch.argwhere(data['instrument(s)'].squeeze(0))
label_indices_list = label_indices_list[:, 1].detach().numpy()
print([instrument_classes[i] for i in label_indices_list])

# Get prediction label indices
pred_indices_list = np.array((mel_spec_model(input_spec).squeeze(0).detach().numpy() > threshold).astype(int).nonzero()).squeeze(0)
print([instrument_classes[i] for i in pred_indices_list])

input_spec.requires_grad_()
input_spec.retain_grad()

class ClassifyTarget:
    def __init__(self, category):
        self.category = category

    def __call__(self, model_output):
        if len(model_output.shape) == 1:
            print(model_output[self.category])
            return model_output[self.category]
        return model_output[:, self.category]

plt.rcParams['figure.figsize'] = (10, 5)

for label in pred_indices_list:
  classifier_targets = ClassifyTarget(label)
  result = gc(input_tensor=input_spec, targets=[classifier_targets])
  result = result[0, :]
  print("Why label: ", instrument_classes[label])
  visualization = show_cam_on_image(input_image, result, use_rgb=True)
  plt.imshow(visualization)
  plt.show()
  
plt.rcParams['figure.figsize'] = (20, 10)
top_k = 2
mel_spec_model.float()
concepts, batch_explanations, concept_scores = dff(input_spec.float(), 4)
mel_spec_model.double()
concept_scores = torch.sigmoid(torch.from_numpy(concept_scores)).numpy()
concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
concept_labels_topk = []
for concept_index in range(concept_categories.shape[0]):
    categories = concept_categories[concept_index, :]    
    concept_labels = []
    for category in categories:
        score = concept_scores[concept_index, category]
        label = f"{instrument_classes[category].split(',')[0]}:{score:.2f}"
        concept_labels.append(label)
    concept_labels_topk.append("\n".join(concept_labels))
visualization = show_factorization_on_image(input_image, 
                                            batch_explanations[0],
                                            image_weight=0.3,
                                            concept_labels=concept_labels_topk)
plt.imshow(visualization)
plt.show()
plt.rcParams['figure.figsize'] = (10, 5)


In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, RawScoresOutputTarget, ClassifierOutputSoftmaxTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

plt.rcParams['figure.figsize'] = (20, 10)
vgg_model.double()

# Get layer to analyze
vgg_layer_list = []
vgg_layer_to_analyze = vgg_model.vggish_model.features.children()
for vgg_layer in vgg_layer_to_analyze:
  vgg_layer_list.append(vgg_layer)
print(vgg_layer_list[13])
vgg_layer_target = vgg_layer_list[13]

# Get data
data = next(iter(vgg_test_loader))
input_spec = data['specs']
print(data['sample_key'])

# Transform spectrogram into an image with float values between 0 and 1
input_image = input_spec.squeeze().detach().numpy()
input_image = input_image.reshape(input_image.shape[0] * input_image.shape[1], input_image.shape[2]).T
plt.imshow(input_image)
plt.show()
input_image = np.expand_dims(input_image, -1)
input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())

# Create GradCAM object
gc = GradCAM(model=vgg_model, target_layers=[vgg_layer_target])

# Create Deep Feature Factorization
dff = DeepFeatureFactorization(model=vgg_model, target_layer=vgg_model.vggish_model.embeddings,
                               computation_on_concepts=vgg_model.classify)

# Get true label indices
label_indices = torch.argwhere(data['instrument(s)'].squeeze(0))
label_indices_list = torch.argwhere(data['instrument(s)'].squeeze(0))
label_indices_list = label_indices_list[:,1].detach().numpy()
print([instrument_classes[i] for i in label_indices_list])

# Get prediction label indices
pred_indices_list = np.array((vgg_model(input_spec).squeeze(0).detach().numpy() > threshold).astype(int).nonzero()).squeeze(0)
print([instrument_classes[i] for i in pred_indices_list])

input_spec.requires_grad_()
input_spec.retain_grad()

class ClassifyTarget:
    def __init__(self, category):
        self.category = category

    def __call__(self, model_output):
        if len(model_output.shape) == 1:
            return model_output[self.category]
        return model_output[:, self.category]

for label in pred_indices_list:
  classifier_targets = ClassifyTarget(label)
  result = gc(input_tensor=input_spec, targets=[classifier_targets])
  result = result.reshape(result.shape[0] * result.shape[1], result.shape[2]).T
  print("Why label: ", instrument_classes[label])
  visualization = show_cam_on_image(input_image, result, use_rgb=True)
  plt.imshow(visualization)
  plt.show()

# top_k = 2
# vgg_model.float()
# print(input_spec.shape)
# concepts, batch_explanations, concept_scores = dff(input_spec.float(), 5)
# vgg_model.double()
# concept_scores = torch.sigmoid(torch.from_numpy(concept_scores)).numpy()
# concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
# concept_labels_topk = []
# for concept_index in range(concept_categories.shape[0]):
#     categories = concept_categories[concept_index, :]    
#     concept_labels = []
#     for category in categories:
#         score = concept_scores[concept_index, category]
#         label = f"{instrument_classes[category].split(',')[0]}:{score:.2f}"
#         concept_labels.append(label)
#     concept_labels_topk.append("\n".join(concept_labels))
# print(len(batch_explanations))
# print(batch_explanations[0].shape)
# batch_explanations = np.stack(batch_explanations).transpose(0,1,3,2).reshape(5, 64, 960)
# print(batch_explanations.shape)
# print(batch_explanations)
# visualization = show_factorization_on_image(input_image, 
#                                             batch_explanations,
#                                             image_weight=0.3,
#                                             concept_labels=concept_labels_topk)
# plt.imshow(visualization)
# plt.show()
# plt.rcParams['figure.figsize'] = (10, 5)