In [None]:
# Copyright 2021 Vaibhav Singh (@vaibhav016)
# Copyright 2021 Dr Vinayak Abrol (_)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import argparse
import glob
import os
import pickle
import math

import cv2
import librosa.display
import matplotlib.pyplot as plt
import tensorflow as tf

from tqdm import tqdm

from tensorflow_asr.gradient_visualisation.plotting_utils import make_directories
from tensorflow_asr.utils import env_util




In [None]:
env_util.setup_environment()

DEFAULT_YAML = "/Users/vaibhavsingh/Desktop/FILRCN/contextnet/config.yml"

directory_to_save_gradient_lists = make_directories(os.getcwd(), "gradient_lists")


In [None]:
from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.asr_dataset import ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.models.transducer.contextnet import ContextNet
from tensorflow_asr.optimizers.schedules import TransformerSchedule
from tensorflow_asr.utils import env_util
import tensorflow as tf

In [2]:
tf.keras.backend.clear_session()
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": False})
strategy = env_util.setup_strategy([0])
config = Config(DEFAULT_YAML)


In [None]:
model_directory = "/content/FILRCN/examples/checkpoints"
last_trained_model = os.path.join(model_directory, sorted(os.listdir(model_directory))[-1])

speech_featurizer = TFSpeechFeaturizer(config.speech_config)

text_featurizer = CharFeaturizer(config.decoder_config)
tf.random.set_seed(0)

visualisation_dataset = ASRSliceDataset(
    speech_featurizer=speech_featurizer,
    text_featurizer=text_featurizer,
    **vars(config.learning_config.gradient_dataset_vis_config)
)

batch_size = 1
visualisation_gradient_loader = visualisation_dataset.create(batch_size)
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
contextnet.make(speech_featurizer.shape)
contextnet.load_weights(last_trained_model, by_name=True)
contextnet.add_featurizers(speech_featurizer, text_featurizer)

optimizer = tf.keras.optimizers.Adam(
    TransformerSchedule(
        d_model=contextnet.dmodel,
        warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000),
        max_lr=(0.05 / math.sqrt(contextnet.dmodel))
    ),
    **config.learning_config.optimizer_config
)

contextnet.compile(
    optimizer=optimizer,
    steps_per_execution=1,
    global_batch_size=1,
    blank=text_featurizer.blank
)
encoder = contextnet.layers[0]

activated_node_list = []
random_activated_node_list = []


In [None]:
# choosing the last node

for i, j in visualisation_gradient_loader:
    inputs = tf.Variable(i["inputs"])
    inputs_length = tf.Variable(i["inputs_length"])
    signal = tf.Variable(i["signal"])

    encoder_output = encoder.call_feature_output([inputs, inputs_length, signal])
    activated_channels = tf.norm(encoder_output, axis=1)
    activated_node_index = tf.math.argmax(activated_channels, axis=1).numpy()

    activated_node_list.append(activated_node_index[0])
    random_activated_node_list.append(3)

In [None]:
@tf.function
def get_integrated_gradients(encoder, mel_spec, inputs_length, signal, activated_node_index, random_node_index):
    m_steps = 50
    baseline = tf.zeros(shape=mel_spec.shape)
    alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps + 1)
    print("alphas", alphas.shape)
    alphas_x = alphas[:, tf.newaxis, tf.newaxis]
    print("alphas_x", alphas_x.shape)
    baseline_x = tf.expand_dims(baseline, axis=0)
    print("baseline ", baseline_x.shape)
    input_x = tf.expand_dims(mel_spec, axis=0)
    print("input", input_x.shape)
    delta = input_x - baseline_x
    interpolated_images = baseline_x + alphas_x * delta
    print("final images", interpolated_images.shape)

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(interpolated_images)
        images = tf.expand_dims(interpolated_images, axis=-1)
        print(images.shape)
        encoder_output = encoder.call_feature_output([images, inputs_length, signal])
        gradients = tape.gradient(encoder_output[:, :, activated_node_index], interpolated_images)

        random_gradients = tape.gradient(encoder_output[:, :, random_node_index], interpolated_images)

    grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
    random_grads = (random_gradients[:-1] + random_gradients[1:]) / tf.constant(2.0)

    integrated_gradients = tf.math.reduce_mean(grads, axis=0)
    integrated_random_gradients = tf.math.reduce_mean(random_grads, axis=0)

    return integrated_gradients, integrated_random_gradients


In [None]:
for filename in tqdm(sorted(os.listdir(model_directory))):
    if not filename.endswith(".h5"):
        print(filename)
        continue

    gradient_file = filename.split('.')[0]
    model_name = os.path.join(model_directory, filename)
    print("model being processed now: ", model_name)

    contextnet.load_weights(model_name, by_name=True)
    encoder = contextnet.layers[0]

    m = 0
    images_check = []
    gradients_check = []
    random_gradients_check = []
    for i, j in visualisation_gradient_loader:
        inputs = tf.Variable(i["inputs"])
        inputs_length = tf.Variable(i["inputs_length"])
        signal = tf.Variable(i["signal"])

        with tf.GradientTape(persistent=True) as tape:
            tape.watch(inputs)
            encoder_output = encoder.call_feature_output([inputs, inputs_length, signal])
            gradients = tape.gradient(encoder_output[:, :, activated_node_list[m]], inputs)
            random_gradients = tape.gradient(encoder_output[:, :, random_activated_node_list[m]], inputs)

        interated_gradients, random_integrated_gradients = get_integrated_gradients(encoder, tf.squeeze(inputs),
                                                                                    inputs_length, signal,
                                                                                    activated_node_list[m],
                                                                                    random_activated_node_list[m])

        gradients_check.append(interated_gradients)
        random_gradients_check.append(random_integrated_gradients)

        images_check.append(tf.squeeze(inputs))

        print("integrated_gradients shape=========", interated_gradients.shape, random_integrated_gradients.shape)
        m = m + 1

    dd = {'input_image': images_check,
          'integrated_gradients': gradients_check,
          'random_integrated_gradients': random_gradients_check,
          'index_of_activated_node': activated_node_list,
          'index_of_random_node': random_activated_node_list
          }

    file_path_to_save = os.path.join(directory_to_save_gradient_lists, filename)

    with open(file_path_to_save + ".pkl", 'wb') as f:
        pickle.dump(dd, f)
    f.close()

In [None]:
def obtain_cmap(color_map):
    return plt.get_cmap(color_map)


def make_directories(current_working_directory_abs, directory_name):
    _directory_abs = os.path.join(current_working_directory_abs, directory_name)
    try:
        os.mkdir(_directory_abs)
    except Exception as e:
        print("--------------", directory_name, "directory already exists-----------------")
        print("--------------The contents will be over-ridden-------------------")
        return _directory_abs

    return _directory_abs


def normalize_gradients(gradient_directory):
    norm_max_g = -1
    normm_max_r = -1

    for index, file in enumerate(sorted(os.listdir(gradient_directory))):
        filename = os.path.join(gradient_directory, file)
        print("filenmes ", filename)
        with open(filename, "rb") as f:
            x_temp = pickle.load(f)

        gradients_check = x_temp["integrated_gradients"]
        random_gradients = x_temp["random_integrated_gradients"]

        for i, j in zip(gradients_check, random_gradients):
            norm_max_g = tf.maximum(tf.norm(i), norm_max_g)
            normm_max_r = tf.maximum(tf.norm(i), normm_max_r)

            print("normm=======", norm_max_g, normm_max_r)

    return norm_max_g, normm_max_r

cmap = "jet"
index_fixed = True

In [None]:
current_working_directory = os.getcwd()
# compute gradient lists 
gradient_directory = os.path.join(current_working_directory, "gradient_lists")
directory_to_save_plots = make_directories(current_working_directory, "gradient_plots")
directory_to_save_video = make_directories(current_working_directory, "video")
norm_g, norm_r = normalize_gradients(gradient_directory)

def gradient_transformation(gradient, norm):
    # gradient = tf.abs(gradient)
    # gradient = tf.square(gradient)
    # gradient = gradient/norm
    # print(gradient)

    return gradient.numpy().T



In [None]:
def plot_gradients_images(directory_to_save_plots):
    for index, file in enumerate(sorted(os.listdir(gradient_directory))):
        filename = os.path.join(gradient_directory, file)
        print("gradient list file =>  ",filename)
        with open(filename, "rb") as f:
            x_temp = pickle.load(f)

        name = file.split('.')[0].split('_')[-1]

        images_check = x_temp['input_image']
        gradients_check = x_temp["integrated_gradients"]
        random_gradients = x_temp["random_integrated_gradients"]
        activated_node_list = x_temp["index_of_activated_node"]
        random_node_list = x_temp["index_of_random_node"]

        fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(25, 14), facecolor=(1, 1, 1))
        fig.suptitle(' Model checkpoint ' + str(index), fontsize=16)
        for i, j in enumerate(zip(images_check, gradients_check)):
            # this is to obtain other samples from visualisation data(it has 6 files- even=>male, and odd=>female
            # if i<4:
            #     continue
            print(images_check[i+1].shape)
            print(gradients_check[i].shape)

            ax[0][0].set(title=" Log Mel spectrogram for male voice")
            img = librosa.display.specshow(images_check[i].numpy().T, sr=16000,
                                           fmax=8000, x_axis='time', y_axis='mel', ax=ax[0][0], alpha=1)
            plt.colorbar(img, ax=ax[0][0])
            ax[0][0].label_outer()

            title = " Gradient Attribution for filter " + str(activated_node_list[i])
            ax[0][1].set(title=title)
            img2 = librosa.display.specshow(gradient_transformation(gradients_check[i], norm_g), sr=16000,
                                            fmax=8000, x_axis='time', y_axis='mel',cmap=obtain_cmap(cmap), ax=ax[0][1],  alpha=1)
            if index_fixed:
                img2.set_clim(vmin=-0.1, vmax=0.1)
            plt.colorbar(img2, ax=ax[0][1])
            ax[0][1].label_outer()

            title = " Gradient Attribution for filter " + str(random_node_list[i])
            ax[0][2].set(title=title)
            img2 = librosa.display.specshow(gradient_transformation(random_gradients[i], norm_r), sr=16000,
                                            fmax=8000, x_axis='time', y_axis='mel', cmap=obtain_cmap(cmap),ax=ax[0][2], alpha=1)
            if index_fixed:
                img2.set_clim(vmin=-0.1, vmax=0.1)
            plt.colorbar(img2, ax=ax[0][2])
            ax[0][2].label_outer()

    ############################################################  2nd row of female#######################################################

            ax[1][0].set(title=" Log Mel spectrogram for female voice")
            img = librosa.display.specshow(images_check[i + 1].numpy().T, sr=16000,
                                           fmax=8000, x_axis='time', y_axis='mel', ax=ax[1][0], alpha=1)
            plt.colorbar(img, ax=ax[1][0])
            ax[1][0].label_outer()



            img2 = librosa.display.specshow(gradient_transformation(gradients_check[i+1], norm_g), sr=16000,
                                            fmax=8000, x_axis='time', y_axis='mel', cmap=obtain_cmap(cmap), ax=ax[1][1],  alpha=1)
            if index_fixed:
                img2.set_clim(vmin=-0.1, vmax=0.1)
            plt.colorbar(img2, ax=ax[1][1])
            ax[1][1].label_outer()

            img2 = librosa.display.specshow(gradient_transformation(random_gradients[i+1], norm_r), sr=16000,
                                            fmax=8000, x_axis='time', y_axis='mel', cmap=obtain_cmap(cmap), ax=ax[1][2], alpha=1)
            if index_fixed:
                img2.set_clim(vmin=-0.1, vmax=0.1)
            plt.colorbar(img2, ax=ax[1][2])
            ax[1][2].label_outer()


            # this break means that only 2 images(male and female) will be displayed in a plot.
            # This is not a sanity break. Its purposeful.
            break

        plt.savefig(directory_to_save_plots + "/Grad" + name +".png")



In [None]:
def make_videos_from_images(directory_to_save_video):
    img_array = []
    size = (10, 10)



    fname1 = os.path.join(os.getcwd(), "gradient_plots") + "/*.png"
    # fname1 = '/Users/vaibhavsingh/Desktop/TensorFlowASR/examples/contextnet/contextnet_visualisation/gradient_visualisation/grad_vis_4/*.png'

    for filename1 in (sorted(glob.glob(fname1))):
        print(filename1)
        image1 = cv2.imread(filename1)
        height, width, layers = image1.shape
        size = (width, height)
        print(size)
        img_array.append(image1)

    out = cv2.VideoWriter(directory_to_save_video + "/gradient_vis.avi", cv2.VideoWriter_fourcc(*'DIVX'), 1, size)

    for i in range(len(img_array)):
        out.write(img_array[i])
    out.release()

In [None]:
plot_gradients_images(directory_to_save_plots)
make_videos_from_images(directory_to_save_video)