In [1]:
# Notebook for fine-tuning BERT related models

In [160]:
from math import sqrt
import regex as re
import os
from glob import glob
import numpy as np
import pandas as pd
import torch
import torch.nn
from torch.utils.data import DataLoader, Dataset
import transformers as ppb
import warnings

warnings.filterwarnings('ignore')

BERT EMBEDDER
+ TIME OF EACH SENT
+ Demographics? Sex / Age?

LSTM (WITH ATTENTION) OVER SEQUENCE ??? 

In [3]:
os.getcwd()

'/Users/tom/phd/ADReSS_Challenge/ADReSS_Challenge'

In [8]:
pitt_cookie_cc = '../data/Pitt/Control/cookie/**'
pitt_fluency_cc = '../data/Pitt/Control/fluency/**'
pitt_cookie_ad = '../data/Pitt/Dementia/cookie/**'
pitt_fluency_ad = '../data/Pitt/Dementia/fluency/**'
pitt_recall_ad = '../data/Pitt/Dementia/recall/**'
pitt_sentence_ad = '../data/Pitt/Dementia/sentence/**'
ad_datas = [pitt_cookie_ad, pitt_fluency_ad, pitt_recall_ad, pitt_sentence_ad]
ctrl_datas = [pitt_cookie_cc, pitt_fluency_cc]

In [75]:
def extract_data(file_name):
    par = {}
    par['id'] = file_name.split('/')[-1].split('.cha')[0]
    f = iter(open(file_name))
    l = next(f)
    speech = []
    try:
        curr_speech = ''
        while (True):
            if l.startswith('@ID'):
                participant = [i.strip() for i in l.split('|')]
                if participant[2] == 'PAR':
                    par['mmse'] = '' if len(participant[8]) == 0 else float(participant[8])
                    par['sex'] = participant[4][0] if len(participant[4]) else 'n/a'
                    age = participant[3].replace(';', '')
                    
                    try:
                        par['age'] = int(float(age)) if len(age) > 0 else 'n/a'
                    except:
                        print(participant)
                        print(age)
            if l.startswith('*PAR:') or l.startswith('*INV'):
                curr_speech = l
            elif len(curr_speech) != 0 and not(l.startswith('%') or l.startswith('*')):
                curr_speech += l
            elif len(curr_speech) > 0:
                speech.append(curr_speech)
                curr_speech = ''
            l = next(f)
    except StopIteration:
        pass

    clean_par_speech = []
    clean_all_speech = []
    speech_time_segments = []
    is_par = False
    for s in speech:
        def _clean(s):
            try:
                speech_time_segments.append([*map(int, re.search('\x15(\d*_\d*)\x15', s).groups()[0].split('_'))])
            except:
                speech_time_segments.append([])
            s = re.sub('\x15\d*_\d*\x15', '', s) # remove time block 
            s = re.sub('\[.*\]', '', s) # remove other speech artifacts [.*]
            s = s.strip()
            s = re.sub('\t|\n|<|>', '', s) # remove tab, new lines, inferred speech??, ampersand, &
            return s
        
        if s.startswith('*PAR:'):
            is_par = True
        elif s.startswith('*INV:'):
            is_par = False
            s = re.sub('\*INV:\t', '', s) # remove prefix
        if is_par:
            s = re.sub('\*PAR:\t', '', s) # remove prefix    
            clean_par_speech.append(_clean(s))
        clean_all_speech.append(_clean(s))
    
    par['speech'] = speech
    par['clean_speech'] = clean_all_speech
    par['clean_par_speech'] = clean_par_speech
    par['joined_all_speech'] = ' '.join(clean_all_speech)
    par['joined_all_par_speech'] = ' '.join(clean_par_speech)
    
    # sentence times
#     par['per_sent_times'] = [speech_time_segments[i][1] - speech_time_segments[i][0] if len(speech_time_segments[i]) else -1
#                              for i in range(len(speech_time_segments))]
#     par['total_time'] =  speech_time_segments[-1][1] - speech_time_segments[0][0]
#     par['time_before_par_speech'] = speech_time_segments[0][0]
#     par['time_between_sents'] = [0 if i == 0 else max(0, speech_time_segments[i][0] - speech_time_segments[i-1][1]) 
#                                  if len(speech_time_segments[i]) else -1
#                                  for i in range(len(speech_time_segments))]
    return par

In [76]:
def parse_data(data_dir, ad=True):    
    df = pd.DataFrame([extract_data(fn) for fn in glob(data_dir)])
    df['ad'] = 1 if ad else 0
    df = df.sample(frac=1).reset_index(drop=True)
    return df

In [77]:
ad_df = pd.concat([parse_data(ad_dir) for ad_dir in ad_datas])

In [81]:
ctrl_df = pd.concat([parse_data(ctrl_dir) for ctrl_dir in ctrl_datas])

In [95]:
random_state = 42

In [96]:
df = pd.concat([ctrl_df, ad_df]).sample(frac=1, random_state=random_state)

In [92]:
# begin with Distil roBERTa
model_class, tokenizer_class, pretrained_weights = (ppb.RobertaForSequenceClassification, ppb.RobertaTokenizer, 'distilroberta-base')
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=480.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=331070498.0, style=ProgressStyle(descri…




In [97]:
# Only Participant Speech

# All (INV + PAR) speech
# tokenized = train_df.joined_all_speech.apply((lambda x: tokenizer.encode(x, add_special_tokens=True, max_length=512)))

In [163]:
class BERTAdDataset(Dataset):
    def __init__(self, text, labels, tokenizer):
        self.text = text
        self.labels = labels
        
        tokenized = [tokenizer.encode(x, add_special_tokens=True, max_length=512) for x in text]
        # pad so can be treated as one batch
        max_len = 0
        for i in tokenized:
            if len(i) > max_len:
                max_len = len(i)

        padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized])

        # attention mask - zero out attention scores where there is no input to be processed (i.e. is padding)
        attention_mask = np.where(padded != 0, 1, 0)        
        self.input_ids = padded  
        self.attention_mask = attention_mask
        
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, i):
        return (self.input_ids[i], self.attention_mask[i]), self.labels[i]

In [164]:
ds = BERTAdDataset(df.joined_all_par_speech.tolist(), df.ad.tolist(), tokenizer)
trainloader = DataLoader(ds, batch_size=128, shuffle=True)

In [167]:
len(ds)

1293

In [128]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# check if multiple GPUs are available
multi_gpu = torch.cuda.device_count() > 1

if torch.cuda.is_available():
    model = model.to(device)

In [None]:
running_loss = 0
criterion = torch.nn.CrossEntropyLoss()

# Only train the classifier parameters, feature parameters are frozen
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.003)

batch 

for epoch in range(5):
    for (input_ids, attn_masks), labels in trainloader:
        input_ids = input_ids.to(device)
        attn_masks = attn_masks.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        preds = model.forward(input_ids, attention_mask=attn_masks)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    
        # model eval
        # every so many batches...

In [None]:
with torch.no_grad():
    last_hidden_states = model(input_ids, attention_mask=attention_mask)