# Imports

In [None]:
import os
import json
from collections import defaultdict, Counter
import pandas as pd

In [None]:
DATA_DIR = '/mnt/data/radgraph/data'

In [None]:
%run ../radgraph.py

# Load data

In [None]:
filename = os.path.join(DATA_DIR, 'dev.json')

In [None]:
with open(filename, 'r') as f:
    all_samples = json.load(f)
len(all_samples)

In [None]:
list(all_samples.keys())[-1]

# Process graphs

In [None]:
from networkx.algorithms.components import node_connected_component

In [None]:
%run ../radgraph.py

## One example

### Choose sample

In [None]:
# Cardiomegaly examples
# key = 'p10/p10003412/s59172281.txt' # Unremarkable cardiac silhouette # Target: 9
# key = 'p15/p15003878/s55991257.txt' # Cardiomegaly stable # Target: 26??
# Heart size, etc

# Opacities
# key = 'p18/p18012429/s50784640.txt' # opacity # Target: 14
key = 'p15/p15005501/s54469606.txt' # Focal infiltrate # Target: 21

# Other
# key = 'p15/p15003878/s57167019.txt'
# key = 'p18/p18001816/s54309228.txt'
# key = 'p15/p15003878/s57380048.txt'
# key = 'p15/p15003878/s58677239.txt' # with uncertain
# key = 'p15/p15001233/s54924087.txt'

In [None]:
d = all_samples[key]
text = d['text'].split()
d['text']

In [None]:
entities = d['entities']
len(entities.keys())

### Plot

In [None]:
graph = create_report_radgraph(entities)
graph

In [None]:
print_id_to_tokens(graph)

In [None]:
plot_radgraph(graph, n_cols=3, labels=True)

### Plot subset

In [None]:
target = '14'
subset = list(node_connected_component(graph.to_undirected(), target))
print_id_to_tokens(graph, subset)

In [None]:
gg = graph.subgraph(subset)
plot_radgraph(gg, figsize=(15,8), layout='planar', labels=True)

In [None]:
d['text']

## Group all findings

TODO: check edge cases:

* p10/p10003412/s59172281.txt: unremarkable cardiac and mediastinal silhouettes

In [None]:
from collections import namedtuple
from tqdm import tqdm

In [None]:
_keys = ['id', 'findings', 'f_labels', 'locations', 'l_labels']
class CoreFinding(namedtuple('CoreFinding', _keys)):
    def __str__(self):
        return f'{self.id}: {self.findings} ({self.f_labels}) - {self.locations} ({self.l_labels})'

    def __repr__(self):
        return self.__str__()
    
    def to_text(self):
        return (self.findings + ' ' + self.locations).lower()
    
    def __lt__(self, other):
        return self.to_text() < other.to_text()

In [None]:
def entities_to_findings(entities):
    graph = create_report_radgraph(entities)
    nodes_data = graph.nodes.data()
    
    # Create useful subgraphs
    modifiers_subgraph = graph.copy()
    for a, b, info in graph.edges.data():
        if info['relation'] not in ('modify', 'suggestive_of'):
            modifiers_subgraph.remove_edge(a, b)

    located_at_subgraph = graph.copy()
    for a, b, info in graph.edges.data():
        if info['relation'] not in ('located_at',):
            located_at_subgraph.remove_edge(a, b)
            
    # Utils
    is_finding = lambda node: 'OBS' in nodes_data[node]['label']
    is_location = lambda node: 'ANAT' in nodes_data[node]['label']
    get_order = lambda node: nodes_data[node]['start']
    get_tokens = lambda node: nodes_data[node]['tokens']
    get_label = lambda node: nodes_data[node]['label']
    
    def group_to_string(group):
        if not group:
            return '', ''
        group = sorted(group, key=get_order)
        labels = set(get_label(node) for node in group)
        labels = ' '.join(str(l) for l in labels)

        tokens = [get_tokens(node) for node in group]
        tokens = ' '.join(str(g) for g in tokens)
        return tokens, labels
    
    # Iterate through nodes for findings
    core_findings = []
    for node in graph.nodes:
        if not is_finding(node):
            continue

        # Successors
        modifies_nodes = list(modifiers_subgraph.successors(node))
        if len(modifies_nodes) > 0:
            continue

        # Findings
        ancestors = nx.ancestors(modifiers_subgraph, node)
        branch = [node] + list(ancestors)
        findings, f_labels = group_to_string(branch)

        # Location
        located_at = [
            m
            for n in branch
            for m in located_at_subgraph.successors(n)
        ]
        location, l_labels = group_to_string([
            s
            for n in located_at if is_location(n)
            for s in list(nx.ancestors(modifiers_subgraph, n)) + [n]
        ])

        core_findings.append(CoreFinding(
            id=node,
            findings=findings,
            f_labels=f_labels,
            locations=location,
            l_labels=l_labels,
        ))
    return core_findings

In [None]:
all_findings = dict()
for report_id, sample in tqdm(all_samples.items()):
    core_findings = entities_to_findings(sample['entities'])
    all_findings[report_id] = core_findings
len(all_findings)

In [None]:
next(iter(all_findings.values()))

## CoreFindings -->  ChexpertLabels

### Try finding by patterns

In [None]:
import re

In [None]:
def find_target_findings(keywords):
    target_findings = []
    out = []
    for key, sample_findings in all_findings.items():
        for core_finding in sample_findings:
            report = core_finding.to_text()
            if any(re.search(keyword, report) for keyword in keywords):
                target_findings.append((core_finding, key))
            else:
                out.append((core_finding, key))
    return target_findings, out

In [None]:
cardiom_findings, out = find_target_findings([
    'cardiomegaly',
    r'\bcardiac',
    'cardiac silhouette',
    'cardiac contour',
    'heart',
])
sorted(cardiom_findings)

In [None]:
pneumo_findings, out = find_target_findings([
    'pneumothorax',
    'pneumothoax',
    'pneumothoraces',
])
sorted(pneumo_findings)

In [None]:
findings1, out = find_target_findings([
    'opaci',
    # 'infiltrat',
])
sorted(findings1)

### Try labelling manually

In [None]:
%run ../common/constants.py

In [None]:
flat_unique_findings = set(
    core_finding.to_text()
    for sample_findings in all_findings.values()
    for core_finding in sample_findings
)
len(flat_unique_findings)

In [None]:
_FINDING_TO_LABEL = {}

In [None]:
def _label_manually(verbose=False):
    shortcut_to_label = {
        k.lower(): v
        for k, v in CHEXPERT_SHORT2LABEL.items()
    }
    
    total = len(flat_unique_findings)
    
    for index, finding in enumerate(flat_unique_findings):
        if finding in _FINDING_TO_LABEL:
            continue

        while True:
            labels = input(f'({index}/{total}) {finding}: ')
            if labels == 'q' or labels == 'quit':
                return

            labels = [l.strip().lower() for l in labels.strip().split(',')]

            unrecognized_labels = [
                l
                for l in labels
                if l not in shortcut_to_label and l != '-'
            ]
            if unrecognized_labels:
                print('ERROR: Some labels not recognized: ', unrecognized_labels)
                continue

            labels = [
                shortcut_to_label[l]
                for l in labels
                if l in shortcut_to_label
            ]

            break

        if verbose:
            print(labels)
            
        _FINDING_TO_LABEL[finding] = labels

In [None]:
_label_manually()

In [None]:
## %run ../../metrics/report_generation/abn_match/chexpert.py

In [None]:
vocab = dict()
for sample_findings in all_findings.values():
    for core_finding in sample_findings:
        report = core_finding.locations + ' ' + core_finding.findings
        for word in report.split():
            if word not in vocab:
                vocab[word] = len(vocab)
len(vocab)

In [None]:
# labeler = ChexpertLighterLabeler(vocab, device='cpu') # DO NOT USE THIS!!

In [None]:
for findings in all_findings.values():
    for finding in findings:
        report = finding.locations + ' ' + finding.findings
        labels = labeler.label_report(report)
        break
    break
labels

In [None]:
core_findings

In [None]:
graph = ReportRadGraph(sample['entities'])

In [None]:
nx.ancestors(graph.graph, '11')

In [None]:
target = '11'
subset = list(node_connected_component(graph.graph.to_undirected(), target))
graph.print_id_to_tokens(subset)
gg = graph.graph.subgraph(subset)
plot_radgraph(gg, figsize=(5,5))

In [None]:
sample