Multi-modal mlp mixer

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# cd drive/MyDrive/sonny

/content/drive/MyDrive/sonny


In [None]:
!pip install transformers
!pip install datasets
!pip install einops

# Prepare dataset

In [None]:
import soundfile as sf
import torch
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import random

N_SAMPLES = 480000
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

In [None]:
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
    """
    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
    """
    if torch.is_tensor(array):
        if array.shape[axis] > length:
            array = array.index_select(
                dim=axis, index=torch.arange(length, device=array.device)
            )

        if array.shape[axis] < length:
            pad_widths = [(0, 0)] * array.ndim
            pad_widths[axis] = (0, length - array.shape[axis])
            array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
    else:
        if array.shape[axis] > length:
            array = array.take(indices=range(length), axis=axis)

        if array.shape[axis] < length:
            pad_widths = [(0, 0)] * array.ndim
            pad_widths[axis] = (0, length - array.shape[axis])
            array = np.pad(array, pad_widths)

    return array

In [None]:
class CustomDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file, header=[0, 1])
        self.text_data = self.data['text_data'][' '].values
        self.wav_dir = self.data['wav_dir'][' '].values
        self.dic = {'happy': 0, 'surprise': 1, 'angry': 2, 'neutral': 3, 'disqust': 4, 'fear': 5, 'sad': 6}
        self.labels = self.data['Total Evaluation']['Emotion'].values
      
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if ';' in self.labels[idx]:
          self.labels[idx] = self.labels[idx].split(';')[random.choice([0,1])]

        audio_input, sample_rate = sf.read(self.wav_dir[idx])
        audio_input = pad_or_trim(audio_input)
        audio_input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values.squeeze(0)

        return self.text_data[idx], audio_input_values, self.dic[self.labels[idx]]

# MODELING
1. MLP MIXER
2. concat
3. Cross-Attention
4. 3-way concat

In [None]:
from functools import partial
from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from transformers import AutoTokenizer, AutoModel, Wav2Vec2ForCTC

1. MLP MIXER

In [None]:
pair = lambda x: x if isinstance(x, tuple) else (x, x)

# mlp_mixer_block
class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        #self.rearrange = Rearrange('b c d -> b d c')

    def forward(self, x):
        return self.fn(self.norm(x)) + x

class FeedForward(nn.Module):
    def __init__(self, dim, expansion_factor = 4, dropout = 0.):
        super().__init__()
        self.inner_dim = int(dim * expansion_factor)
        self.dropout = dropout
        self.dense_1 = nn.Linear(dim, self.inner_dim)
        self.gelu = nn.GELU()
        self.dropout_1 = nn.Dropout(dropout)
        self.dense_2 = nn.Linear(self.inner_dim, dim)
        
    def forward(self, x):
        x = self.dense_1(x)
        x = self.gelu(x)
        x = self.dropout_1(x)
        x = self.dense_2(x)
        x = self.dropout_1(x)
        return x

class MLPMixer(nn.Module):
    def __init__(self, dim, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
        super().__init__()
        self.dim = dim
        self.num_classes = num_classes
        self.expansion_factor = expansion_factor
        self.expansion_factor_token = expansion_factor_token
        self.dropout = dropout
        self.linear_1 = nn.Linear
        self.rearrange = Rearrange('b c d -> b d c')
        self.preNormResidual_1 = PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout))
        self.preNormResidual_2 = PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout))
        self.norm = nn.LayerNorm(dim)
        self.reduce_ = Reduce('b n c -> b c', 'mean')
        self.linear_2 = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.rearrange(x)
        x = self.preNormResidual_1(x)
        x = self.preNormResidual_2(x)
        x = self.norm(x)
        x = self.reduce_(x)
        #x = self.linear_2(x)
        return x

In [None]:
mlpMixer_model = MLPMixer(dim = 768, num_classes = 7)
mlpMixer_model

MLPMixer(
  (rearrange): Rearrange('b c d -> b d c')
  (preNormResidual_1): PreNormResidual(
    (fn): FeedForward(
      (dense_1): Linear(in_features=768, out_features=3072, bias=True)
      (gelu): GELU(approximate='none')
      (dropout_1): Dropout(p=0.0, inplace=False)
      (dense_2): Linear(in_features=3072, out_features=768, bias=True)
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (preNormResidual_2): PreNormResidual(
    (fn): FeedForward(
      (dense_1): Linear(in_features=768, out_features=384, bias=True)
      (gelu): GELU(approximate='none')
      (dropout_1): Dropout(p=0.0, inplace=False)
      (dense_2): Linear(in_features=384, out_features=768, bias=True)
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (reduce_): Reduce('b n c -> b c', 'mean')
  (linear_2): Linear(in_features=768, out_features=7, bias=True)
)

In [None]:
# mlp_mixer_model
class Classify(nn.Module):
  def __init__(self, input_size, class_num):
    super(Classify, self).__init__()
    self.fc = nn.Linear(input_size, class_num)

  def forward(self, avg_vecs):
    logit = self.fc(avg_vecs)
    return logit

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.wav_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", output_hidden_states=True)
        self.txt_model = AutoModel.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.wav_encoder
        self.txt_encoder
        self.project1 = Project(768)
        self.project2 = Project(768)
        self.classification_model = Classify(768, 7)
        self.mlp_mixer_model = MLPMixer(dim = 768, num_classes = 7)
      

    def wav_encoder(self, wav_dir):
        logits = self.wav_model(wav_dir)['hidden_states'][-1]
        return logits

    def txt_encoder(self, text_tensor):
        outputs = self.txt_model(**txt_tensor)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states

    def forward(self, text_data, wav_dir):
        encoder_layer_1 = self.wav_encoder(wav_dir)
        encoder_layer_2 = self.txt_encoder(text_data)

        out1 = self.project1(encoder_layer_1)
        out2 = self.project2(encoder_layer_2)

        concat = torch.cat([out1, out2], dim = 1)

        concat = concat.transpose(1,2)
        result = self.mlp_mixer_model(concat)

        logit = self.classification_model(result).squeeze(1)
        softmax = F.softmax(logit, dim=1)
        prediction = torch.argmax(softmax, dim=1)

        return logit, softmax, prediction

class Project(nn.Module):
    def __init__(self, dim):
      super().__init__()
      self.layer = nn.Linear(dim, dim)
        
    def forward(self, x):
      return self.layer(x)

2. concat

In [None]:
# concat
class Classify(nn.Module):
  def __init__(self, input_size, class_num):
    super(Classify, self).__init__()
    self.fc = nn.Linear(input_size, class_num)

  def forward(self, avg_vecs):
    logit = self.fc(avg_vecs)
    return logit

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.wav_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", output_hidden_states=True)
        self.txt_model = AutoModel.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.wav_encoder
        self.txt_encoder
        self.project1 = Project(768)
        self.project2 = Project(768)
        self.classification_model = Classify(768, 7)

    def wav_encoder(self, wav_dir):
        # load audio
        #audio_input, sample_rate = sf.read(wav_dir)
        # pad input values and return pt tensor
        #input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
        # retrieve logits & take argmax
        logits = self.wav_model(wav_dir)['hidden_states'][-1]
        # print(logits)
        #predicted_ids = torch.argmax(logits, dim=-1)
        return logits

    def txt_encoder(self, text_tensor):

        outputs = self.txt_model(**txt_tensor)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states

    def forward(self, text_data, wav_dir):
        #encoder_layer_1 = wav_dir.squeeze(0)
        encoder_layer_1 = self.wav_encoder(wav_dir)
        encoder_layer_2 = self.txt_encoder(text_data)

        out1 = self.project1(encoder_layer_1)
        out2 = self.project2(encoder_layer_2)

        # print(out1.size(), out2.size())
        concat = torch.cat([out1, out2], dim = 1)
        # (batch_size, 길이?, 768)
        pool = nn.AdaptiveAvgPool2d((1,concat.size()[2]))
        result = pool(concat).squeeze(dim=2)

        logit = self.classification_model(result).squeeze(1)
        softmax = F.softmax(logit, dim=1)
        prediction = torch.argmax(softmax, dim=1)

        return logit, softmax, prediction

class Project(nn.Module):
    def __init__(self, dim):
      super().__init__()
      self.layer = nn.Linear(dim, dim)
        
    def forward(self, x):
      return self.layer(x)

3. Cross-Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_k = self.d_model // self.n_head

        # Q, K, V에 대한 Linear Layers
        self.w_qs = nn.Linear(d_model, n_head * self.d_k)
        self.w_ks = nn.Linear(d_model, n_head * self.d_k)
        self.w_vs = nn.Linear(d_model, n_head * self.d_k)

        # Scaled Dot-Product Attention용 Linear Layer
        self.fc = nn.Linear(n_head * self.d_k, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # Q, K, V에 대한 Linear Layers를 지난 후, head를 나누는 것이 핵심
        qs = self.w_qs(q).view(batch_size, -1, self.n_head, self.d_k)
        ks = self.w_ks(k).view(batch_size, -1, self.n_head, self.d_k)
        vs = self.w_vs(v).view(batch_size, -1, self.n_head, self.d_k)

        # head를 transpose해서 batch_size와 head를 맞바꿔준다
        qs = qs.transpose(1,2).contiguous().view(batch_size * self.n_head, -1, self.d_k)
        ks = ks.transpose(1,2).contiguous().view(batch_size * self.n_head, -1, self.d_k)
        vs = vs.transpose(1,2).contiguous().view(batch_size * self.n_head, -1, self.d_k)

        # Scaled Dot-Product Attention
        attn = torch.bmm(qs, ks.transpose(1, 2)) / (self.d_k ** 0.5)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = F.softmax(attn, dim=-1)
        output = torch.bmm(attn, vs)

        # Concatenate multi-heads
        output = output.view(batch_size, self.n_head, -1, self.d_k)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)

        # Linear Layer for output
        output = self.fc(output)

        return output

In [None]:
# transformer cross_attention
class Classify(nn.Module):
  def __init__(self, input_size, class_num):
    super(Classify, self).__init__()
    self.fc = nn.Linear(input_size, class_num)

  def forward(self, avg_vecs):
    logit = self.fc(avg_vecs)
    return logit

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.wav_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", output_hidden_states=True)
        self.txt_model = AutoModel.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.wav_encoder
        self.txt_encoder
        self.project1 = Project(768)
        self.project2 = Project(768)
        self.classification_model = Classify(768, 7)
        self.attention = MultiHeadAttention(768, 8)
        self.pool = nn.AdaptiveAvgPool2d((1, 7))
      

    def wav_encoder(self, wav_dir):
        logits = self.wav_model(wav_dir)['hidden_states'][-1]
        return logits

    def txt_encoder(self, text_tensor):
        outputs = self.txt_model(**txt_tensor)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states

    def forward(self, text_data, wav_dir):
        encoder_layer_1 = self.wav_encoder(wav_dir)
        encoder_layer_2 = self.txt_encoder(text_data)

        out1 = self.project1(encoder_layer_1)
        out2 = self.project2(encoder_layer_2)

        concat = self.attention(out1, out2, out2)
        pool = nn.AdaptiveAvgPool2d((1, concat.size()[2]))
        result = pool(concat).squeeze(dim=1)

        logit = self.classification_model(result)

        softmax = F.softmax(logit, dim=1)
        prediction = torch.argmax(softmax, dim=1)

        return logit, softmax, prediction

class Project(nn.Module):
    def __init__(self, dim):
      super().__init__()
      self.layer = nn.Linear(dim, dim)
        
    def forward(self, x):
      return self.layer(x)

4. 3-way

In [None]:
# 3-way
class Classify(nn.Module):
  def __init__(self, input_size, class_num):
    super(Classify, self).__init__()
    self.fc = nn.Linear(input_size, class_num)

  def forward(self, avg_vecs):
    logit = self.fc(avg_vecs)
    return logit

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.wav_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h", output_hidden_states=True)
        self.txt_model = AutoModel.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base-v2022")
        self.wav_encoder
        self.txt_encoder
        self.project1 = Project(768)
        self.project2 = Project(768)
        self.classification_model = Classify(768, 7)
        self.attention = MultiHeadAttention(768, 8)
        self.pool = nn.AdaptiveAvgPool2d((1, 7))
      

    def wav_encoder(self, wav_dir):
        logits = self.wav_model(wav_dir)['hidden_states'][-1]
        return logits

    def txt_encoder(self, text_tensor):
        outputs = self.txt_model(**txt_tensor)
        last_hidden_states = outputs.last_hidden_state
        return last_hidden_states

    def forward(self, text_data, wav_dir):
        encoder_layer_1 = self.wav_encoder(wav_dir)
        encoder_layer_2 = self.txt_encoder(text_data)

        out1 = self.project1(encoder_layer_1)
        out2 = self.project2(encoder_layer_2)

        concat = torch.cat([out1, out2], dim = 1)
        min_length = min(len(out1[0]), len(out2[0]))

        # 요소 곱
        concat_1 = out1[:, :min_length, :] * out2[:, :min_length, :]

        # 요소 차
        concat_2 = torch.abs(out1[:, :min_length, :] - out2[:, :min_length, :])

        # 3-way concat
        concat_3 = torch.cat([concat, concat_1, concat_2], dim = 1)

        pool = nn.AdaptiveAvgPool2d((1,concat.size()[2]))
        result = pool(concat).squeeze(dim=1)

        logit = self.classification_model(result)

        softmax = F.softmax(logit, dim=1)
        prediction = torch.argmax(softmax, dim=1)

        return logit, softmax, prediction

class Project(nn.Module):
    def __init__(self, dim):
      super().__init__()
      self.layer = nn.Linear(dim, dim)
        
    def forward(self, x):
      return self.layer(x)

In [None]:
model = Model()

# Train & Valid

In [None]:
from sklearn.metrics import classification_report, f1_score

In [None]:
model.txt_model.requires_grad, model.wav_model.requires_grad = False, False

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score

train_dataset = CustomDataset('KEMDy20_train_data.csv')
valid_dataset = CustomDataset('KEMDy20_val_data.csv')
test_dataset = CustomDataset('KEMDy20_test_data.csv')

In [None]:
import torch.optim as optim
import numpy as np
import random
from tqdm.notebook import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers = 3)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=True, num_workers = 3)

class_weights = [1/1274,1/180,1/197,1/9098,1/94,1/53,1/182]  
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights)).to(device)
model.train()
model.txt_model.requires_grad, model.wav_model.requires_grad = False, False
optimizer = optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=0.0001)

In [None]:
model = model.to('cuda')
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(t:=tqdm(train_loader)):
        txt_data, input_values, labels = data

        txt_tensor = model.tokenizer(txt_data, return_tensors="pt", padding = True)
        
        for key in txt_tensor.keys():
            txt_tensor[key] = txt_tensor[key].to("cuda")
        input_values = input_values.to("cuda")
        labels = labels.to("cuda")

        
        optimizer.zero_grad()
        _, softmax, outputs = model(txt_tensor, input_values)
        loss = criterion(softmax, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        t.set_postfix_str(f' Loss: {running_loss/(i+1)}')

        
    model.eval()

    running_loss = 0.0
    pred_label = []
    true_label = []

    for i, data in enumerate(t:=tqdm(valid_loader)):
        txt_data, input_values, labels = data

        txt_tensor = model.tokenizer(txt_data, return_tensors="pt", padding = True)
        
        for key in txt_tensor.keys():
            txt_tensor[key] = txt_tensor[key].to("cuda")
        input_values = input_values.to("cuda")
        labels = labels.to("cuda")
        
        _, softmax, outputs = model(txt_tensor, input_values)
        loss = criterion(softmax, labels)

        running_loss += loss.item()
        
        outputs = outputs.detach().to('cpu')
        labels = labels.detach().to('cpu')
        pred_label += outputs.tolist()
        true_label += labels.tolist()

        t.set_postfix_str(f' Val_Loss: {running_loss/(i+1)}')

    model.train()
    f1score = f1_score(true_label, pred_label, average='weighted')

    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}, F1_score: {f1score}')
    print(classification_report(true_label, pred_label))

GPU 메모리 초기화

In [None]:
torch.cuda.memory_allocated()
torch.cuda.memory_reserved()

14600372224

In [None]:
import torch as th
del input_values
del labels
for i in list(txt_tensor.keys()):
  del txt_tensor[i]
del outputs
del loss
del model

th.cuda.empty_cache()