In [20]:
import torchaudio
import torch
import os, csv, argparse, wget
from AST.models import ASTModel
import numpy as np
from torch.cuda.amp import autocast
import IPython

 10-second audio waveform -> sequence of 128-dimensional long filterbank 
 -> 1024(time) x 128(frequency) spectogram
 split into 512(64 time) x 8(frequency)) square patches of shape 16x16 fed into AST

In [11]:

# Filterbank
def load_audio(filename):
    waveform, sr = torchaudio.load(filename)
    assert sr == 16000, 'input audio sampling rate must be 16kHz'
    waveform = waveform - waveform.mean()
    
    # 128 dimensional long filterbank
    fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr,
                                              use_energy=False, window_type='hanning',
                                              num_mel_bins=128, dither=0.0, frame_shift=10)
    target_length = 1024
    n_frames = fbank.shape[0]
    p = target_length - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]
    # normalize the fbank
    fbank = (fbank + 5.081) / 4.4849

    # returns a 1024x128 spectogram
    return fbank

def load_label(label_csv):
    with open(label_csv, 'r') as f:
        reader = csv.reader(f, delimiter=',')
        lines = list(reader)
    labels = []
    ids = []  # Each label has a unique id such as "/m/068hy"
    for i1 in range(1, len(lines)):
        id = lines[i1][1]
        label = lines[i1][2]
        ids.append(id)
        labels.append(label)
    return labels

In [3]:
# Load AST model with AudioSet pretrained weights

checkpoint_path = "./data/audioset_0.4593.pth"
ast = ASTModel(label_dim=527, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False)
checkpoint = torch.load(checkpoint_path, map_location="cuda")
audio_model = torch.nn.DataParallel(ast, device_ids=[0])
audio_model.load_state_dict(checkpoint)
audio_model = audio_model.to(torch.device("cuda:0"))
audio_model.eval()

---------------AST Model Summary---------------
ImageNet pretraining: False, AudioSet pretraining: False
frequncey stride=10, time stride=10
number of patches=1212


DataParallel(
  (module): ASTModel(
    (v): DistilledVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): ModuleList(
        (0-11): 12 x Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (drop): Dro

In [8]:
label_csv = "./data/class_labels_indices.csv"
labels = load_label(label_csv)

In [10]:
labels[:5]

['Speech',
 'Male speech, man speaking',
 'Female speech, woman speaking',
 'Child speech, kid speaking',
 'Conversation']

In [26]:
file = "./data/sample/audio_samples/audio/5FTf2UXOjd8_000160.flac"
fb = load_audio(file)

In [27]:
fb_data = fb.expand(1, 1024, 128)
fb_data = fb_data.to(torch.device("cuda:0"))

In [37]:
with torch.no_grad():
  with autocast():
    output = audio_model.forward(fb_data)
    output = torch.sigmoid(output)
result_output = output.data.cpu().numpy()[0]
sorted_indexes = np.argsort(result_output)[::-1]

# Print audio tagging top probabilities
print('Predice results:')
for k in range(10):
    print('- {}: {:.4f}'.format(np.array(labels)[sorted_indexes[k]], result_output[sorted_indexes[k]]))
print('Listen to this sample: ')
IPython.display.Audio('./data/sample/audio_samples/audio/5FTf2UXOjd8_000160.flac', rate=16000)

Predice results:
- Music: 0.9717
- Guitar: 0.6226
- Musical instrument: 0.5547
- Plucked string instrument: 0.5020
- Steel guitar, slide guitar: 0.0830
- Bass guitar: 0.0827
- Tapping (guitar technique): 0.0817
- Strum: 0.0570
- Acoustic guitar: 0.0519
- Electric guitar: 0.0503
Listen to this sample: 


### The audio contains a musical melody made from guitar and other musical instruments

In [35]:
output.shape

torch.Size([1, 527])