In [1]:
# Notebook for fine-tuning BERT related models

In [160]:
from math import sqrt
import regex as re
import os
from glob import glob
import numpy as np
import pandas as pd
import torch
import torch.nn
from torch.utils.data import DataLoader, Dataset
import transformers as ppb
import warnings

warnings.filterwarnings('ignore')

BERT EMBEDDER
+ TIME OF EACH SENT
+ Demographics? Sex / Age?

LSTM (WITH ATTENTION) OVER SEQUENCE ??? 

In [3]:
os.getcwd()

'/Users/tom/phd/ADReSS_Challenge/ADReSS_Challenge'

In [8]:
pitt_cookie_cc = '../data/Pitt/Control/cookie/**'
pitt_fluency_cc = '../data/Pitt/Control/fluency/**'
pitt_cookie_ad = '../data/Pitt/Dementia/cookie/**'
pitt_fluency_ad = '../data/Pitt/Dementia/fluency/**'
pitt_recall_ad = '../data/Pitt/Dementia/recall/**'
pitt_sentence_ad = '../data/Pitt/Dementia/sentence/**'
ad_datas = [pitt_cookie_ad, pitt_fluency_ad, pitt_recall_ad, pitt_sentence_ad]
ctrl_datas = [pitt_cookie_cc, pitt_fluency_cc]

In [75]:
def extract_data(file_name):
    par = {}
    par['id'] = file_name.split('/')[-1].split('.cha')[0]
    f = iter(open(file_name))
    l = next(f)
    speech = []
    try:
        curr_speech = ''
        while (True):
            if l.startswith('@ID'):
                participant = [i.strip() for i in l.split('|')]
                if participant[2] == 'PAR':
                    par['mmse'] = '' if len(participant[8]) == 0 else float(participant[8])
                    par['sex'] = participant[4][0] if len(participant[4]) else 'n/a'
                    age = participant[3].replace(';', '')
                    
                    try:
                        par['age'] = int(float(age)) if len(age) > 0 else 'n/a'
                    except:
                        print(participant)
                        print(age)
            if l.startswith('*PAR:') or l.startswith('*INV'):
                curr_speech = l
            elif len(curr_speech) != 0 and not(l.startswith('%') or l.startswith('*')):
                curr_speech += l
            elif len(curr_speech) > 0:
                speech.append(curr_speech)
                curr_speech = ''
            l = next(f)
    except StopIteration:
        pass

    clean_par_speech = []
    clean_all_speech = []
    speech_time_segments = []
    is_par = False
    for s in speech:
        def _clean(s):
            try:
                speech_time_segments.append([*map(int, re.search('\x15(\d*_\d*)\x15', s).groups()[0].split('_'))])
            except:
                speech_time_segments.append([])
            s = re.sub('\x15\d*_\d*\x15', '', s) # remove time block 
            s = re.sub('\[.*\]', '', s) # remove other speech artifacts [.*]
            s = s.strip()
            s = re.sub('\t|\n|<|>', '', s) # remove tab, new lines, inferred speech??, ampersand, &
            return s
        
        if s.startswith('*PAR:'):
            is_par = True
        elif s.startswith('*INV:'):
            is_par = False
            s = re.sub('\*INV:\t', '', s) # remove prefix
        if is_par:
            s = re.sub('\*PAR:\t', '', s) # remove prefix    
            clean_par_speech.append(_clean(s))
        clean_all_speech.append(_clean(s))
    
    par['speech'] = speech
    par['clean_speech'] = clean_all_speech
    par['clean_par_speech'] = clean_par_speech
    par['joined_all_speech'] = ' '.join(clean_all_speech)
    par['joined_all_par_speech'] = ' '.join(clean_par_speech)
    
    # sentence times
#     par['per_sent_times'] = [speech_time_segments[i][1] - speech_time_segments[i][0] if len(speech_time_segments[i]) else -1
#                              for i in range(len(speech_time_segments))]
#     par['total_time'] =  speech_time_segments[-1][1] - speech_time_segments[0][0]
#     par['time_before_par_speech'] = speech_time_segments[0][0]
#     par['time_between_sents'] = [0 if i == 0 else max(0, speech_time_segments[i][0] - speech_time_segments[i-1][1]) 
#                                  if len(speech_time_segments[i]) else -1
#                                  for i in range(len(speech_time_segments))]
    return par

In [76]:
def parse_data(data_dir, ad=True):    
    df = pd.DataFrame([extract_data(fn) for fn in glob(data_dir)])
    df['ad'] = 1 if ad else 0
    df = df.sample(frac=1).reset_index(drop=True)
    return df

In [77]:
ad_df = pd.concat([parse_data(ad_dir) for ad_dir in ad_datas])

In [179]:
ctrl_df = pd.concat([parse_data(ctrl_dir, ad=False) for ctrl_dir in ctrl_datas])

In [95]:
random_state = 42

In [180]:
df = pd.concat([ctrl_df, ad_df]).sample(frac=1, random_state=random_state)

In [181]:
# explode out each segment into AD / control segments
# Do not shuffle, as parent level segments have already been shuffled
segmented_speech = df.apply(lambda r: pd.DataFrame({'id': r.id, 'speech_sent': r.clean_par_speech, 'ad': r.ad, 'mmse': r.mmse}), axis=1).tolist()
df_segments = pd.concat(segmented_speech).reset_index(drop=True)

In [186]:
df_segments.ad.value_counts()

1    18538
0     3156
Name: ad, dtype: int64

In [185]:
df_segments.shape

(21694, 4)

In [190]:
# Embedding function
def bert_embed(text: pd.Series, tokenizer, model):
    tokenized = text.apply((lambda x: tokenizer.encode(x, add_special_tokens=True, max_length=512)))

    # pad so can be treated as one batch
    max_len = max([len(i) for i in tokenized.values])
    padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])

    # attention mask - zero out attention scores where there is no input to be processed (i.e. is padding)
    attention_mask = np.where(padded != 0, 1, 0)
    input_ids = torch.tensor(padded)  
    attention_mask = torch.tensor(attention_mask)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # check if multiple GPUs are available
    multi_gpu = torch.cuda.device_count() > 1

    if torch.cuda.is_available():
        model = model.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

    with torch.no_grad():
        last_hidden_states = model(input_ids, attention_mask=attention_mask)
    last_hidden_states = last_hidden_states[0]
    if device.type == 'cuda':
        last_hidden_states = last_hidden_states.cpu()
    features = last_hidden_states[:,0,:].numpy()
    return features

In [196]:
# begin with Distil roBERTa
model_class, tokenizer_class, pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased')
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights)

In [None]:
sent_features = bert_embed(df_segments.speech_sent, tokenizer, model)

0                                    let's see what's this
1                                           well the kæl@u
2                                            lay it down .
3                               pants and clothes and +...
4        &s the little boy's  the little boy's (j)ust s...
                               ...                        
21689                our child was taken to our hospital .
21690    we don't know but we think this may be a very ...
21691                                                 oh .
21692    when I came in this office the doctor told me ...
21693    when I came in the bedroom (.) I see the burea...
Name: speech_sent, Length: 21694, dtype: object

In [None]:
## ideally fine-tuning a BERT base on these segments? Or actually ideally fine-tune on spoken speech corpus, like 

In [200]:
df_segments

Unnamed: 0,id,speech_sent,ad,mmse
0,173-1,let's see what's this,1,5
1,173-1,well the kæl@u,1,5
2,173-1,lay it down .,1,5
3,173-1,pants and clothes and +...,1,5
4,173-1,&s the little boy's the little boy's (j)ust s...,1,5
...,...,...,...,...
21689,061-0,our child was taken to our hospital .,1,26
21690,061-0,we don't know but we think this may be a very ...,1,26
21691,061-0,oh .,1,26
21692,061-0,when I came in this office the doctor told me ...,1,26


In [163]:
class BERTAdDataset(Dataset):
    def __init__(self, text, labels, tokenizer):
        self.text = text
        self.labels = labels
        
        tokenized = [tokenizer.encode(x, add_special_tokens=True, max_length=512) for x in text]
        # pad so can be treated as one batch
        max_len = 0
        for i in tokenized:
            if len(i) > max_len:
                max_len = len(i)

        padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized])

        # attention mask - zero out attention scores where there is no input to be processed (i.e. is padding)
        attention_mask = np.where(padded != 0, 1, 0)        
        self.input_ids = padded  
        self.attention_mask = attention_mask
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, i):
        return (self.input_ids[i], self.attention_mask[i]), self.labels[i]

In [164]:
ds = BERTAdDataset(df.joined_all_par_speech.tolist(), df.ad.tolist(), tokenizer)
trainloader = DataLoader(ds, batch_size=128, shuffle=True)

In [167]:
len(ds)

1293

In [128]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# check if multiple GPUs are available
multi_gpu = torch.cuda.device_count() > 1

if torch.cuda.is_available():
    model = model.to(device)

In [None]:
running_loss = 0
criterion = torch.nn.CrossEntropyLoss()

# Only train the classifier parameters, feature parameters are frozen
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.003)

batch 

for epoch in range(5):
    for (input_ids, attn_masks), labels in trainloader:
        input_ids = input_ids.to(device)
        attn_masks = attn_masks.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        preds = model.forward(input_ids, attention_mask=attn_masks)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    
        # model eval
        # every so many batches...

In [None]:
with torch.no_grad():
    last_hidden_states = model(input_ids, attention_mask=attention_mask)