# Real-time Auto-tagging inference with essentia-tensorflow and souncard

In [None]:
%matplotlib nbagg

import numpy as np
import matplotlib.pyplot as plt
import soundcard as sc

from struct import unpack
from IPython import display

from essentia.streaming import *
from essentia import Pool, run, array, reset

In [None]:
# model parameters
modelName = 'msd-musicnn.pb'
input_layer = 'model/Placeholder'
output_layer = 'model/Sigmoid'
msd_labels = ['rock','pop','alternative','indie','electronic','female vocalists','dance','00s','alternative rock','jazz','beautiful','metal','chillout','male vocalists','classic rock','soul','indie rock','Mellow','electronica','80s','folk','90s','chill','instrumental','punk','oldies','blues','hard rock','ambient','acoustic','experimental','female vocalist','guitar','Hip-Hop','70s','party','country','easy listening','sexy','catchy','funk','electro','heavy metal','Progressive rock','60s','rnb','indie pop','sad','House','happy']

# analysis parameters
sampleRate = 16000

frameSize=512 
hopSize=256
patchSize = 64
numberBands=96

chuncksize = 2**14 + frameSize
channels = 1

##  Instantiate and connect the algorithms 

In [None]:
fc = FrameCutter(frameSize=frameSize, hopSize=hopSize, 
                 startFromZero=True,
                 lastFrameToEndOfFile=False)

tim = TensorflowInputMusiCNN()

vtt = VectorRealToTensor(shape=[1, 1, patchSize, numberBands], lastPatchMode='discard')

ttp = TensorToPool(namespace=input_layer)

tfp = TensorflowPredict(graphFilename=modelName,
                        inputs=[input_layer],
                        outputs=[output_layer],
                        isTrainingName="model/Placeholder_1")

ptt = PoolToTensor(namespace=output_layer)

ttv = TensorToVectorReal()

pool = Pool()

In [None]:
fc.frame       >>  tim.frame
tim.bands      >>  vtt.frame
tim.bands      >>  (pool, "melbands")
vtt.tensor     >>  ttp.tensor
ttp.pool       >>  tfp.poolIn
tfp.poolOut    >>  ptt.pool
ptt.tensor     >>  ttv.tensor
ttv.frame      >>  (pool, output_layer)

## Callback to update the plots

In [None]:
def callback(in_data, frame_count, time_info, status):
    buffer = array(unpack('f' * (channels * chuncksize), in_data))
    
    if channels == 2:
        buffer = np.mean([buffer[::2], buffer[1::2]], axis=0)

    vimp = VectorInput(buffer)
    vimp.data >> fc.signal
    run(vimp)

    if pool.containsKey('melbands'):
        if  pool['melbands'].shape[0] > tokens['mel']:
            tokens['mel'] = pool['melbands'].shape[0]
            
            if pool['melbands'].shape[0] > patchSize:
                data_mel = pool['melbands'][-patchSize:,:].T
            else:
                data_mel = np.zeros([numberBands, patchSize])
                data_mel[:, -pool['melbands'].shape[0]:] = pool['melbands'].T

            img_mel.set_data(data_mel)
            img_mel.autoscale()
        
    if pool.containsKey('model/Sigmoid'):
        if  pool['model/Sigmoid'].shape[0] > tokens['activations']:
            tokens['activations'] = pool['model/Sigmoid'].shape[0]
            
            if pool['model/Sigmoid'].shape[0] > 20:
                data_acts = pool['model/Sigmoid'][-20:,:].T
            else:
                data_acts = np.zeros([50, 20])
                data_acts[:, -pool['model/Sigmoid'].shape[0]:] = pool['model/Sigmoid'].T

            img_act.set_data(data_acts)

    f.canvas.draw()
    f.canvas.flush_events()

    reset(vimp)
    vimp.data.disconnect(fc.signal)

## Prcess from a file

In [None]:
# initialize plots
f, ax = plt.subplots(1, 2, figsize=[9.6, 7])
f.canvas.draw()

ax[0].set_title('Mel Bands')
img_mel = ax[0].imshow(np.zeros([numberBands, patchSize]),
                       aspect='auto', origin='lower')

ax[1].set_title('Activations')
img_act = ax[1].matshow(np.zeros([50, 20]), aspect='0.7', vmin=0, vmax=1)
ax[1].xaxis.set_ticks_position('bottom')
ax[1].yaxis.set_ticks_position('right')
_ = plt.yticks(np.arange(50), msd_labels, fontsize=6)

# reset storage and counter
pool.clear()
tokens = {'mel': 0, 'activations': 0}

with sc.all_microphones(include_loopback=True)[0].recorder(samplerate=16000) as mic:
    while True:
        buffer = mic.record(numframes=chuncksize).mean(axis=1)
        callback(buffer, 0, 0, 'e')