In [1]:
import torch
import torch.nn as nn
from torch.nn import LayerNorm
import torchaudio.compliance.kaldi as ta_kaldi

from backbone import (
    TransformerEncoder,
)
from quantizer import (
    NormEMAVectorQuantizer,
)

import logging
from typing import Optional

## Tokenizer

In [2]:
logger = logging.getLogger(__name__)


class TokenizersConfig:
    def __init__(self, cfg=None):
        self.input_patch_size: int = -1  # path size of patch embedding
        self.embed_dim: int = 512  # patch embedding dimension
        self.conv_bias: bool = False  # include bias in conv encoder

        self.encoder_layers: int = 12  # num encoder layers in the transformer
        self.encoder_embed_dim: int = 768  # encoder embedding dimension
        self.encoder_ffn_embed_dim: int = 3072  # encoder embedding dimension for FFN
        self.encoder_attention_heads: int = 12  # num encoder attention heads
        self.activation_fn: str = "gelu"  # activation function to use

        self.layer_norm_first: bool = False  # apply layernorm first in the transformer
        self.deep_norm: bool = False  # apply deep_norm first in the transformer

        # dropouts
        self.dropout: float = 0.1  # dropout probability for the transformer
        self.attention_dropout: float = 0.1  # dropout probability for attention weights
        self.activation_dropout: float = 0.0  # dropout probability after activation in FFN
        self.encoder_layerdrop: float = 0.0  # probability of dropping a tarnsformer layer
        self.dropout_input: float = 0.0  # dropout to apply to the input (after feat extr)

        # positional embeddings
        self.conv_pos: int = 128  # number of filters for convolutional positional embeddings
        self.conv_pos_groups: int = 16  # number of groups for convolutional positional embedding

        # relative position embedding
        self.relative_position_embedding: bool = False  # apply relative position embedding
        self.num_buckets: int = 320  # number of buckets for relative position embedding
        self.max_distance: int = 1280  # maximum distance for relative position embedding
        self.gru_rel_pos: bool = False  # apply gated relative position embedding

        # quantizer
        self.quant_n: int = 1024 # codebook number in quantizer
        self.quant_dim: int = 256    # codebook dimension in quantizer

        if cfg is not None:
            self.update(cfg)

    def update(self, cfg: dict):
        self.__dict__.update(cfg)


class Tokenizers(nn.Module):
    def __init__(
            self,
            cfg: TokenizersConfig,
    ) -> None:
        super().__init__()
        logger.info(f"Tokenizers Config: {cfg.__dict__}")

        self.cfg = cfg

        self.embed = cfg.embed_dim
        self.post_extract_proj = (
            nn.Linear(self.embed, cfg.encoder_embed_dim)
            if self.embed != cfg.encoder_embed_dim
            else None
        )

        self.input_patch_size = cfg.input_patch_size
        self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
                                         bias=cfg.conv_bias)

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        assert not cfg.deep_norm or not cfg.layer_norm_first
        self.encoder = TransformerEncoder(cfg)
        self.layer_norm = LayerNorm(self.embed)

        self.quantize = NormEMAVectorQuantizer(
            n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
        )
        self.quant_n = cfg.quant_n
        self.quantize_layer = nn.Sequential(
            nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
            nn.Tanh(),
            nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim)  # for quantize
        )

    def forward_padding_mask(
            self,
            features: torch.Tensor,
            padding_mask: torch.Tensor,
    ) -> torch.Tensor:
        extra = padding_mask.size(1) % features.size(1)
        if extra > 0:
            padding_mask = padding_mask[:, :-extra]
        padding_mask = padding_mask.view(
            padding_mask.size(0), features.size(1), -1
        )
        padding_mask = padding_mask.all(-1)
        return padding_mask

    def preprocess(
            self,
            source: torch.Tensor,
            fbank_mean: float = 15.41663,
            fbank_std: float = 6.55582,
    ) -> torch.Tensor:
        fbanks = []
        for waveform in source:
            waveform = waveform.unsqueeze(0) * 2 ** 15
            fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
            fbanks.append(fbank)
        fbank = torch.stack(fbanks, dim=0)
        fbank = (fbank - fbank_mean) / (2 * fbank_std)
        return fbank

    def extract_labels(
            self,
            source: torch.Tensor,
            padding_mask: Optional[torch.Tensor] = None,
            fbank_mean: float = 15.41663,
            fbank_std: float = 6.55582,
    ):
        fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)

        if padding_mask is not None:
            padding_mask = self.forward_padding_mask(fbank, padding_mask)

        fbank = fbank.unsqueeze(1)
        features = self.patch_embedding(fbank)
        features = features.reshape(features.shape[0], features.shape[1], -1)
        features = features.transpose(1, 2)
        features = self.layer_norm(features)

        if padding_mask is not None:
            padding_mask = self.forward_padding_mask(features, padding_mask)

        if self.post_extract_proj is not None:
            features = self.post_extract_proj(features)

        x = self.dropout_input(features)

        x, layer_results = self.encoder(
            x,
            padding_mask=padding_mask,
        )

        quantize_input = self.quantize_layer(x)
        quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)

        return embed_ind

## BEATs Models

In [3]:
logger = logging.getLogger(__name__)


class BEATsConfig:
    def __init__(self, cfg=None):
        self.input_patch_size: int = -1  # path size of patch embedding
        self.embed_dim: int = 512  # patch embedding dimension
        self.conv_bias: bool = False  # include bias in conv encoder

        self.encoder_layers: int = 12  # num encoder layers in the transformer
        self.encoder_embed_dim: int = 768  # encoder embedding dimension
        self.encoder_ffn_embed_dim: int = 3072  # encoder embedding dimension for FFN
        self.encoder_attention_heads: int = 12  # num encoder attention heads
        self.activation_fn: str = "gelu"  # activation function to use

        self.layer_wise_gradient_decay_ratio: float = 1.0  # ratio for layer-wise gradient decay
        self.layer_norm_first: bool = False  # apply layernorm first in the transformer
        self.deep_norm: bool = False  # apply deep_norm first in the transformer

        # dropouts
        self.dropout: float = 0.1  # dropout probability for the transformer
        self.attention_dropout: float = 0.1  # dropout probability for attention weights
        self.activation_dropout: float = 0.0  # dropout probability after activation in FFN
        self.encoder_layerdrop: float = 0.0  # probability of dropping a tarnsformer layer
        self.dropout_input: float = 0.0  # dropout to apply to the input (after feat extr)

        # positional embeddings
        self.conv_pos: int = 128  # number of filters for convolutional positional embeddings
        self.conv_pos_groups: int = 16  # number of groups for convolutional positional embedding

        # relative position embedding
        self.relative_position_embedding: bool = False  # apply relative position embedding
        self.num_buckets: int = 320  # number of buckets for relative position embedding
        self.max_distance: int = 1280  # maximum distance for relative position embedding
        self.gru_rel_pos: bool = False  # apply gated relative position embedding

        # label predictor
        self.finetuned_model: bool = False  # whether the model is a fine-tuned model.
        self.predictor_dropout: float = 0.1  # dropout probability for the predictor
        self.predictor_class: int = 527  # target class number for the predictor

        if cfg is not None:
            self.update(cfg)

    def update(self, cfg: dict):
        self.__dict__.update(cfg)


class BEATs(nn.Module):
    def __init__(
            self,
            cfg: BEATsConfig,
    ) -> None:
        super().__init__()
        logger.info(f"BEATs Config: {cfg.__dict__}")

        self.cfg = cfg

        self.embed = cfg.embed_dim
        self.post_extract_proj = (
            nn.Linear(self.embed, cfg.encoder_embed_dim)
            if self.embed != cfg.encoder_embed_dim
            else None
        )

        self.input_patch_size = cfg.input_patch_size
        self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
                                         bias=cfg.conv_bias)

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        assert not cfg.deep_norm or not cfg.layer_norm_first
        self.encoder = TransformerEncoder(cfg)
        self.layer_norm = LayerNorm(self.embed)

        if cfg.finetuned_model:
            self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
            self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
        else:
            self.predictor = None

    def forward_padding_mask(
            self,
            features: torch.Tensor,
            padding_mask: torch.Tensor,
    ) -> torch.Tensor:
        extra = padding_mask.size(1) % features.size(1)
        if extra > 0:
            padding_mask = padding_mask[:, :-extra]
        padding_mask = padding_mask.view(
            padding_mask.size(0), features.size(1), -1
        )
        padding_mask = padding_mask.all(-1)
        return padding_mask

    def preprocess(
            self,
            source: torch.Tensor,
            fbank_mean: float = 15.41663,
            fbank_std: float = 6.55582,
    ) -> torch.Tensor:
        fbanks = []
        for waveform in source:
            waveform = waveform.unsqueeze(0) * 2 ** 15
            fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
            fbanks.append(fbank)
        fbank = torch.stack(fbanks, dim=0)
        fbank = (fbank - fbank_mean) / (2 * fbank_std)
        return fbank

    def extract_features(
            self,
            source: torch.Tensor,
            padding_mask: Optional[torch.Tensor] = None,
            fbank_mean: float = 15.41663,
            fbank_std: float = 6.55582,
    ):
        fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)

        if padding_mask is not None:
            padding_mask = self.forward_padding_mask(fbank, padding_mask)

        fbank = fbank.unsqueeze(1)
        features = self.patch_embedding(fbank)
        features = features.reshape(features.shape[0], features.shape[1], -1)
        features = features.transpose(1, 2)
        features = self.layer_norm(features)

        if padding_mask is not None:
            padding_mask = self.forward_padding_mask(features, padding_mask)

        if self.post_extract_proj is not None:
            features = self.post_extract_proj(features)

        x = self.dropout_input(features)

        x, layer_results = self.encoder(
            x,
            padding_mask=padding_mask,
        )

        if self.predictor is not None:
            x = self.predictor_dropout(x)
            logits = self.predictor(x)

            if padding_mask is not None and padding_mask.any():
                logits[padding_mask] = 0
                logits = logits.sum(dim=1)
                logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
            else:
                logits = logits.mean(dim=1)

            lprobs = torch.sigmoid(logits)

            return lprobs, padding_mask
        else:
            return x, padding_mask


## Iteration 1

In [4]:
#tokenize
checkpoint = torch.load(r"models\Tokenizer_iter1.pt") 
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(10, 10000)
padding_mask = torch.zeros(10, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
print(labels)



tensor([  41,  798,  809,  809,  883,  375,  375,  375,  579,  809, 1008,  916,
         883,  496,  375,  375, 1002,  280,  280,  387,  916,  883,  496,  496,
         894,  128,  916,  809,  883,  883,  375,  496,  894,  490,  809,  809,
         809,  375,  496,  496,  579,  490,  261,  916,   33,  496,  496,  496,
          41,  128,  809,  916,  916,  496,  375,  496,  579,  809,  809,  916,
         883,  496,  375,  496,  579,  809,  809,  916,  916,  597,  496,  496,
         609,  809,  916,  951,  916,  375,  375,  375, 1002,  217,  775,  809,
         916,  375,  496,  496,  403,  951,  809,  809,   33,  496,  496,  496,
         609,  687,  809,  809,  916,  375,  375,  375,  403,    0,  809,  916,
         916,  496,  496,  496,  579,  774,  280,  916,  883,  496,  496,  496,
         798,  809,  277,  900,  883,  883,  375,  375,  579,  482,  809,  916,
         916,  269,  496,  496,  270,  490,  477,  128,  809,  883,  496,  496,
         609,  482,  809,  883,  883,  3

In [5]:
# load the pre-trained model
checkpoint = torch.load(r"models\BEATs_iter1.pt")

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

BEATs(
  (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
  (patch_embedding): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)
  (dropout_input): Dropout(p=0.0, inplace=False)
  (encoder): TransformerEncoder(
    (pos_conv): Sequential(
      (0): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
      (1): SamePad()
      (2): GELU(approximate='none')
    )
    (layers): ModuleList(
      (0): TransformerSentenceEncoderLayer(
        (self_attn): MultiheadAttention(
          (dropout_module): Dropout(p=0.0, inplace=False)
          (relative_attention_bias): Embedding(320, 12)
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (grep_linear): Linear(in_features=64, out

In [6]:
# predict the classification probability of each class
audio_input_16khz = torch.randn(100, 10000)
padding_mask = torch.zeros(100, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label_iter1 = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label_iter1} with probability of {top5_label_prob}')

Top 5 predicted labels of the 0th audio are ['/m/07rgkc5', '/m/09x0r', '/m/04rlf', '/m/0chx_', '/m/07yv9'] with probability of tensor([0.2311, 0.1505, 0.0888, 0.0757, 0.0380], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 1th audio are ['/m/07rgkc5', '/m/0chx_', '/m/09x0r', '/m/04rlf', '/m/096m7z'] with probability of tensor([0.4373, 0.1031, 0.0974, 0.0880, 0.0665], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 2th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04rlf', '/m/096m7z'] with probability of tensor([0.3790, 0.1702, 0.0966, 0.0910, 0.0432], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 3th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04rlf', '/m/096m7z'] with probability of tensor([0.3540, 0.1306, 0.1185, 0.0912, 0.0530], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 4th audio are ['/m/07rgkc5', '/m/09x0r', '/m/04rlf', '/m/0chx_', '/m/07yv9'] with probability of tensor([0.2206, 0.1511, 0.0742, 0.0552, 0.0406], grad_fn=<

In [7]:
top5_label_iter1

['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04rlf', '/m/0cj0r']

## Iteration 2 

In [8]:
#tokenize
checkpoint = torch.load(r"models\Tokenizer_iter2.pt") 
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(10, 10000)
padding_mask = torch.zeros(10, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
print(labels)

tensor([653, 375, 692, 945, 990, 330, 705, 143, 653, 462, 278, 391, 816, 705,
        330, 801, 653, 511, 570, 748, 227, 705, 771, 771, 653, 692, 656, 556,
        391, 391, 104, 143, 653, 391, 377, 758, 771, 705, 143, 143, 462, 439,
        692, 288, 801, 219, 143, 771, 653, 391, 453, 653, 932, 472, 801, 801,
        653, 746, 375, 945, 391, 932, 472, 362, 271, 391, 444, 312, 288, 771,
        771, 771, 271, 375, 375, 146, 104, 104, 104, 143, 653, 375, 113, 391,
        439, 104, 472, 771, 653, 146, 439, 767, 769, 771, 771, 771, 653, 652,
        189, 375, 391, 288, 705, 472, 653, 891, 841, 288, 104, 104, 104, 472,
        653, 989, 147, 122, 990, 104, 104, 771, 653, 716, 692, 977, 553, 143,
        246, 143, 653, 503, 150, 833, 391, 705, 472, 801, 653, 656, 243, 653,
        104, 705, 801, 801, 653, 989,  22, 391, 104, 104, 104, 143, 653, 900,
        511, 375, 439, 472, 705, 143, 524,  31,  15, 391, 972, 705, 143, 143,
        653, 391, 841, 990, 104, 104, 104, 104, 653, 344, 444, 5

In [9]:
# load the pre-trained model
checkpoint = torch.load(r"models\BEATs_iter2.pt")

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

BEATs(
  (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
  (patch_embedding): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)
  (dropout_input): Dropout(p=0.0, inplace=False)
  (encoder): TransformerEncoder(
    (pos_conv): Sequential(
      (0): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
      (1): SamePad()
      (2): GELU(approximate='none')
    )
    (layers): ModuleList(
      (0): TransformerSentenceEncoderLayer(
        (self_attn): MultiheadAttention(
          (dropout_module): Dropout(p=0.0, inplace=False)
          (relative_attention_bias): Embedding(320, 12)
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (grep_linear): Linear(in_features=64, out

In [10]:
# predict the classification probability of each class
audio_input_16khz = torch.randn(100, 10000)
padding_mask = torch.zeros(100, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label_iter2 = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label_iter2} with probability of {top5_label_prob}')

Top 5 predicted labels of the 0th audio are ['/m/0j2kx', '/m/09x0r', '/m/07rgkc5', '/m/0cj0r', '/m/0chx_'] with probability of tensor([0.1457, 0.1354, 0.1311, 0.0880, 0.0628], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 1th audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/04rlf'] with probability of tensor([0.3900, 0.1612, 0.1080, 0.1010, 0.0691], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 2th audio are ['/m/0j2kx', '/m/09x0r', '/m/07yv9', '/m/07rgkc5', '/m/07qlf79'] with probability of tensor([0.2426, 0.1322, 0.1055, 0.0696, 0.0639], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 3th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04rlf', '/m/07qlf79'] with probability of tensor([0.3157, 0.1403, 0.0726, 0.0710, 0.0518], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 4th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04rlf', '/m/07qlf79'] with probability of tensor([0.3362, 0.1351, 0.0762, 0.0593, 0.0555], grad_

In [11]:
top5_label_iter2

['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04229', '/m/07yv9']

## Iteration 3

In [12]:
#tokenize
checkpoint = torch.load(r"models\Tokenizer_iter3.pt") 
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(10, 10000)
padding_mask = torch.zeros(10, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
print(labels)

tensor([549, 224, 757, 372, 433, 962, 757, 103, 466, 372, 372, 605,  30, 962,
        433, 713, 466, 757, 372, 372, 409,  30, 395, 713, 137, 757, 453, 889,
        185, 962, 962, 103, 466, 629, 241, 813, 962, 962, 433, 713, 466, 241,
        813, 813, 965, 962, 757, 713, 108, 757, 372, 372, 103, 962, 395, 554,
        400, 372, 605, 372, 898, 898, 962, 554, 108, 267, 372, 372, 372, 395,
        395, 321, 466, 890, 372, 241, 151, 898, 395, 103, 466, 629, 915, 813,
        898, 898, 395, 713, 466, 372, 978, 813, 241, 151, 185, 713, 466, 372,
        978, 151, 615, 757, 433, 185, 466, 629, 372, 372,  30,  85, 395, 713,
        466, 757, 372, 372,  30, 757, 757, 321, 108, 224, 224, 405, 395, 433,
        395, 321, 108, 629, 241, 241, 141, 237, 321, 704, 108, 224, 241, 241,
         30, 237,  30, 321, 549, 372, 978, 372, 151, 202, 433, 554, 466, 757,
        372, 372, 750, 202, 395, 554, 137, 629, 372, 241, 898, 433, 962, 103,
        137, 757, 372, 241,  30,  30, 321, 321, 137, 372, 629, 2

In [13]:
# load the pre-trained model
checkpoint = torch.load(r"models\BEATs_iter3.pt")

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

BEATs(
  (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
  (patch_embedding): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)
  (dropout_input): Dropout(p=0.0, inplace=False)
  (encoder): TransformerEncoder(
    (pos_conv): Sequential(
      (0): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
      (1): SamePad()
      (2): GELU(approximate='none')
    )
    (layers): ModuleList(
      (0): TransformerSentenceEncoderLayer(
        (self_attn): MultiheadAttention(
          (dropout_module): Dropout(p=0.0, inplace=False)
          (relative_attention_bias): Embedding(320, 12)
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (grep_linear): Linear(in_features=64, out

In [14]:
# predict the classification probability of each class
audio_input_16khz = torch.randn(100, 10000)
padding_mask = torch.zeros(100, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label_iter3 = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label_iter3} with probability of {top5_label_prob}')

Top 5 predicted labels of the 0th audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/07qlf79'] with probability of tensor([0.3836, 0.2575, 0.2486, 0.1529, 0.1204], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 1th audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/07qlf79'] with probability of tensor([0.6755, 0.2237, 0.1100, 0.0938, 0.0644], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 2th audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/07qlf79', '/m/09x0r'] with probability of tensor([0.7371, 0.2330, 0.1053, 0.0920, 0.0770], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 3th audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/04rlf'] with probability of tensor([0.6488, 0.2286, 0.1161, 0.1039, 0.0445], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 4th audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/07qlf79'] with probability of tensor([0.6795, 0.2601, 0.1426, 0.1219, 0.0595],

In [15]:
top5_label_iter3

['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/04rlf']

## Iteration 3+ 

In [16]:
#tokenize
checkpoint = torch.load(r"models\Tokenizer_iter3_plus.pt") 
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(10, 10000)
padding_mask = torch.zeros(10, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
print(labels)

tensor([215, 160,  36,  36,  36, 350,  36, 364, 883, 400,  36,  36, 350, 967,
        350, 364, 250, 400,  36,  36, 562, 323, 350,  36, 953, 566,  36,  36,
        967,  36,  36, 364, 883, 243, 206,  36, 350, 713, 350, 364, 250, 400,
         36, 339,  36, 350, 350, 890, 838,  36,  36, 244, 379, 319, 350,  36,
        767,  36, 244, 788,  36, 350, 603, 738, 694, 500, 244,  36, 627, 889,
        603,  36, 670, 500,  36,  36, 967,  36, 267, 757, 362, 566,  36,  36,
        267, 350, 350, 199, 767, 243,  36,  36, 313,  36, 350, 364, 767, 565,
         36,  36, 967,  36,  36, 364, 767, 890,  36,  36, 350, 381, 350, 915,
        694, 400,  36,  36, 716, 662, 350,  36, 767,  36,  36, 313, 612, 838,
        603, 199, 767, 400,  36,  36, 379, 350, 319, 364, 767, 733,  36,  36,
        185, 425, 603,  36, 767, 890,  36, 733, 296, 350,  36, 350, 362, 881,
         36,  36, 350, 967, 350, 505, 767, 890,  36, 339, 627,  36, 350, 603,
        767, 392,  36,  36,  36, 319, 603,  36, 767, 400,  36,  

In [18]:
# load the pre-trained model
checkpoint = torch.load(r"models\BEATs_iter3_plus.pt")

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

BEATs(
  (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
  (patch_embedding): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)
  (dropout_input): Dropout(p=0.0, inplace=False)
  (encoder): TransformerEncoder(
    (pos_conv): Sequential(
      (0): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
      (1): SamePad()
      (2): GELU(approximate='none')
    )
    (layers): ModuleList(
      (0): TransformerSentenceEncoderLayer(
        (self_attn): MultiheadAttention(
          (dropout_module): Dropout(p=0.0, inplace=False)
          (relative_attention_bias): Embedding(320, 12)
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (grep_linear): Linear(in_features=64, out

In [19]:
# predict the classification probability of each class
audio_input_16khz = torch.randn(100, 10000)
padding_mask = torch.zeros(100, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label_iter3_plus = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label_iter3_plus} with probability of {top5_label_prob}')

Top 5 predicted labels of the 0th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/07qlf79', '/m/096m7z'] with probability of tensor([0.1929, 0.1440, 0.1187, 0.0829, 0.0721], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 1th audio are ['/m/07rgkc5', '/m/0chx_', '/m/09x0r', '/m/096m7z', '/m/07qlf79'] with probability of tensor([0.2254, 0.1394, 0.1382, 0.1179, 0.0914], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 2th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/096m7z', '/m/07qlf79'] with probability of tensor([0.2745, 0.1639, 0.1563, 0.1397, 0.0390], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 3th audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/096m7z', '/m/07qlf79'] with probability of tensor([0.2780, 0.1538, 0.1450, 0.1041, 0.0509], grad_fn=<UnbindBackward0>)
Top 5 predicted labels of the 4th audio are ['/m/07rgkc5', '/m/0chx_', '/m/09x0r', '/m/096m7z', '/m/0cj0r'] with probability of tensor([0.2845, 0.1937, 0.1445, 0.1422, 0.0374],

In [20]:
top5_label_iter3_plus

['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/096m7z', '/m/0cj0r']

## Results

In [21]:
print(f'Top 5 predicted labels of the 1st iteration audio are {top5_label_iter1}')
print(f'Top 5 predicted labels of the 2nd iteration audio are {top5_label_iter2}')
print(f'Top 5 predicted labels of the 3rd iteration audio are {top5_label_iter3}')
print(f'Top 5 predicted labels of the 3_plus iteration audio are {top5_label_iter3_plus}')

Top 5 predicted labels of the 1st iteration audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04rlf', '/m/0cj0r']
Top 5 predicted labels of the 2nd iteration audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/04229', '/m/07yv9']
Top 5 predicted labels of the 3rd iteration audio are ['/m/07rgkc5', '/m/0chx_', '/m/096m7z', '/m/09x0r', '/m/04rlf']
Top 5 predicted labels of the 3_plus iteration audio are ['/m/07rgkc5', '/m/09x0r', '/m/0chx_', '/m/096m7z', '/m/0cj0r']
