# 1. import data and packages

In [None]:
# import neccessary packages
import benepar, spacy
# treeswift package https://niemasd.github.io/TreeSwift/
import treeswift
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
import glob
import os
import re

In [None]:
# set up spacy parser
nlp = spacy.load('en_core_web_md')

if spacy.__version__.startswith('2'):
    nlp.add_pipe(benepar.BeneparComponent("benepar_en3"))
else:
    nlp.add_pipe("benepar", config={"model": "benepar_en3"})

In [None]:
# an example
doc = nlp("So that should be like exactly the same thing")

sent = list(doc.sents)[0]
print(sent._.parse_string)

In [None]:
# File imports
df_ms = 'your dataframe'

In [None]:
# if a turn starts with a " ", remove it. There's some issues with the Dec2019 data
df_ms = df_ms.replace(to_replace ="^\s+", value = '', regex = True)
# remove the double parenthesis, and quotation marks in some conversations
df_ms = df_ms.replace(to_replace ="\(\(", value = '', regex = True)
df_ms = df_ms.replace(to_replace ="\)\)", value = '', regex = True)
df_ms = df_ms.replace(to_replace ='\"', value = '', regex = True)

df_ms['Text'].replace('', np.nan, inplace=True)
df_ms = df_ms[df_ms['Text'].notna()]

In [None]:
# filter 3rd party speaker's utterance
df_ms = df_ms[df_ms.other_speech != 1]

In [None]:
# within each group, merge multiple rows of the same speaker's uttences to one
df_ms['key'] = ((df_ms['group']!= df_ms['group'].shift(1))|
                (df_ms['Speaker']!= df_ms['Speaker'].shift(1))).astype(int).cumsum()
df_show = df_ms.groupby(['key','group','Speaker'])['Text'].apply(' '.join)
df_ms = df_show.to_frame().reset_index()
# print(type(df_show))

In [None]:
# add the turn id within groups 
df_ms['group_turn_id'] = df_ms.groupby(['group']).cumcount()+1
# add the speaker turn id within groups 
df_ms['group_speaker_turn_id'] = df_ms.groupby(['group','Speaker']).cumcount()+1

# 2. Get syntax tree for each utterance, calculate SILLA

In [None]:
# replace the NaN values in 'Text' column to [] so it can be treated universally
df_ms['Text'] = [ '' if x is np.NaN else x for x in df_ms['Text'] ] 

## 2.1 for each full syntax tree of a sentence, get the list of subtrees (as target)

In [None]:
def get_subtree(full_syntax_tree):
    # handle a few special cases that fail to parse as newick tree. 
    newick_tree = full_syntax_tree.replace("'", "").replace(" ", ",").replace(";", "").replace(":", "").replace('"',"")
    tree_read = treeswift.read_tree_newick(newick_tree)
    tree_list = []
    previous_subtree = []
    two_level_subtree = []
    for node in tree_read.traverse_levelorder(leaves=True, internal=True):
        subtree = tree_read.extract_subtree(node)
        subtree_list = list(subtree.labels())
        # find when to create the new list and append
        # the item in the previous_subtree is the parent, and the last item in the subtree_list is the child
        if len(subtree_list) > 1 and len(previous_subtree) == 1 :
            two_level_subtree = []
            two_level_subtree.append(previous_subtree[0])
            two_level_subtree.append(subtree_list[len(subtree_list)-1])
        # still the same tree, add another child item to the list
        elif len(subtree_list) > 1 and len(previous_subtree) > 1: 
            two_level_subtree.append(subtree_list[len(subtree_list)-1])
        # move to a different tree, finish appending the two_level_subtree and append it to the tree_list
        elif len(subtree_list) == 1 and len(previous_subtree) > 1:
            if two_level_subtree != []: # exclude the special first '[]' case
                tree_list.append(two_level_subtree)        
        previous_subtree = subtree_list
        for tree in tree_list: # remove non-meaningful tags
            tree = remove_bad_tags(tree)
    # print("tree list:", tree_list) 
    return tree_list
    
def remove_bad_tags(tree): # tree looks like: ['INTJ', 'UH', '.']
    for items in tree: 
        if items == '': tree.remove('')
        if items == '.': tree.remove('.')
    return tree

In [None]:
# get subtrees for all utterances in a conversation, output a dictionary, key is the index of df_ms
# this takes ~15 mins to run (for middle school data)
def get_subtree_dict(df_ms):
    subtree_dict = {} # a dictionary, key = utterance index, value = subtree list 
    n = 0
    for index, row in df_ms.iterrows(): # i is the index of df_ms
        text =  df_ms.at[index,'Text'] # text is an utterance
        # print(text)
        sent = list(nlp(text).sents)
        utterance_subtree_list = [] # initiate the subtree list for an utterance
        for j in sent: # j is a sentence in the utterance
            syntax_tree = j._.parse_string
            # print(syntax_tree)
            subtree_list = get_subtree(syntax_tree)
            utterance_subtree_list.extend(subtree_list)
        subtree_dict[n] = utterance_subtree_list 
        n = n+1
    # if n%10==0: print('hundred')
    return subtree_dict

# subtree_dict = get_subtree_dict(df_ms)

This step will take ~15 minutes. you can skip this step if you already have the syntax tree generated. 

In [None]:
# add the subtree list dict to the dataframe as a column (by [dict key] and [df index])
df_ms['syntax_tree_current'] = df_ms.index.map(get_subtree_dict(df_ms))

## 2.2 Get the prime syntax subtree (window size customizable, now is 5 and 10 turns)

In [None]:
# set syntax_tree_prime to be n turns prior to the target turn 
def get_prime_subtree(df_ms, n): # specify the number of turns  as n, return a dict
    prime_dict = {}
    # df_ms['syntax_tree_current'].shift(1)
    for group,turn_id in zip(df_ms.group, df_ms.group_turn_id):
        if turn_id == 1: # when it moves to a new group, store the group_name
            group_name = group
        #  get the index of the row 
        index1 = df_ms[(df_ms['group'] == group_name) & (df_ms['group_turn_id'] == turn_id)].index.item()
        
        if group == group_name: # prime has to be done within the same group
            # if index < n, then do everything from 1 to index
            prime_count = turn_id//2 # the number of rows priming: turn_id//2. 5//2 = 2
            speaker = df_ms.iloc[index1]['Speaker'] # find the value in column 'Speaker' and row 'index1'
            
            # df_ms[(df_ms['group'] == group_name) & (df_ms['group_turn_id'] == turn_id)].iloc[]
            syntax_list = []
            if prime_count < n: 
                # syntax_list: [[['S', 'NP', 'VP', '.']], [['INTJ', 'UH', '.']]]
                syntax_list = df_ms[(df_ms['group'] == group_name) & 
                                    (df_ms['Speaker'] != speaker)].iloc[0: prime_count]['syntax_tree_current'].tolist()
            else: 
                syntax_list = df_ms[(df_ms['group'] == group_name) & 
                                    (df_ms['Speaker'] != speaker)].iloc[prime_count-n: prime_count]['syntax_tree_current'].tolist()
#             print ('corpus index: ', index1, 'speaker: ', speaker, 
#                    'syntax_list len: ', len(syntax_list))   

            # turn to this format: [['S', 'NP', 'VP', '.'], ['INTJ', 'UH', '.']]
            syntax_list_reformat = []
            for l in syntax_list: # l is the list of syntax rules for a turn, e.g., [['INTJ', 'UH']]
                    # print("l: ", l)
                syntax_list_reformat.extend(l) # add it to the reformatted list      

            prime_dict[index1] = syntax_list_reformat # syntax list for prime                 
    # print(list(prime_dict.items())[:2])
    return prime_dict


# list.remove('rabbit')

In [None]:
# add the prime subtree column (10-turn window)
df_ms['syntax_tree_prime_10'] = df_ms.index.map(get_prime_subtree(df_ms, 10)) 
# replace the first row prime NaN to [] so it can be treated universally
df_ms['syntax_tree_prime_10'] = [ [] if x is np.NaN else x for x in df_ms['syntax_tree_prime_10'] ] 
# df_ms.head()

In [None]:
# add the prime subtree column (5-turn window)
df_ms['syntax_tree_prime_5'] = df_ms.index.map(get_prime_subtree(df_ms, 5)) 
# replace the first row prime NaN to [] so it can be treated universally
df_ms['syntax_tree_prime_5'] = [ [] if x is np.NaN else x for x in df_ms['syntax_tree_prime_5'] ] 
df_ms.head()

## 2.3 Calculate the SILLA

In [None]:
# find the overlap items in two lists (target and prime)
# calculate the length of utterances
def find_overlap(df_ms, prime_window):

    prime_column = 'syntax_tree_prime_'+ str(prime_window) # get the column name based on the prime_window chosed 
    len_prime_column = 'len_prime_' + str(prime_window)
    
    print(prime_column)
    df_ms['overlap_count'] = 0
    df_ms['len_target'] = 0
    df_ms['len_prime'] = 0

    for index, row in df_ms.iterrows():
        n=0
        list1 = row[prime_column]
        list2 = row['syntax_tree_current']
        # print(type(list1), type(list2))
        for i in list1:
            # print(i)
            if i in list2: # if the item in list1 belongs to list 2, then n++
                n+=1
        # add columns (column manipulation: set use df.at, get use df.loc)
        
        df_ms.at[index,'overlap_count'] = n # set number of items that overlapped between prime and target
        df_ms.at[index,'len_target'] = len(list2) # set length of target utterance
        df_ms.at[index,'len_prime'] = len(list1) # set length of prime utterance
    
    return len_prime_column # keep the len_prime column name for future use 

In [None]:
# specify your prime window, currently is 10 
find_overlap(df_ms, 10)
# df_ms[:10]

In [None]:
# get the SILLA = p(target|prime)/p(target) = number of elements overlapped / (len_target*len_prime)
df_ms['len_prime_target'] = df_ms['len_target'] * df_ms['len_prime']
df_ms['lla'] = df_ms['overlap_count']/df_ms['len_prime_target']

# 3. Calculate Normalized LLA (nLLA)

## 3.1 compute $\bar {LLA}$
i.e., The average LLA for all pairs that have the same product of length, and for all possible product values n.
Outcome of D1 should be a dictionary which the key to be product values (n), value to be the average LLA of all pairs which have that product values

In [None]:
# Reitter paper idea: normalize by the average LLA for the same product of length, which is len(Prime)*len(target)
# "product of length" column: len_prime_target
grouped_element_length = df_ms.groupby('len_prime_target') # group the df by utternace length 
print(len(grouped_element_length)) # total of 1657 unique len(Prime)*len(target)

avg_lla_list = {} # key = length, value = average_lla, use this to add the column to df later
for l in grouped_element_length: # l is a tuple object, l[0] is the element length, l[1] is the subset dataframe
    avg_lla = l[1]["lla"].mean()
    avg_lla_list[l[0]] = avg_lla
#     for index, row in l[1].iterrows():
#         list(grouped_element_length)[l[0]][1].at[index,'avg_lla'] = avg_lla
        # print('yes')
print(list(avg_lla_list.items())[:5])   

In [None]:
# add avg_lla dictionary to the df as a column, mapping by len(Prime)*len(target)
df_ms['avg_lla'] = df_ms.len_prime_target.map(avg_lla_list)
df_ms.head()

## 3.2 compute nLLA

In [None]:
# get the nLLA = LLA / avg_lla
df_ms['nlla'] = df_ms['lla']/df_ms['avg_lla'] 
df_ms.head()

## 3.3 (one time run, can be skipped) Post hoc experiment - add a baseline lla for each turn
### shuffle the corpus, pick any n turns prior to the *target* turn as *prime*

In [None]:
df_shuffle = df_ms.sample(frac=1).reset_index(drop=True)
df_shuffle[:2]

In [None]:
cols = [7,8,9,10,11,12,13,14,15]
df_shuffle.drop(df_shuffle.columns[cols],axis=1,inplace=True)
df_shuffle[:2]

In [None]:
# set syntax_tree_prime to be n turns prior to the target turn 
def get_baseline_prime_subtree(df_ms, n): # specify the number of turns  as n, return a dict
    prime_dict = {}
    # df_ms['syntax_tree_current'].shift(1)
    for index in df_ms.index:
        # if index < n, then do everything from 1 to index
        # print(type(df_ms.iloc[index]['syntax_tree_current']))
        syntax_list = []
        if index < n: 
            # syntax_list: [[['S', 'NP', 'VP', '.']], [['INTJ', 'UH', '.']]]
            syntax_list = df_ms.iloc[0: index]['syntax_tree_current'].tolist()
            # print(syntax_list)
        else: 
            syntax_list = df_ms.iloc[index-n: index]['syntax_tree_current'].tolist()
        # turn to this format: [['S', 'NP', 'VP', '.'], ['INTJ', 'UH', '.']]
        syntax_list_reformat = []
        for l in syntax_list: # l is the list of syntax rules for a turn, e.g., [['INTJ', 'UH']]
            syntax_list_reformat.extend(l) # add it to the reformatted list     

        prime_dict[index] = syntax_list_reformat # syntax list for prime                 
    # print(list(prime_dict.items())[:2])
    return prime_dict


# list.remove('rabbit')
# get_baseline_prime_subtree(df_shuffle[:30], 10)

In [None]:
# get the random prime from the shuffled df
df_shuffle['prime_10_random'] = df_ms.index.map(get_baseline_prime_subtree(df_shuffle, 10)) 
# replace the first row prime NaN to [] so it can be treated universally
df_shuffle['prime_10_random'] = [ [] if x is np.NaN else x for x in df_shuffle['prime_10_random'] ] 
df_shuffle.head()

In [None]:
# find the overlap items in two lists (target and prime)
# calculate the length of utterances
def find_overlap_random(df_ms, prime_window):

    prime_column = 'prime_10_random' # get the column name based on the prime_window chosed 
    len_prime_column = 'len_prime_' + str(prime_window)
    
    print(prime_column)
    df_ms['overlap_count'] = 0
    df_ms['len_target'] = 0
    df_ms['len_prime'] = 0

    for index, row in df_ms.iterrows():
        n=0
        # for repeated use
#         list1 = row[prime_column].split("delimiter")
#         list2 = row['syntax_tree_current'].split("delimiter")
        list1 = row['prime_10_random']
        list2 = row['syntax_tree_current']
        # print(type(list1), type(list2))
        for i in list1:
            # print(i)
            if i in list2: # if the item in list1 belongs to list 2, then n++
                n+=1
        # add columns (column manipulation: set use df.at, get use df.loc)
        
        df_ms.at[index,'overlap_count'] = n # set number of items that overlapped between prime and target
        df_ms.at[index,'len_target'] = len(list2) # set length of target utterance
        df_ms.at[index,'len_prime'] = len(list1) # set length of prime utterance
    
    return len_prime_column # keep the len_prime column name for future use 

In [None]:
find_overlap_random(df_shuffle, 10)

In [None]:
df_shuffle.head()

# 4. Visual inspection 

In [None]:
# load the processed data
df_ms = pd.read_csv(r'your file path')
df_ms= df_ms.iloc[:,1:]
df_ms.head()

## 4.1 Distribution of SILLA

In [None]:
df_ms = df_ms.replace([np.inf, -np.inf], 0)

In [None]:
# plot the distribution of SILLA scores
plt.hist(df_ms[(df_ms['lla'] < 0.25)]['lla'], bins = 100)
plt.xlim(-0.01, 0.25)
plt.title('SILLA distribution - middle school')
plt.show()

In [None]:
# count number of zero values of silla
df_ms['lla'].value_counts(normalize=True)

In [None]:
plt.hist(df_ms[(df_ms['lla'] != 0 )& (df_ms['lla'] < 0.25)]['lla'], bins = 100)
plt.xlim(-0.01, 0.25)
plt.title('SILLA distribution - middle school - nonzero only')
plt.show()

## 4.2 Distribution of Normalized LLA (nLLA)
The distribution shape changed a bit, the scale changed, general trends look similar.

In [None]:
# plot the distribution of normalized SILLA scores
plt.hist(df_ms['nlla'], bins = 120)
plt.xlim(-0.2, 8)
plt.title('nLLA distribution - middle school')
plt.show()

In [None]:
plt.hist(df_ms[df_ms['nlla'] != 0]['nlla'], bins = 120)
plt.xlim(-0.2, 8)
plt.title('nLLA distribution - middle school - nonzero only')
plt.show()

## 4.3 Distribution of LLA by groups
for more consistent figures see my ppt 

### 4.3.1  Distribution of SILLA by groups

In [None]:
grouped = df_ms.groupby('group')
# print(list(grouped))

for group in grouped:
  # figure()
  print(group[0])
  group[1].lla.plot.hist(bins = 200, xlim = (-0.01, 0.2), figsize=(3,3))
  # plot.hist(group[1].N)
  plt.show()

### 4.3.2  Distribution of nLLA by groups

In [None]:
for group in grouped:
  # figure()
  print(group[0])
  group[1].nlla.plot.hist(bins = 80, xlim = (-0.5, 6), figsize=(3,3))
  # plot.hist(group[1].N)
  plt.show()

## 4.4 distribution of sentence length, relationship between sentence length and LLA 

In [None]:
# plot the distribution of sentence length
plt.hist(df_ms['len_target'], bins = 200)
plt.xlim(-1, 80)
plt.show()

In [None]:
# any patterns between sentence length and SILLA score
x = df_ms['len_target']
y = df_ms['lla']
plt.scatter(x,y)

In [None]:
# any patterns between sentence length and overlap count
plt.scatter(df_ms['len_target'],df_ms['overlap_count'])
m, b = np.polyfit(df_ms['len_target'], df_ms['overlap_count'], 1)
plt.plot(df_ms['len_target'], m*df_ms['len_target'] + b,color = 'green')
print('slop: ', m, 'intercept: ', b)

In [None]:
# take a granular look 
plt.scatter(df_ms['len_target'],df_ms['overlap_count'])
plt.xlim(0, 80)
plt.ylim(-1, 120)

## 4.5 Distribution of syntax subtrees
### Among 5400 syntax rules, 18 are really polular (54% of total syntax rules), we can easily find the distortion elbow from the plot. 

In [None]:
# plot the distribution of syntax subtrees
import ast

# Create dictionary
dict_freq = {}
special_char = 0 # track how many special characters are there ("." and "")
# Add syntax rules (subtrees) to dictionary
for index, row in df_ms.iterrows():
    # Converting string to list
    list3 = ast.literal_eval(row['syntax_tree_current']) # list3 is subtress in one utterance

    # list is unhashable, convert list to tuple
    for r in list3:   # r is each subtree
        if tuple(r) not in dict_freq:
            dict_freq[tuple(r)] = 0
        dict_freq[tuple(r)] += 1
print('special_char count: ', special_char)       
word_freq_list = [(v,k) for k,v in dict_freq.items()]
freq_list_sorted = sorted(word_freq_list,reverse=True)
print(list(dict_freq.items())[:2])

In [None]:
from nltk.book import *

In [None]:
# take a look at the syntax subtree frequency dictionary 
print('unique syntax rules: ',len(dict_freq))
print('total syntax rules freq (for all utterances):', sum(dict_freq.values()))
listr = []
for value in dict_freq.values():
    listr.append(value)
          
print('mean syntax rules freq:', np.mean(listr))
print('std syntax rules freq:', np.std(listr))


In [None]:
fdist = FreqDist(dict_freq)
fdist.most_common(50)

In [None]:
fdist.plot(30)