# RST-Tree Word Relation Analyser

#### Prerequisites:
* **Python 3** *(tested with Python 3.9)*
* **nltk** *(pip package)*:
~~~
    pip install nltk
~~~ 
* **punkt** *(nltk package)*:
~~~
    python
    >>> import nltk
    >>> nltk.download('punkt')
~~~    

#### Usage:
 - Change the parameters in the cell below to your desired value
 - Click on *Run All Cells*

In [None]:
filePath = "input.rs3"
outputFile = "output.csv"

maximumRelationLevelToShow = 3

In [None]:
import sys
sys.path.append("RST-Tace")

from rsttace.input.parser import RstTreeParser
from rsttace.core.rsttree import MultiNucRelation

rstTree = RstTreeParser(filePath).read()

In [None]:
def extractRelations(rstNode):
    """ For rstNode: Extract relation to the next parent or (larger) sibling """
    hasLargerSibling = (rstNode.toSibling is not None) and (rstNode.toSibling.start is rstNode)
    hasRealParent = isinstance(rstNode.toParent, MultiNucRelation)
    
    if hasLargerSibling:
        return [rstNode.toSibling.relation] + extractRelations(rstNode.toSibling.end)
    elif hasRealParent:
        return [rstNode.toParent.relation] + extractRelations(rstNode.toParent.parent)
    else: # Parent is a Span -> Extract relation from one level above
        if rstNode.toParent is not None:
            return extractRelations(rstNode.toParent.parent)
        else:
            return []

def extractRelationsTextPairs(rstNode):   
    if rstNode is None:
        return []
    else:
        retList = []
        # Generate entry for current node
        if rstNode.text is not None:
            text = rstNode.text
            relations = extractRelations(rstNode)
            retList.append([relations, text])
        
        # Append lists of children (MultiNucRelation and Span)
        if (rstNode.toChildren is not None):
            for child in rstNode.toChildren.children:
                retList = retList + extractRelationsTextPairs(child)
               
        return retList

In [None]:
def splitIntoTokens(text):
    from nltk.tokenize import word_tokenize
    return word_tokenize(text)

def removePunctuation(tokens):
    return [w for w in tokens if w.isalpha()]

def convertToLowerCase(words):
    return [w.lower() for w in words]

def removeStopWords(words):
    from nltk.corpus import stopwords
    stopWords = set(stopwords.words('german'))
    return [w for w in words if not w in stopWords]

def extractWords(text):
    tokens = splitIntoTokens(text)
    words = removePunctuation(tokens)
    #words = convertToLowerCase(words)
    #words = removeStopWords(words)
    return words

In [None]:
relationTextPairs = extractRelationsTextPairs(rstTree.root)

# Generate word relation table
wordRelationsTable = []
relationList = []
for pair in relationTextPairs:
    relations = pair[0]
    text = pair[1]
    
    # Add relations to relation list, if not already contained
    for rel in relations:
        if rel not in relationList:
            relationList.append(rel)
        
    words = extractWords(text)
    for word in words:
        wordRelationsTable.append([word, relations])

In [None]:
# Generate output
relationList.sort()

csvRows = []
for wordRelation in wordRelationsTable:
    word = wordRelation[0]
    relations = wordRelation[1]
    
    csvRow = [word]
    for rel in relationList:
        try:
            level = relations.index(rel)
            if level < maximumRelationLevelToShow:
                csvRow.append(str(level+1))
            else:
                csvRow.append("")
        except ValueError:
            csvRow.append("")
    
    csvRows.append(csvRow)

In [None]:
# Write to CSV file
from csv import writer

with open(outputFile, mode='w') as csvFile:
    csvWriter = writer(csvFile, delimiter=',')
    
    csvWriter.writerow([""]+relationList)
    for row in csvRows:
        csvWriter.writerow(row)    

In [None]:
# def generateWordRelationPairKey(word, relation):
#     return relation + "_" + word
#
# relationTextPairs = extractRelationTextPairs(rstTree.root)
#
# # Generate word relation scatter plot
# relationList = []
# wordList = []
# wordRelTable = {}
# for pair in relationTextPairs:
#     relation = pair[0]
#     text = pair[1]
#    
#     # Add relation to relation list, if not already contained 
#     if relation not in relationList:
#         relationList.append(relation)
# 
#     words = extractWords(text)
#     for word in words:
#         # Add word to word list, if not already contained
#         if word not in wordList:
#             wordList.append(word)
#        
#         # Generate word-relation pair key
#         pairKey = generateWordRelationPairKey(word, relation)
#         if pairKey in wordRelTable:
#             wordRelTable[pairKey] += 1
#         else:
#             wordRelTable[pairKey] = 1
#
# # Prepare output
# relationList.sort()