In [1]:
import os
import os.path as osp
import pandas as pd
import librosa
from PIL import Image
import matplotlib.pyplot as plt
import random
import IPython
import torch
import torchaudio

In [None]:
# MELD dataset, https://affective-meld.github.io/

In [None]:
d = 0
u = 3
video_id = 'dia'+str(d)+'_utt'+str(u)

split = 'train' # 'train', 'dev', 'test'
data_dir = '/mnt/ff1f01b3-85e2-407c-8f5d-cdcee532daa5/emodet_cache/MELD.Raw/'

anno_text = pd.read_csv(osp.join(data_dir, f'{split}_sent_emo.csv'))
images_path = data_dir + split + '_splits/frames/' + video_id + '/'
audio_path = data_dir + split + '_splits/audio/' + video_id + '.mp3'

anno_text.head()

In [None]:
# display text and labels
f_d = anno_text[anno_text['Dialogue_ID'] == d]
f_u = f_d[anno_text['Utterance_ID'] == u]
text_gt = f_u['Utterance'].item()
print(f_u['Sr No.'])
print(anno_text.iloc[845])
print('=======================')
print('Text: ', text_gt)

In [None]:
# display video frames
img_list= os.listdir(images_path)
print(len(img_list))
img = Image.open(images_path + '00000003.jpg')
plt.figure()
plt.axis('off')
plt.imshow(img)

In [None]:
# display audio
audio_sample_rate = 22050
_wav, sr = librosa.load(audio_path, sr=audio_sample_rate, mono=True)
plt.figure(figsize=(20, 5))
librosa.display.waveshow(_wav, sr=sr)
plt.axis('off')
plt.show()

In [None]:
IPython.display.Audio(audio_path)

In [7]:
# loading trained wav2vec provided by torch
device = 'cuda:0'
fixed_len = 200000

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)

# load audio files 
waveform, sample_rate = torchaudio.load(audio_path)
waveform = waveform.to(device)
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)[0].reshape(1, -1)

In [None]:
with torch.inference_mode():
    emission, _ = model(waveform)
    print(emission.shape)
plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()

In [None]:
from model.pre_audio import GreedyCTCDecoder

decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
print(transcript)
print('gt: ', text_gt)

In [13]:
from data.meld_data import MELD

dataset = MELD(target='multimodal_finetune')
print(dataset.__len__())
frame, target = dataset.__getitem__(0)
print(frame.shape)
print(target.keys())
print(target['audio_wav'].shape)

9368
torch.Size([8, 3, 224, 224])
dict_keys(['utt_token', 'emotion_idx', 'sentiment_idx', 'audio_wav'])
torch.Size([283, 768])


In [1]:
from model.clip import EmotionCLIP
ckpt_path = '/home/minxiao/workspace/emo_det/EmotionCLIP/checkpoints_download/emotionclip_latest.pt'
backbone_ckpt_path = '/home/minxiao/workspace/emo_det/EmotionCLIP/src/pretrained/vit_b_32-laion2b_e16-af8dbd0c.pth'

model = EmotionCLIP(
    temporal_fusion='transformer',#'mean',
    video_len=8,
    backbone_checkpoint=backbone_ckpt_path,
    reset_logit_scale = False
)