In [2]:
import librosa
import torch
import types
import torch.nn as nn
from transformers import AutoModelForAudioClassification
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
                                                  Wav2Vec2PreTrainedModel)


signal = torch.from_numpy(
    librosa.load('/kaggle/input/speech-dataset/UrbanSound8K/audio/fold3/102105-3-0-0.wav', sr=16000)[0])[None, :]
device = 'cpu'

class ADV(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, x):

        x = self.dense(x)
        x = torch.tanh(x)

        return self.out_proj(x)


class Dawn(Wav2Vec2PreTrainedModel):
  

    def __init__(self, config):

        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = ADV(config)

    def forward(self, x):
        x -= x.mean(1, keepdim=True)
        variance = (x * x).mean(1, keepdim=True) + 1e-7
        x = self.wav2vec2(x / variance.sqrt())
        return self.classifier(x.last_hidden_state.mean(1))


def _forward(self, x):
    '''x: (batch, audio-samples-16KHz)'''
    x = (x + self.config.mean) / self.config.std  # sgn
    x = self.ssl_model(x, attention_mask=None).last_hidden_state
    # pool
    h = self.pool_model.sap_linear(x).tanh()
    w = torch.matmul(h, self.pool_model.attention).softmax(1)
    mu = (x * w).sum(1)
    x = torch.cat(
        [
            mu,
            ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
        ], 1)
    return self.ser_model(x)


# WavLM

base = AutoModelForAudioClassification.from_pretrained(
        '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
        trust_remote_code=True).to(device).eval()
base.forward = types.MethodType(_forward, base)

# Wav2Vec2

dawn = Dawn.from_pretrained(
    'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
).to(device).eval()


def wav2small(x):
    return .5 * dawn(x) + .5 * base(x)

pred = wav2small(signal.to(device))
print(f'Arousal={pred[0, 0]} '
      f'Dominance={pred[0, 1]} ',
      f'Valence={pred[0, 2]}')


config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

pipeline_utils.py:   0%|          | 0.00/5.43k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes:
- pipeline_utils.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.22k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.34k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/661M [00:00<?, ?B/s]

Arousal=0.7052952647209167 Dominance=0.5967236757278442  Valence=0.5533943772315979
