In [None]:
from johnsnowlabs import nlp, medical, visual
import pandas as pd
import json
import string
import numpy as np
import warnings

warnings.filterwarnings('ignore')
spark = nlp.start()

from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pyspark.sql as SQL
from pyspark import keyword_only

from sklearn.metrics import classification_report

In [None]:
data = spark.read.parquet(test_path)

# map entity columns to dataset columns
column_map = {
    "begin1": "firstCharEnt1",
    "end1": "lastCharEnt1",
    "begin2": "firstCharEnt2",
    "end2": "lastCharEnt2",
    "chunk1": "chunk1",
    "chunk2": "chunk2",
    "label1": "label1",
    "label2": "label2"
}

# apply preprocess function to dataframe
data = medical.REDatasetHelper(data).create_annotation_column(
    column_map,
    ner_column_name="train_ner_chunks" # optional, default train_ner_chunks
)

In [None]:
documenter = nlp.DocumentAssembler()\
    .setInputCol("sentence")\
    .setOutputCol("sentences")

tokenizer = nlp.Tokenizer()\
    .setInputCols(["sentences"])\
    .setOutputCol("tokens")\

words_embedder = nlp.WordEmbeddingsModel()\
    .pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("embeddings")

pos_tagger = nlp.PerceptronModel()\
    .pretrained("pos_clinical", "en", "clinical/models") \
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("pos_tags")

dependency_parser = nlp.DependencyParserModel()\
    .pretrained("dependency_conllu", "en")\
    .setInputCols(["sentences", "pos_tags", "tokens"])\
    .setOutputCol("dependencies")

finisher = nlp.Finisher()\
    .setInputCols(["relations"])\
    .setOutputCols(["relations_out"])\
    .setCleanAnnotations(False)\
    .setValueSplitSymbol(",")\
    .setAnnotationSplitSymbol(",")\
    .setOutputAsArray(False)

In [None]:
clinical_re_Model = medical.RelationExtractionModel()\
    .pretrained("re_oncology_temporal_wip", "en", 'clinical/models')\
    .setInputCols(["embeddings", "pos_tags", "train_ner_chunks", "dependencies"])\
    .setOutputCol("relations")

clinical_re_Model.getClasses()

In [None]:
finetune_pipeline = nlp.Pipeline(stages=[
    documenter,
    tokenizer,
    words_embedder,
    pos_tagger,
    dependency_parser,
    clinical_re_Model,
    finisher
])

## Inference

In [None]:
%%time
result = finetune_pipeline.fit(data).transform(data)
result_df = result.toPandas()

result_df.shape

# Examples - Visulalization

In [None]:
import re

def printh(text, pattern, color="cyan", raise_errors=False, escape=False):
    """Highlight `pattern` while printing `text` in console.
    The `pattern` can be a regex pattern or plain text. If `pattern` is
    plain text, it must be an exact match within `text`.
    In case of multiple matches, all the matches will be highlighted.

    `color` value must be one of following:
     - black
     - red
     - green
     - yellow
     - blue
     - magenta
     - cyan
     - white

    Input:
     - text(string): full text to print
     - pattern(string): a regex-like pattern or simple string
     - color(string): color of output
     - raise_errors(bool): raise errors when no match found
     - escape(bool): escape special characters in `pattern`
    """

    COLOR_MAP = {
     'black': 0,
     'red': 1,
     'green': 2,
     'yellow': 3,
     'blue': 4,
     'magenta': 5,
     'cyan': 6,
     'white': 7
    }

    if escape:
        pattern = re.escape(pattern)

    matches = list(re.finditer(pattern.lower(), text.lower()))
    if (not matches) & raise_errors:
        error = ("An exact of pattern '{pattern}' could not be found"
                 f" within text '{text}'")
        raise Exception(error)
    else:
        color_prefix = "\033[39;4{}m".format(COLOR_MAP.get(color, 9))
        color_suffix = "\033[m"

        for i, match in enumerate(matches):
            start, end = match.span()
            offset = len(color_prefix + color_suffix) * i
            start, end = start + offset, end + offset
            text = f"{text[:start]}{color_prefix}{text[start:end]}{color_suffix}{text[end:]}"
        print(text)


def visualize_example(example):
    # print("-"*8)
    print(f"Event: {example.chunk1}, Date: {example.chunk2}, JSL Prediction: {example.relations_out}")
    print("-"*8)
    printh(example.sentence, example.chunk1)
    print("="*8)

In [None]:
for index, row in result_df.iterrows():
    visualize_example(row)
    # print("+"*8)