<a href="https://colab.research.google.com/github/shangeth/wavencoder/blob/master/examples/notebooks/wavencoder_demo_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo Notebook
## [wavencoder](https://pypi.org/project/wavencoder/) Models

---


Author : Shangeth Rajaa

![Twitter Follow](https://img.shields.io/twitter/follow/shangethr?style=social)

[GitHub](https://github.com/shangeth) [LinkedIn](https://www.linkedin.com/in/shangeth/)

# Installing wavencoder

In [None]:
!pip install fairseq
!pip install wavencoder

Collecting fairseq
[?25l  Downloading https://files.pythonhosted.org/packages/67/bf/de299e082e7af010d35162cb9a185dc6c17db71624590f2f379aeb2519ff/fairseq-0.9.0.tar.gz (306kB)
[K     |████████████████████████████████| 307kB 4.7MB/s 
Collecting sacrebleu
[?25l  Downloading https://files.pythonhosted.org/packages/a3/c4/8e948f601a4f9609e8b2b58f31966cb13cf17b940b82aa3e767f01c42c52/sacrebleu-1.4.14-py3-none-any.whl (64kB)
[K     |████████████████████████████████| 71kB 7.4MB/s 
Collecting portalocker
  Downloading https://files.pythonhosted.org/packages/89/a6/3814b7107e0788040870e8825eebf214d72166adf656ba7d4bf14759a06a/portalocker-2.0.0-py2.py3-none-any.whl
Building wheels for collected packages: fairseq
  Building wheel for fairseq (setup.py) ... [?25l[?25hdone
  Created wheel for fairseq: filename=fairseq-0.9.0-cp36-cp36m-linux_x86_64.whl size=2046423 sha256=f68a8ad6a5ce93540273cc72dd2d81ba74683c60a9aeb5aad1b93c187a2bb6e5
  Stored in directory: /root/.cache/pip/wheels/37/3e/1b/0fa30695

In [1]:
import sys
sys.path.append("../../")

import torchaudio
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend("soundfile")



# Wav2Vec pretrained feature extractor

In [2]:
import torch
import wavencoder

x = torch.randn(1, 16000) # [1, 16000]
encoder = wavencoder.models.Wav2Vec(pretrained=False)
z = encoder(x) # [1, 512, 98]
z.shape

torch.Size([1, 512, 98])

# SincNet pretrained feature extractor

In [3]:
from wavencoder.models import SincNet
encoder = SincNet(pretrained=False).eval()
x = torch.randn(1, 3200) 
z = encoder(x)
print(z.shape)

torch.Size([1, 2048])


# RawNet

In [4]:
import torch
import wavencoder

x = torch.randn(1, 59049) # [1, 16000]
rawnet_encoder = wavencoder.models.RawNet2Model(pretrained=False, return_code=True, class_dim=100)
z = rawnet_encoder(x) # [1, 1024]
z.shape

torch.Size([1, 1024])

# Audio Classifier
- wav2vec encoder `[1, 16000] -> [1, 512, 98]`
- mean of features along time axis `[1, 512, 98] -> [1, 512]`
- ANN Classifier `[1, 512] -> [1, 2]`


In [5]:
import torch
import torch.nn as nn
import wavencoder

class AudioClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = wavencoder.models.Wav2Vec(pretrained=False)
        self.classifier = nn.Linear(512, 2)

    def forward(self, x):
        z = self.encoder(x)
        z = torch.mean(z, dim=2)
        out = self.classifier(z)
        return out

model = AudioClassifier()
x = torch.randn(1, 16000)
y_hat = model(x)
print(y_hat.shape)

torch.Size([1, 2])


- SincNet encoder `[1, 3200] -> [1, 6420]`
- ANN Classifier `[1, 6420] -> [1, 512]`

In [6]:
import torch
import torch.nn as nn
import wavencoder

class SincNetAudioClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = SincNet(pretrained=False)
        self.classifier = nn.Linear(2048, 2)

    def forward(self, x):
        z = self.encoder(x)
        out = self.classifier(z)
        return out

model = SincNetAudioClassifier()
x = torch.randn(2, 3200)
y_hat = model(x)
print(y_hat.shape)

torch.Size([2, 2])


# LSTM Attention Classifier

In [2]:
import torch
import torch.nn as nn
import wavencoder

model = nn.Sequential(
        wavencoder.models.Wav2Vec(pretrained=False),
        wavencoder.models.LSTM_Attn_Classifier(512, 64, 2, return_attn_weights=True, attn_type='soft')
)

x = torch.randn(5, 16000)
y_hat, attn_weights = model(x)

print(y_hat.shape, attn_weights.shape)

torch.Size([5, 2]) torch.Size([5, 98])
