In [1]:
import sys
sys.path.append('/home/kvu/erc/libs')

import csv
import os
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

import pandas
import torch
import yaml
import seaborn as sns
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from PIL import Image
from torchvision import transforms as tvtf
from torch.utils.data import DataLoader
from utils.device import detach, move_to

from utils.getter import get_instance
from datasets import *

import warnings
warnings.filterwarnings("ignore")

In [2]:
dev_id = 'cuda:0' \
    if torch.cuda.is_available() else 'cpu'
device = torch.device(dev_id)

In [5]:
dataset = ContextAwareDataset(csv_path='/home/kvu/erc/data/raw-audios-wav/train.csv', 
                              audio_feat_dir='/home/kvu/erc/Datasets/MELD/audio-features/train', 
                              text_feat_dir='/home/kvu/erc/Datasets/MELD/text-features/train', 
                              ordered_json_list='/home/kvu/erc/Datasets/MELD/utterance-ordered.json', 
                              num_utt=8, dataset='train')

In [6]:
dataset_loader = DataLoader(dataset=dataset, batch_size=8, shuffle=False, pin_memory=False) 

In [7]:
x, y = next(iter(dataset_loader))
# x = move_to(x, device)

In [8]:
audio, text = x

In [9]:
audio[0]

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0218, -0.0711,  0.0420,  ...,  0.0138,  0.0438,  0.0177],
        [-0.0052, -0.1040,  0.0576,  ..., -0.0075,  0.0242, -0.0240],
        [ 0.0040, -0.0343,  0.0428,  ..., -0.0112, -0.0089,  0.0491]],
       dtype=torch.float64)

In [10]:
text.shape

torch.Size([8, 8, 1024])

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionModel(nn.Module):
    def __init__(self, num_classes, AUDIO_FEAT_DIM=1280, TEXT_FEAT_DIM=1024):
        super().__init__()
        self.AUDIO_FEAT_DIM = AUDIO_FEAT_DIM
        self.TEXT_FEAT_DIM = TEXT_FEAT_DIM
        self.mbp = MBP(AUDIO_FEAT_DIM, TEXT_FEAT_DIM)
        self.classifier = nn.Sequential(
            nn.Linear(1000, 512),
            # nn.Linear(AUDIO_FEAT_DIM + TEXT_FEAT_DIM, 512),
            nn.Dropout(0.2),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes),
        )
        self.fc = nn.Linear(1000, num_classes)

    def forward(self, inp):
        audio_feats, text_feats = inp
        batch_size, num_utt, _ = audio_feats.shape
        res = torch.zeros(size=(batch_size, num_utt, 1000))
        for i in range(batch_size):
            audio_feat = audio_feats[i].float()
            text_feat = text_feats[i].float()
            fused = self.mbp(audio_feat, text_feat)
            print(fused.shape)
            res[i] = fused
        return res

class MBP(nn.Module):
    """
        Multi-modal Factorized Bilinear Pooling - https://arxiv.org/pdf/1708.01471.pdf
    """
    def __init__(self, AUDIO_FEAT_DIM, TEXT_FEAT_DIM, SUM_POOLING_WINDOW=3, OUTPUT_DIM=1000):
        super().__init__()
        self.AUDIO_FEAT_DIM = AUDIO_FEAT_DIM
        self.TEXT_FEAT_DIM = TEXT_FEAT_DIM
        self.SUM_POOLING_WINDOW = SUM_POOLING_WINDOW
        self.OUTPUT_DIM = OUTPUT_DIM
        self.FUSED_DIM = SUM_POOLING_WINDOW * OUTPUT_DIM

        self.audio_linear_projection = nn.Linear(AUDIO_FEAT_DIM, self.FUSED_DIM)
        self.text_linear_projection = nn.Linear(TEXT_FEAT_DIM, self.FUSED_DIM)
        self.dropout = nn.Dropout(0.1)

    def forward(self, audio, text):
        x = self.audio_linear_projection(audio)
        y = self.text_linear_projection(text)
        z = torch.mul(x, y)
        z = self.dropout(z)
        z = z.view(-1, 1, self.OUTPUT_DIM, self.SUM_POOLING_WINDOW)
        z = torch.sum(z, dim=3)
        z = torch.squeeze(z)
        z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
        z = F.normalize(z)
        return z

In [12]:
fusion = FusionModel(num_classes=7)

In [16]:
res = fusion(x)

torch.Size([8, 1000])
torch.Size([8, 1000])
torch.Size([8, 1000])
torch.Size([8, 1000])
torch.Size([8, 1000])
torch.Size([8, 1000])
torch.Size([8, 1000])
torch.Size([8, 1000])


In [17]:
res.shape

torch.Size([8, 8, 1000])

In [24]:
res[:, -1, :] # Pick the last vector

tensor([[-0.0495,  0.0227, -0.0216,  ..., -0.0345, -0.0199,  0.0252],
        [-0.0409,  0.0411,  0.0243,  ...,  0.0335, -0.0117,  0.0170],
        [ 0.0202,  0.0652,  0.0254,  ...,  0.0331,  0.0056, -0.0263],
        ...,
        [-0.0132, -0.0082, -0.0351,  ...,  0.0141, -0.0455, -0.0116],
        [ 0.0148,  0.0291, -0.0170,  ...,  0.0412, -0.0304, -0.0283],
        [ 0.0380,  0.0259,  0.0147,  ...,  0.0206,  0.0020,  0.0341]],
       grad_fn=<SliceBackward>)