In [6]:
import os
subs_dir = 'subtitles'
vids_dir = 'videos'
subs_files = [x for x in sorted(os.listdir(subs_dir)) if x[-4:] == '.vtt']
vids_files = [x for x in sorted(os.listdir(vids_dir)) if x[-4:] == '.m4v']
print(subs_files)
print(vids_files)

['0.vtt', '1.vtt', '2.vtt', '3.vtt', '4.vtt', '5.vtt', '6.vtt', '7.vtt', '8.vtt', '9.vtt']
['0.m4v', '1.m4v', '2.m4v', '3.m4v', '4.m4v', '5.m4v', '6.m4v', '7.m4v', '8.m4v', '9.m4v']


In [7]:
import webvtt

subs = []
for i, subs_file in enumerate(subs_files):
    print(i)
    subs.append(webvtt.read(os.path.join(subs_dir, subs_file)))

for caption in subs[0]:
    print('start:', caption.start)
    print('caption:', caption.text)
    print('end:', caption.end)
    print('----------------')

print(subs[6][0].text)

0
1
2
3
4
5
6
7
8
9
start: 00:00:00.000
caption:  
a nucleus in the nucleus or some things
end: 00:00:04.460
----------------
start: 00:00:04.460
caption: a nucleus in the nucleus or some things
 
end: 00:00:04.470
----------------
start: 00:00:04.470
caption: a nucleus in the nucleus or some things
called protons
end: 00:00:05.559
----------------
start: 00:00:05.559
caption: called protons
 
end: 00:00:05.569
----------------
start: 00:00:05.569
caption: called protons
something's called neutrons and outside
end: 00:00:09.169
----------------
start: 00:00:09.169
caption: something's called neutrons and outside
 
end: 00:00:09.179
----------------
start: 00:00:09.179
caption: something's called neutrons and outside
or something called electrons that's all
end: 00:00:11.660
----------------
start: 00:00:11.660
caption: or something called electrons that's all
 
end: 00:00:11.670
----------------
start: 00:00:11.670
caption: or something called electrons that's all
the atomic structure 

In [3]:
"""
The data uses middle minute of the videos, thus we need to remove captions that don't belong
to that timeframe. Also there should be a caption per each .5 sec in the video to match eeg data rows.
"""
from moviepy.editor import VideoFileClip  # This library is complete overkill for this purpose
import time
     
# Read video file durations
video_durations = [] 
for videofile in vids_files:
    clip = VideoFileClip(os.path.join(vids_dir, videofile))
    print(videofile, clip.duration)
    video_durations.append(clip.duration)

0.m4v 141.88
1.m4v 143.8
2.m4v 123.8
3.m4v 117.52
4.m4v 146.24
5.m4v 124.6
6.m4v 117.4
7.m4v 114.32
8.m4v 126.12
9.m4v 123.76


In [4]:
import pandas as pd
eeg_df = pd.read_csv('EEG_data.csv')
print(eeg_df.iloc[0])
eeg_df.VideoID = pd.to_numeric(eeg_df.VideoID)
eeg_df.SubjectID = pd.to_numeric(eeg_df.SubjectID)
print(eeg_df.VideoID.value_counts().sort_index())
print(eeg_df.SubjectID.value_counts().sort_index())

# Users seem to have different amounts of watch time assuming .5 second measurement interval is correct 
for subject_id in sorted(eeg_df.SubjectID.unique()):
    print('SubjecttId', subject_id)
    print(eeg_df.query('SubjectID == {}'.format(subject_id)).VideoID.value_counts().sort_index())

SubjectID                  0.0
VideoID                    0.0
Attention                 56.0
Mediation                 43.0
Raw                      278.0
Delta                 301963.0
Theta                  90612.0
Alpha1                 33735.0
Alpha2                 23991.0
Beta1                  27946.0
Beta2                  45097.0
Gamma1                 33228.0
Gamma2                  8293.0
predefinedlabel            0.0
user-definedlabeln         0.0
Name: 0, dtype: float64
0.0    1412
1.0    1414
2.0    1274
3.0    1206
4.0    1356
5.0    1230
6.0    1181
7.0    1177
8.0    1280
9.0    1281
Name: VideoID, dtype: int64
0.0    1261
1.0    1301
2.0    1284
3.0    1314
4.0    1295
5.0    1262
6.0    1275
7.0    1276
8.0    1282
9.0    1261
Name: SubjectID, dtype: int64
SubjecttId 0.0
0.0    144
1.0    140
2.0    142
3.0    122
4.0    116
5.0    123
6.0    116
7.0    112
8.0    124
9.0    122
Name: VideoID, dtype: int64
SubjecttId 1.0
0.0    140
1.0    142
2.0    122
3.0    116
4

In [10]:
# filter captions from subs that are not in the middle minute
import numpy as np


def strf_seconds(seconds):
    """
    return: seconds (int) as HH:MM:SS.whatevers_remaining
    """
    return '{:2.0f}:{:2.0f}:{:2.0f}.{}'.format(seconds / 3600, int(int(seconds) / 60) % 60, int(seconds) % 60,
                                         str(seconds).split('.')[-1]).replace(' ', '0')

def get_video_intervals(vid_duration, n_rows):
    """
    return: .5 video intervals per video and subject
    """
    half_usertime = n_rows / 4  # 1 row is .5 seconds
    start_time = vid_duration / 2 - half_usertime
    end_time = vid_duration / 2 + half_usertime

    intervals = np.arange(2 * start_time, 2 * end_time) / 2
    if len(intervals) > n_rows: intervals = intervals[:-1]
    assert len(intervals) == n_rows
    return [strf_seconds(interval) for interval in intervals]

def get_subs_for_rows(vid_subs, vid_intervals):
    #print(vid_intervals)
    row_subs = []
    a = 0
    b = 0
    #print(len(vid_intervals), len(vid_subs))
    while b < len(vid_intervals):
        while a < len(vid_subs) and vid_subs[a].end < vid_intervals[b]:
            a += 1
        if a < len(vid_subs) and vid_subs[a].start < vid_intervals[b] < vid_subs[a].end: 
            row_subs.append(vid_subs[a].text)
        else: row_subs.append('<empty>')
        b += 1
    assert len(vid_intervals) == len(row_subs)
    return row_subs

# Test strf_seconds
print(strf_seconds(119.5))

00:01:59.5


In [9]:
# make on row in subs as caption for a .5 sec period (as if a sample is taken every .5 seconds)
subs_per_row = []
for subject_id in sorted(eeg_df.SubjectID.unique()):
    for video_id in sorted(eeg_df.VideoID.unique()):
        #print(subject_id, video_id)
        video_duration = video_durations[int(video_id)]
        video_intervals = get_video_intervals(video_duration,
                                              len(eeg_df.query('SubjectID == {} and VideoID == {}'
                                                              .format(subject_id, video_id))))
        subs_per_row += get_subs_for_rows(subs[int(video_id)], video_intervals)
            

In [None]:
# elmo embed subs, requires https://github.com/HIT-SCIR/ELMoForManyLangs
from elmoformanylangs import Embedder
import numpy as np
e = Embedder('../../../text_embedding_repos/ELMoForManyLangs/Elmo_english_pre')


In [None]:
subs_sents = [sub.split() for sub in subs_per_row]
elmo_sents = e.sents2elmo(subs_sents)
elmo_avg_sents = [np.mean(vec, axis=0) for vec in elmo_sents]

In [None]:
import pandas as pd

data = []
for i, row in eeg_df.iterrows():
    row = np.insert(elmo_avg_sents[i], (0, 0), (row.SubjectID, row.VideoID)).astype(np.float32)
    data.append(row)

df = pd.DataFrame(data, columns=['SubjectID', 'VideoID', *[str(x) for x in range(len(elmo_avg_sents[0]))]])
#df.to_csv('elmo_embedded_subs.csv', index=False)

In [None]:
import pandas as pd
df = pd.read_csv('elmo_embedded_subs.csv')

In [None]:
# Whole df takes 257M space -> split by VideoID
for video_id in sorted(df.VideoID.unique()):
    vid_df = df.query('VideoID == {}'.format(video_id))
    vid_df.to_csv('vid_{}_elmo_embedded_subs.csv'.format(int(video_id)), index=False)

In [None]:
# Combining vid_dfs
import numpy as np
vid_dfs = pd.concat([pd.read_csv('vid_{}_elmo_embedded_subs.csv'.format(i))
                     for i in range(10)], ignore_index=True
                   ).sort_values(['SubjectID', 'VideoID']).reset_index(drop=True)

vec_cols = [str(x) for x in range(1024)]
vid_dfs[vec_cols] = vid_dfs[vec_cols].astype(np.float32) # for speed

In [None]:
def round5(x):
    return round(x, 5)

assert vid_dfs.applymap(round5).equals(df.applymap(round5))