# Transcript-to-Transformer Processing

This notebook pre-processes transcript data from the `Narratives` dataset to:  
1. Generate nuisance vectors for use in regression analysis (phonemes, etc).  
2. Produce a list-of-TRs output that can be fed into Transformer models.


It accepts either `.xlsx` or `.json` timestamped-transcript format. `json` is required to produce phoneme vectors.

Note that it requires a certain directory structure (e.g. the Narratives data should be downloaded into a `/data/stimuli/STORY_NAME` subfolder.

In [3]:
STORY = "black" # black, slumlordreach

PHONEMES = False

In [4]:
import json
import pandas as pd
import itertools

if PHONEMES:

    original_json = json.load(open("./data/stimuli/{}/align.json".format(STORY)))
    original_transcript = pd.DataFrame.from_records(original_json['words'])
    original_transcript.rename(axis='columns', mapper={'start': 'start_ts', 'end':'end_ts', 'word': 'cased', 'alignedWord': 'uncased'}, inplace=True)

else:
    
    original_transcript = pd.read_csv("./data/stimuli/{}/align.csv".format(STORY), header=None, 
                                  names=["cased", "uncased", "start_ts", "end_ts"])

## First pass: discovering / correcting some timestamp errors.

In [5]:
original_transcript.head()

Unnamed: 0,cased,uncased,start_ts,end_ts
0,So,so,0.24,0.63
1,I,i,0.68,1.26
2,was,was,1.96,2.3
3,a,a,2.3,2.45
4,junior,junior,2.46,3.14


In [6]:
# NB: a handful of datapoints' TS are null. We backfill them.
original_transcript.end_ts = original_transcript.end_ts.bfill()
original_transcript.start_ts = original_transcript.start_ts.bfill()

In [7]:
def seconds_to_tr(seconds):
    """
    Segment into TRs starting at TR=0. Events are segmented based on their end TS.
        1.0s = TR0
        1.6s = TR1
        3.1s = TR3
        etc
    """
    return int(seconds / 1.5)

In [8]:
original_transcript["tr"] = original_transcript.end_ts.apply(lambda x: seconds_to_tr(x))

In [9]:
# Transform NaN phonemes into empty lists

if PHONEMES:
    original_transcript['phones'] = original_transcript['phones'].apply(lambda d: d if isinstance(d, list) else [])

In [10]:
original_transcript.iloc[4:20]

Unnamed: 0,cased,uncased,start_ts,end_ts,tr
4,junior,junior,2.46,3.14,2
5,in,in,3.14,3.41,2
6,college,college,3.41,4.2,2
7,when,when,4.79,5.02,3
8,I,i,5.02,5.09,3
9,got,got,5.09,5.33,3
10,my,my,5.35,5.59,3
11,first,first,5.61,6.34,4
12,paying,paying,7.11,7.63,5
13,in,in,7.64,7.66,5


In [11]:
def n_phonemes(tr_group):
    
    return sum(tr_group.phones.apply(len))

def phoneme_set(tr_group):
    
    try:
        all_phonemes = list(itertools.chain.from_iterable(tr_group.phones))
        unique_phonemes = set([p["phone"].split("_")[0] for p in all_phonemes])
        return unique_phonemes
    except TypeError:
        return {}

In [12]:
if PHONEMES:
    derived_phoneme_list = list(set([s.split("_")[0] for s in phoneme_set(original_transcript)]))
    print(len(derived_phoneme_list))
    json.dumps(derived_phoneme_list)

In [13]:
PHONEME_LIST_FROZEN = ["ao", "iy", "m", "dh", "ow", "k", "w", "ey", "s", "ch", "sh", "aw", "ay", "l", "jh", "v", "g", "r", "oy", "er", "ae", "d", "hh", "th", "ih", "uw", "aa", "z", "zh", "oov", "ng", "p", "f", "ah", "n", "b", "uh", "y", "t", "eh"]

def in_set(p, phoneme_set):

    if p in phoneme_set:
        return 1
    else:
        return 0

def phoneme_vector(tr_group):
    
    set_of_phonemes = phoneme_set(tr_group)
    return [in_set(p, set_of_phonemes) for p in PHONEME_LIST_FROZEN]

In [14]:
# print(original_transcript.phones[4:10])
# phoneme_vector(original_transcript[4:10])
# all_phonemes = list(itertools.chain.from_iterable(original_transcript.phones))
# unique_phonemes = set([p["phone"] for p in all_phonemes])
# unique_phonemes

In [15]:
TR_TO_CHECK = 16

tr_x = original_transcript[original_transcript.tr == TR_TO_CHECK]
original_transcript[original_transcript.tr == TR_TO_CHECK]

Unnamed: 0,cased,uncased,start_ts,end_ts,tr
45,weekends,weekends,23.29,24.1,16
46,was,was,24.58,24.84,16
47,to,to,24.84,24.99,16


In [16]:
tr_grouped = []

for k, g in original_transcript.groupby("tr"):
    tr_grouped.append({
        "start_ts": g.start_ts.min(),
        "end_ts": g.end_ts.max(),
        "tr": k,
        "tokens": " ".join(g.cased.values),
        "n_tokens": len(g)
    })
    
    if PHONEMES:
        tr_grouped[-1]['phoneme_vector'] = phoneme_vector(g)
        tr_grouped[-1]['n_phonemes'] = n_phonemes(g)

df = pd.DataFrame.from_records(tr_grouped)
df.head()

Unnamed: 0,start_ts,end_ts,tr,tokens,n_tokens
0,0.24,1.26,0,So I,2
1,1.96,2.45,1,was a,2
2,2.46,4.2,2,junior in college,3
3,4.79,5.59,3,when I got my,4
4,5.61,6.34,4,first,1


In [17]:
df.n_tokens.value_counts().sort_index()

1     60
2    111
3    129
4     89
5     54
6     23
7     11
8      2
9      2
Name: n_tokens, dtype: int64

## Pad missing TRs

A `tr_shift` of greater than 1 indicates that we need to "pad" by inserting an additional, silent TR.

In [19]:
df["tr_shift"] = df.tr - df.tr.shift(1)
df["prev_tr"] = df.tr.shift(1)
df.tr_shift.value_counts()

1.0    446
2.0     25
3.0      5
4.0      1
7.0      1
6.0      1
5.0      1
Name: tr_shift, dtype: int64

In [20]:
df[df["tr_shift"] > 2]

Unnamed: 0,start_ts,end_ts,tr,tokens,n_tokens,tr_shift,prev_tr
46,75.13,76.34,50,Then I played commercials,4,4.0,46.0
99,164.37,164.86,109,Now keep in,3,7.0,102.0
187,306.23,306.77,204,Boom,1,3.0,201.0
202,337.099999,337.42,224,I could,2,3.0,221.0
243,403.96,404.56,269,Well I,2,3.0,266.0
245,409.889999,410.85,273,My father speaks,3,3.0,270.0
256,434.109999,434.89,289,But I'm still,3,6.0,283.0
421,695.719999,696.79,464,And I said,3,3.0,461.0
480,799.81,799.83,533,you,1,5.0,528.0


In [21]:
import itertools

def generate_missing_trs(row):
    
    if row["tr_shift"] > 1:
        return [{"tokens": "", "tr": int(row["prev_tr"] + i + 1)} for i in range(0, int(row["tr_shift"] - 1))]

def pad_missing_trs(df):
    
    missing = df.apply(lambda x: generate_missing_trs(x), axis=1)
    missing = missing[missing.values != None].values

    missing_tr_df = pd.DataFrame.from_records(itertools.chain.from_iterable(missing))
    
    return missing_tr_df

# Concat and sort by inferred TR to make sure our empty-space TRs get slotted in appropriately
final_df = pd.concat([df, pad_missing_trs(df)]).sort_values("tr")

final_df.tail(10)

Unnamed: 0,start_ts,end_ts,tr,tokens,n_tokens,tr_shift,prev_tr
475,785.44,787.419999,524,white I,2.0,1.0,523.0
476,787.42,788.519999,525,have this job,3.0,1.0,524.0
477,789.059999,790.28,526,because I am good,4.0,1.0,525.0
478,790.57,791.469999,527,at what I do,4.0,1.0,526.0
479,792.35,792.61,528,Thank,1.0,1.0,527.0
49,,,529,,,,
50,,,530,,,,
51,,,531,,,,
52,,,532,,,,
480,799.81,799.83,533,you,1.0,5.0,528.0


In [22]:
# Set index to TR
final_df.index = final_df.tr

# Make sure no duplicates
final_df.tr.value_counts()

0      1
351    1
365    1
364    1
363    1
      ..
172    1
171    1
170    1
169    1
533    1
Name: tr, Length: 534, dtype: int64

In [23]:
final_df[:20]

Unnamed: 0_level_0,start_ts,end_ts,tr,tokens,n_tokens,tr_shift,prev_tr
tr,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0.24,1.26,0,So I,2.0,,
1,1.96,2.45,1,was a,2.0,1.0,0.0
2,2.46,4.2,2,junior in college,3.0,1.0,1.0
3,4.79,5.59,3,when I got my,4.0,1.0,2.0
4,5.61,6.34,4,first,1.0,1.0,3.0
5,7.11,7.66,5,paying in,2.0,1.0,4.0
6,8.929999,9.99,6,my field,2.0,1.0,5.0
7,10.29,11.97,7,on the radio This,4.0,1.0,6.0
8,11.969999,13.389999,8,is not an internship,4.0,1.0,7.0
9,13.83,14.45,9,I'm getting a,3.0,1.0,8.0


In [24]:
final_df.to_csv("data/stimuli/{}/tr_tokens.csv".format(STORY))

In [25]:
len(final_df)

534

In [26]:
final_df.head()

Unnamed: 0_level_0,start_ts,end_ts,tr,tokens,n_tokens,tr_shift,prev_tr
tr,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0.24,1.26,0,So I,2.0,,
1,1.96,2.45,1,was a,2.0,1.0,0.0
2,2.46,4.2,2,junior in college,3.0,1.0,1.0
3,4.79,5.59,3,when I got my,4.0,1.0,2.0
4,5.61,6.34,4,first,1.0,1.0,3.0


In [32]:
if PHONEMES:
    final_df['phoneme_vector'] = final_df['phoneme_vector'].apply(lambda d: d if isinstance(d, list) else [0] * len(PHONEME_LIST_FROZEN))
    final_df['n_phonemes'].fillna(0, inplace=True)
    final_df['n_tokens'].fillna(0, inplace=True)
    
    final_df[500:]

In [28]:
import numpy as np

if PHONEMES:
    phoneme_vector = np.stack(final_df.phoneme_vector)
    np.save("{}_phoneme_vectors.npy".format(STORY), phoneme_vector)
    np.save("{}_phoneme_counts.npy".format(STORY), final_df.n_phonemes)
    np.save("{}_word_counts.npy".format(STORY), final_df.n_tokens)
            

In [None]:
print("scp {}_phoneme*.npy {}_word*.npy tsumers@apps.pni.princeton.edu:/jukebox/griffiths/bert-brains/code/bert-brains/data/{}".format(STORY, STORY, STORY))

scp black_phoneme*.npy black_word*.npy tsumers@apps.pni.princeton.edu:/jukebox/griffiths/bert-brains/code/bert-brains/data/black
