In [None]:
# ライブラリインポート
import torch, torchaudio
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle


from tqdm.notebook import tqdm
from torch.utils.data import random_split, DataLoader
from torchemotion.datasets.EmodbDataset import EmodbDataset
from torchemotion.datasets.IemocapDataset import IemocapDataset
from transformers import Wav2Vec2Processor, WavLMModel, WavLMForSequenceClassification, WavLMConfig #, WavLMForCTC
from IPython.display import Audio, display
from tqdm.notebook import tqdm

from sklearn.model_selection import train_test_split, KFold

seed = 42

In [None]:
# utils

# model
class SERwithWavLM(nn.Module):
    def __init__(self, pretrained_model, num_labels):
        super().__init__()
        self.wavlm_config = WavLMConfig(pretrained_model)
        self.wavlm_config.update({'num_labels':num_labels, 'use_weighted_layer_sum':True})
        self.wavlm = WavLMForSequenceClassification.from_pretrained(pretrained_model, config=self.wavlm_config)
    
    def forward(self, inputs):
        outputs = self.wavlm(**inputs[0], labels=inputs[1]) 
        logits = outputs.logits
        loss = outputs.loss
        
        return logits, loss

# initialization of weight    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        # Liner層の初期化
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

# count parameters of model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# plot dataset's info about emotion and length of waveform
def plot_data(dataset):
    emotions = list()
    lengths = list()
    for i in range(len(dataset)):
        data = dataset[i]
        emotions.append(int(data['emotion']))
        lengths.append(data['waveform'].size(-1))

    df_emo = pd.DataFrame(emotions, columns=['emotion'])

    fig, ax = plt.subplots(2,1)
    ax1, ax2 = ax
    sns.countplot(x='emotion', data=df_emo, ax=ax1)
    ax2.hist(lengths)

    plt.show()
    
    return emotions, lengths

# straitified by emotion
def make_tar_val_dataset(dataset):
    emotions = [int(dataset[i]['emotion']) for i in range(len(dataset))]

    train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=0.2, stratify=emotions, random_state=42)
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    emotions_train = [int(train_dataset.__getitem__(i)['emotion']) for i in range(len(train_dataset))]
    plt.hist(emotions, label='all')
    plt.hist(emotions_train, label='train')
    plt.legend()
    plt.show()
    
    return train_dataset, val_dataset

# collate_fn used in dataloader
def collate_fn(batch):
    waveforms, targets = [], []

    for data in batch:
        waveforms += [data['waveform'].numpy().flatten()]
        targets += [torch.tensor(int(data['emotion']))]

    targets = torch.stack(targets)

    return waveforms, targets

## Load and Split Dataset
[EmoDB](https://www.kaggle.com/datasets/piyushagni5/berlin-database-of-emotional-speech-emodb)や[IEMOCAP](https://sail.usc.edu/iemocap/)のデータセットをダウンロードしてからデータ[torchemotion](https://github.com/alanwuha/torchemotion)を使って読み込む

データの分割は感情ラベルが均衡になるように

`train_size : test_size = 8 : 2`

で分割する

In [None]:
# IEMOCAP
data_dir = 'Iemocap_data_dir'
dataset = IemocapDataset(root=data_dir)

In [None]:
# Emodb
data_dir = 'Emodb_data_dir'
dataset = IemocapDataset(root=data_dir)

In [None]:
# staratified by emotion
emotions = [int(emodb.__getitem__(i)['emotion']) for i in range(len(emodb))]
train_indices, val_indices = train_test_split(list(range(len(emotions))), test_size=0.2, stratify=emotions, random_state=seed)
train_dataset = torch.utils.data.Subset(emodb, train_indices)
val_dataset = torch.utils.data.Subset(emodb, val_indices)

# ラベルが均衡になっているか確認
emotions_train = [int(train_dataset.__getitem__(i)['emotion'])-1 for i in range(len(train_dataset))]
plt.hist(emotions, label='all')
plt.hist(emotions_train, label='train')
plt.legend()
plt.show()

## Model
- 感情分類用：[WavLMForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/wavlm#transformers.WavLMForSequenceClassification)
- 音声データの前処理： [Wav2Vec2Processor](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Processor)

`WavLm`のAttentin機構を利用するためにバッチデータを`Wav2Vec2Processor`で処理してからモデルに入力する。

これらの事前学習モデルとして[patrickvonplaten/wavlm-libri-clean-100h-base-plus](https://huggingface.co/patrickvonplaten/wavlm-libri-clean-100h-base-plus)を使用した。

In [None]:
# processorとmodelの読み込み例
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wavlm-libri-clean-100h-base-plus") # 前処理
model = WavLMForSequenceClassification.from_pretrained("patrickvonplaten/wavlm-libri-clean-100h-base-plus", config=wavlm_config)

## Train

In [None]:
def train_model(model, dataloaders_dict, optimizer, scheduler, num_epochs, log_interval=10):
    
    # 高速化
    torch.backends.cudnn.benchmark = True
    
    log_intervals = 10
    pbar_update = 1 / sum([len(v) for v in dataloaders_dict.values()])
    
    processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wavlm-libri-clean-100h-base-plus")
    
    n_step = 0
    with tqdm(total=num_epochs) as pbar:
        for epoch in range(num_epochs):
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  
                else:
                    model.eval()   
                epoch_loss = 0.0  
                epoch_corrects = 0
                # データローダーからミニバッチを取り出すループ
                for step, (data, target) in enumerate(dataloaders_dict[phase]):
                    
                    # GPUが使えるならGPUにデータを送る
                    inputs = processor(data, sampling_rate=16000,  return_tensors='pt', padding=True)
                    
                    # DataParallelでも動くように入力をすべて同じdeviceに送る
                    input_tensor = torch.stack((inputs['input_values'], inputs['attention_mask'])).to(device)
                    
                    # wavlmのモデルに入力する辞書を作成
                    input_dict = {}
                    for i, k in enumerate(inputs.keys()):
                        input_dict[k] = input_tensor[i]
                        
                    target = target.to(device)
                    inputs = (input_dict, target) # モデルの入力に形式を合わせるためにタプルを作成
                    
                    # optimizerを初期化
                    optimizer.zero_grad()

                    # 順伝搬（forward）計算
                    with torch.set_grad_enabled(phase == 'train'):
                        
                        logits, loss = model(inputs)
                        loss = loss.mean(dim=-1) # deviceごとに平均をとる
                        preds = torch.argmax(logits, dim=-1)  # ラベルを予測

                        # 訓練時はバックプロパゲーション
                        if phase == 'train':
                            n_step += len(data)
                            loss.backward()
                            optimizer.step()
                            if scheduler is not None:
                                scheduler.step()
                            loss_log = loss.item()
                            del loss
                            if step % log_interval == 0:
                                print(f"Train Epoch: {epoch} [{step * len(data)}/{len(dataloaders_dict[phase].dataset)} ({100. * step / len(dataloaders_dict[phase]):.0f}%)]\tLoss: {loss_log:.6f}")
                            
                        # else:
                            # print(preds.to('cpu').detach().numpy()) 推論結果を表示
                       
                        epoch_loss += loss_log * len(data)
                        epoch_corrects += preds.squeeze().eq(target).sum().item()
                        
                        pbar.update(pbar_update)
                    
                # epochごとのlossと正解率
                epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
                epoch_acc = epoch_corrects / len(dataloaders_dict[phase].dataset)
               
                print('Epoch {}/{} | {:^5} |  Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, num_epochs,
                                                                               phase, epoch_loss, epoch_acc))

    return model


In [None]:
# デバイスの選択
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 設定
num_epochs = 10
batch_size = 8
learning_rate = 0.0001
pretrained_model = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
num_labels = 4

# データローダー作成
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
dataloaders_dict = {'train': train_loader, 'val': test_loader}

# モデルの初期化
model = SERwithWavLM(pretrained_model, num_labels)
model.wavlm.classifier.apply(weights_init)

# モデルの並列化
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    prams = model.module.parameters()
else:
    params = model.parameters()
model.to(device)

# optimizerとschedulerの作成
optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.9)

# 訓練と評価
model = train_model(model, dataloaders_dict, optimizer, scheduler, num_epochs, log_interval=10)