# Visualizing RST structures in GeM corpora

## 1. Import the necessary packages.

In [1]:
# For parsing XML
import xml.etree.ElementTree as et

# For drawing the graphs
import pygraphviz as pgv

# TO DO: format the RST segments
# import textwrap

## 2. Parse the GeM XML files.

In [2]:
basefile = et.parse('test_xml/2002-she-base-1.xml') # Base layer
rstfile = et.parse('test_xml/2002-she-rst-1.xml') # RST layer

Get the root elements.

In [3]:
baseroot = basefile.getroot()
rstroot = rstfile.getroot()

Parse the base units.

In [4]:
base_units = {} # Set up an empty dictionary

for unit in baseroot:
    unit_id = unit.attrib['id']
    if 'alt' in unit.attrib:
        unit_content = unit.attrib['alt']
    else:
        unit_content = unit.text
    base_units[unit_id] = unit_content # Populate dictionary with key (id) and value (content) pairs
    
# TO DO: PARSE CONTENT FROM EMBEDDED BASE UNITS

Parse the RST units.

In [5]:
rst_units = {}

for rstunit in rstroot[0]: # [0] to access the nested <segment> element
    rstunit_id = rstunit.attrib['id']
    rstunit_xref = rstunit.attrib['xref']
    rstunit_content = base_units[rstunit_xref]
    rst_units[rstunit_id] = rstunit_content

Parse the RST relations.

In [6]:
rst_relations = {}

for span in rstroot[1]: # [1] to access the nested <rst-structure> element
    rst_relations[span.attrib['id']] = span.attrib['relation']

Parse the RST spans.

In [19]:
rst_graph = pgv.AGraph(strict = False, directed = False, ranksep='1.0')

rst_graph.add_nodes_from(rst_relations) # add relations
rst_graph.add_nodes_from(rst_units) # add segments

for node in rst_graph.nodes():
    if node in rst_units:
        rst_graph.get_node(node).attr['label'] = rst_units[node]
        rst_graph.get_node(node).attr['shape'] = 'box'
        rst_graph.get_node(node).attr['fontsize'] = '8.0'
    if node in relations:
        rst_graph.get_node(node).attr['label'] = relations[node]
        rst_graph.get_node(node).attr['shape'] = 'none'
        rst_graph.get_node(node).attr['style'] = 'filled'
        rst_graph.get_node(node).attr['fillcolor'] = 'gray82'
        rst_graph.get_node(node).attr['fontcolor'] = 'crimson'

In [20]:
relations = {}
edges = []

for span in rstroot[1]:
    if span.tag == 'multi-span':
        multispan, nuclei, relation = span.attrib['id'], span.attrib['nuclei'].split(), span.attrib['relation']
        if span.attrib['id'] == multispan:
            relations[multispan] = relation
            for n in nuclei:
                edges.append((multispan, n))
    if span.tag == 'span':
        spanid, nucleus, satellites, relation = span.attrib['id'], span.attrib['nucleus'].split(), span.attrib['satellites'].split(), span.attrib['relation']
        if span.attrib['id'] == spanid:
            relations[spanid] = relation
            for n in nucleus:
                edges.append((spanid, n))
            for s in satellites:
                edges.append((s, spanid))
    for title in span.iter('title'):
        target_id = span.attrib['id']
        title_xref = title.attrib['xref']
        if span.attrib['id'] == target_id:
            edges.append((title_xref, target_id))

In [21]:
rst_graph.add_edges_from(edges)
rst_graph.draw("test1.png", format = "png", prog = 'dot')