In [1]:
import xmltodict, json
import os
import pandas as pd
import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from itertools import combinations
import json, re

### Load raw dataset

In [3]:
# !pwd
# !conda env list
# !python --version
# !cd /home/jovyan/work/Temporal_relation/
# !pwd

In [273]:
# Change to path to the data
# path = '/home/wt/Downloads/n2c2 2012/'
wp = '/home/jovyan/work/Temporal_relation/'
path = wp + 'data/i2b2/'
training_data_path = path + 'merge_training'
test_data_path = path + 'ground_truth/merged_xml'

In [9]:
def data_loader(data_path):
    data = {}
    for filename in os.listdir(data_path):
        if filename.endswith(".xml"): 
            f = (os.path.join(data_path, filename))
#             print(f)
            fb = open(f, "rb").read().decode(encoding="utf-8")
#     invalid character '&' https://github.com/martinblech/xmltodict/issues/277
            fb = fb.replace('&', '&amp;')
            dic = xmltodict.parse(fb, attr_prefix='')
#     restore orginal character "&"
            dic['ClinicalNarrativeTemporalAnnotation']['TEXT'] = dic['ClinicalNarrativeTemporalAnnotation']['TEXT'].replace('&amp;', '&')
            data[filename] = (dic)
    return data

In [10]:
train_data = data_loader(training_data_path)
test_data = data_loader(test_data_path)

In [11]:
print(len(train_data), len(test_data))

190 120


In [12]:
def find_first_regex(text, substrings):
    pattern = '|'.join(map(re.escape, substrings))  # Escape special characters
    match = re.search(pattern, text)
    if match:
        return match.start()
    else:
        raise ValueError("None of the substrings found in the text.")

In [319]:
def build_section_graph(doc_id, data, section='all'):
    # for doc_id in list(data.keys())[:1]:
    text = data[doc_id]['ClinicalNarrativeTemporalAnnotation']['TEXT']
    # print(text)
    
    history_start = text.index('HISTORY OF PRESENT ILLNESS ')
    substrings = ['REVIEW OF SYSTEMS', 'HOSPITAL COURSE']
    history_end = find_first_regex(text, substrings)

    sect_start, sect_end = 0, len(text)
    if section == 'history':
        sect_start, sect_end = history_start, history_end
    elif section == 'other':
        sect_start, sect_end = history_end, len(text)
    # print(text[sect_start:sect_end])

    events = pd.DataFrame(data[doc_id]['ClinicalNarrativeTemporalAnnotation']['TAGS']['EVENT'])
    events['start'] = events['start'].astype(int)
    events['end'] = events['end'].astype(int)
    # Filter events in the history section
    # print('events', events.shape)
    events = events.loc[(events['start']>=sect_start) & (events['end']<=sect_end)]
    # print('events after', events.shape)
    
    # FILTER 1: only use events related to medical concepts
    # events = events.loc[events['type'].isin(['PROBLEM', 'TEST', 'TREATMENT'])]
    event_types = dict(zip(events['id'], events['type']))
    
    # Remove duplicated admission and discharge time.
    # adm_dis = pd.DataFrame(data[doc_id]['ClinicalNarrativeTemporalAnnotation']['TAGS']['SECTIME'])
    times = pd.DataFrame(data[doc_id]['ClinicalNarrativeTemporalAnnotation']['TAGS']['TIMEX3'])
    times['start'] = times['start'].astype(int)
    times['end'] = times['end'].astype(int)
    # print('times', times.shape)
    times = times.loc[((times['start']>=sect_start) & (times['end']<=sect_end))]
    # print('times after', times.shape)
    time_types = dict(zip(times['id'], times['type']))
    
    nodes_keep = list(events['id']) + list(times['id'])
    # print(len(nodes_keep))
    
    all_links = pd.DataFrame(data[doc_id]['ClinicalNarrativeTemporalAnnotation']['TAGS']['TLINK'])
    all_links = all_links.loc[all_links['type']!='']

    links = all_links.loc[(all_links['id'].str.lower().str.contains('sectime')==False)]
    
    # FILTER 2: Exclude sectime links not about admission
    if section == 'history':
        section_links = all_links.loc[(all_links['id'].str.lower().str.contains('sectime')==True) & (all_links['toID']=='T0')]
    elif section == 'other':
        section_links = all_links.loc[(all_links['id'].str.lower().str.contains('sectime')==True) & (all_links['toID']=='T1')]
    else:
        links = all_links
    # print(section_links.shape)
    # print(section_links.head())
    # print(section_links.groupby('type')['fromID'].unique())
    # print(set(section_links['fromID']) - set(nodes_keep))
    # print(set(nodes_keep) -  set(section_links['fromID']))
    if section != 'all':
        node_category = dict(zip(section_links['fromID'], section_links['type']))
    
    
    
    # Normalize AFTER and BEFORE relations
    links = links.copy()
    mask = (links['type'] == 'AFTER')
    links.loc[mask, ['fromID', 'fromText', 'toID', 'toText']] = links.loc[mask, ['toID', 'toText', 'fromID', 'fromText']].values
    links.loc[mask, 'type'] = 'BEFORE'
    links = links.drop_duplicates(subset=['fromID', 'fromText', 'toID', 'toText', 'type'], keep='last')
    
    
    G = nx.from_pandas_edgelist(links[['fromID', 'toID', 'type']], source='fromID', target='toID', edge_attr=True, create_using=nx.DiGraph())
    source_nodes = dict(zip(links['fromID'], links['fromText']))
    target_nodes = dict(zip(links['toID'], links['toText']))
    nx.set_node_attributes(G, source_nodes|target_nodes, 'text')
    if section != 'all':
        nx.set_node_attributes(G, node_category, 'time2section')
    nx.set_node_attributes(G, event_types|time_types, 'type')
    
    # only keep nodes of interest
    # FILTER 3: only subgraph
    G = G.subgraph(nodes_keep).copy()
    
    # clear reverse links and reduce redundent nodes; 
    # There are no many duplicated links
    return G, text[sect_start:sect_end]

In [324]:
# G, text = build_section_graph('36.xml', train_data, 'all')
# G, text = build_section_graph('36.xml', train_data, 'history')
G, text = build_section_graph('36.xml', train_data, 'other')

In [325]:
print(f"Graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

Graph with 134 nodes and 116 edges


In [300]:
G, text = build_section_graph('36.xml', train_data, 'history')

In [301]:
print(f"Graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

Graph with 183 nodes and 177 edges


In [16]:
print(f"Graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

Graph with 45 nodes and 42 edges


In [17]:
nx.write_graphml(G, wp+"graphs/tem_history_graph.graphml")

### Merge overlap nodes

In [258]:
def merge_overlapping_nodes(G):
    """
    Merge nodes that are connected by edges with type='overlp'.
    Only merges edges across merged groups if they have the same direction.
    Maintains node attributes and edge attributes as JSON strings for GraphML compatibility.
    Tracks source nodes for each merged edge.
    
    Args:
        G (nx.Graph or nx.DiGraph): Input graph
        
    Returns:
        nx.Graph or nx.DiGraph: New graph with merged nodes and edges
    """
    # Create new merged graph of same type as input
    merged_G = G.__class__()
    
    # Find connected components considering only overlap edges
    overlap_edges = [(u, v) for u, v, d in G.edges(data=True) if d.get('type') == 'OVERLAP']
    overlap_graph = nx.Graph()  # Undirected for finding components
    overlap_graph.add_edges_from(overlap_edges)
    
    # Get clusters of nodes to merge
    clusters = list(nx.connected_components(overlap_graph))
    
    # Create mapping from original nodes to their merged cluster names
    node_to_cluster = {}
    for cluster in clusters:
        cluster = list(cluster)
        merged_name = '+'.join(sorted(cluster))
        for node in cluster:
            node_to_cluster[node] = merged_name
    
    # Process nodes
    for cluster in clusters:
        cluster = list(cluster)
        
        if len(cluster) == 1:
            # Single node, just copy it and its attributes
            node = cluster[0]
            merged_G.add_node(node, **G.nodes[node])
            continue
            
        # Create merged node name
        merged_node = '+'.join(sorted(cluster))
        
        # Combine node attributes and convert to JSON string
        merged_attrs = {
            'original_nodes': json.dumps(cluster),
            'node_attributes': json.dumps({node: dict(G.nodes[node]) for node in cluster})
        }
        
        # Add merged node
        merged_G.add_node(merged_node, **merged_attrs)
    
    # Add nodes that weren't in any cluster
    unclustered_nodes = set(G.nodes()) - set(node for cluster in clusters for node in cluster)
    for node in unclustered_nodes:
        merged_G.add_node(node, **G.nodes[node])
        node_to_cluster[node] = node  # Map to itself
    
    # Create a dictionary to store edges between clusters
    cluster_edges = {}  # (from_cluster, to_cluster) -> list of original edges with source info
    
    # Process edges
    for u, v, data in G.edges(data=True):
        # Get cluster names (or original names for unclustered nodes)
        u_cluster = node_to_cluster[u]
        v_cluster = node_to_cluster[v]
        
        # Skip internal edges of merged clusters if they were overlap edges
        if u_cluster == v_cluster and data.get('type') == 'OVERLAP':
            continue
        
        # Add source node information to edge data
        edge_data = data.copy()
        edge_data['source_nodes'] = {'from': u, 'to': v}
        
        # Create edge key based on direction
        edge_key = (u_cluster, v_cluster)
        
        # For directed graphs, maintain direction information
        if isinstance(G, nx.DiGraph):
            if edge_key not in cluster_edges:
                cluster_edges[edge_key] = []
            cluster_edges[edge_key].append((u, v, edge_data))
        else:
            # For undirected graphs, normalize the edge key
            normalized_key = tuple(sorted([u_cluster, v_cluster]))
            if normalized_key not in cluster_edges:
                cluster_edges[normalized_key] = []
            cluster_edges[normalized_key].append((u, v, edge_data))
    
    # Add merged edges to the graph
    for (from_cluster, to_cluster), edges in cluster_edges.items():
        # For directed graphs, check if all edges have the same direction
        if isinstance(G, nx.DiGraph):
            # Check if all edges go in the same direction
            directions = set((u_cluster, v_cluster) 
                           for u, v, _ in edges
                           for u_cluster, v_cluster in [(node_to_cluster[u], node_to_cluster[v])])
            
            if len(directions) == 1:  # All edges have same direction
                merged_G.add_edge(from_cluster, to_cluster, 
                                edge_attributes=json.dumps([data for _, _, data in edges]))
        else:
            # For undirected graphs, just add the edge
            merged_G.add_edge(from_cluster, to_cluster, 
                            edge_attributes=json.dumps([data for _, _, data in edges]))
    
    return merged_G

def load_merged_graph(graphml_file):
    """
    Load a merged graph from GraphML file and convert JSON string attributes back to Python objects.
    """
    G = nx.read_graphml(graphml_file)
    
    # Convert node attributes back from JSON
    for node in G.nodes():
        if 'original_nodes' in G.nodes[node]:
            G.nodes[node]['original_nodes'] = json.loads(G.nodes[node]['original_nodes'])
            G.nodes[node]['node_attributes'] = json.loads(G.nodes[node]['node_attributes'])
    
    # Convert edge attributes back from JSON
    for u, v in G.edges():
        if 'edge_attributes' in G[u][v]:
            G[u][v]['edge_attributes'] = json.loads(G[u][v]['edge_attributes'])
    
    return G

In [259]:
mergaed_G = merge_overlapping_nodes(G)

In [260]:
print(f"Graph with {mergaed_G.number_of_nodes()} nodes and {mergaed_G.number_of_edges()} edges")

Graph with 70 nodes and 51 edges


In [21]:
nx.write_graphml(mergaed_G, wp+"graphs/tem_history_graph_merge.graphml")

In [23]:
print(list(mergaed_G.nodes(data=True)))

[('E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', {'original_nodes': '["E102", "E113", "E115", "E112", "E5", "T4", "E8", "E7", "E111", "E4", "E105", "T7", "E6", "E114", "E106", "T6"]', 'node_attributes': '{"E102": {"text": "short of breath", "time2section": "BEFORE", "type": "PROBLEM"}, "E113": {"text": "lower extremity edema", "time2section": "BEFORE", "type": "PROBLEM"}, "E115": {"text": "Her lower extremity edema", "time2section": "BEFORE", "type": "PROBLEM"}, "E112": {"text": "sharp", "time2section": "BEFORE", "type": "PROBLEM"}, "E5": {"text": "sweating", "time2section": "BEFORE", "type": "PROBLEM"}, "T4": {"text": "several years", "type": "DURATION"}, "E8": {"text": "syncope", "time2section": "BEFORE", "type": "PROBLEM"}, "E7": {"text": "vomiting", "time2section": "BEFORE", "type": "PROBLEM"}, "E111": {"text": "nonradiating", "time2section": "BEFORE", "type": "PROBLEM"}, "E4": {"text": "chest twinges", "time2section": "BEFORE", "type": "PROBLEM"}, "E105": {"tex

### Remove conflicting relations (e.g., self-link and mutual links)

In [261]:
def remove_self_links(G):
    H = G.copy()
    self_loops = list(nx.selfloop_edges(H))
    H.remove_edges_from(self_loops)
    return H

In [262]:
def remove_mutual_links(G):
    H = G.copy()
    edges_to_remove = set()
    
    for u, v in G.edges():
        if H.has_edge(v, u) and (v, u) not in edges_to_remove and (u, v) not in edges_to_remove:
            edge1 = G.get_edge_data(u, v)
            edge2 = G.get_edge_data(v, u)
            edges_to_remove.add((u, v))
            edges_to_remove.add((v, u))
    H.remove_edges_from(edges_to_remove)
    return H

In [263]:
mergaed_G_clean = remove_self_links(mergaed_G)
mergaed_G_clean = remove_mutual_links(mergaed_G_clean)

In [264]:
print(f"Graph with {mergaed_G_clean.number_of_nodes()} nodes and {mergaed_G_clean.number_of_edges()} edges")

Graph with 70 nodes and 51 edges


### Remove redundant links

In [265]:
def remove_redundant_edges(G):
    redundant_edges = []
    
    # Iterate over edges while capturing their attributes
    edges = list(G.edges(data=True))  # List of tuples (u, v, data_dict)
    
    for u, v, data in edges:
        # Remove the edge and check if a path still exists
        G.remove_edge(u, v)
        
        if nx.has_path(G, u, v):
            redundant_edges.append((u, v))
        
        # Re-add the edge with its original attributes
        G.add_edge(u, v, **data)
    
    # Remove redundant edges (preserves attributes of non-redundant edges)
    G.remove_edges_from(redundant_edges)
    return G

In [266]:
mergaed_G_clean = remove_redundant_edges(mergaed_G_clean)

In [267]:
print(f"Graph with {mergaed_G_clean.number_of_nodes()} nodes and {mergaed_G_clean.number_of_edges()} edges")

Graph with 70 nodes and 51 edges


### Minimal paths and clean nodes rather than 'PROBLEM', 'TEST', 'TREATMENT'. A --> B --> C  and D-->C 

In [268]:
def minimal_path_cover(G):
    """
    Find a minimal collection of paths that cover all edges in a directed graph.
    
    Args:
        G: A NetworkX directed graph (DiGraph)
        
    Returns:
        A list of paths, where each path is a list of nodes
    """
    if not G.edges():
        return []
    
    # Create a working copy of the graph
    remaining_edges = G.copy()
    paths = []
    
    while remaining_edges.edges():
        # Find longest path in the remaining graph
        # This is a greedy approach - finding the truly minimal cover is NP-hard
        longest_path = find_longest_path(remaining_edges)
        
        # Add the path to our collection
        paths.append(longest_path)
        
        # Remove the edges in this path from the remaining graph
        for i in range(len(longest_path) - 1):
            u, v = longest_path[i], longest_path[i + 1]
            if remaining_edges.has_edge(u, v):
                remaining_edges.remove_edge(u, v)
    
    return paths

def find_longest_path(G):
    """
    Find the longest path in a directed graph.
    
    Args:
        G: A NetworkX directed graph (DiGraph)
        
    Returns:
        A list of nodes representing the longest path
    """
    # For each node, try to find the longest path starting from it
    longest_path = []
    
    for start_node in G.nodes():
        # Skip nodes with no outgoing edges
        if G.out_degree(start_node) == 0:
            continue
            
        # Find the longest path from this start node
        path = find_longest_path_from_node(G, start_node)
        
        # Update longest path if this one is longer
        if len(path) > len(longest_path):
            longest_path = path
    
    return longest_path

def find_longest_path_from_node(G, start_node):
    """
    Find the longest path starting from a specific node.
    
    Args:
        G: A NetworkX directed graph (DiGraph)
        start_node: The starting node
        
    Returns:
        A list of nodes representing the longest path from start_node
    """
    # Use dynamic programming to find the longest path
    # This is much more efficient than a brute force approach
    
    # Initialize distances and paths
    dist = {node: -float('inf') for node in G.nodes()}
    dist[start_node] = 0
    pred = {node: None for node in G.nodes()}
    
    # Topologically sort the nodes
    try:
        topo_order = list(nx.topological_sort(G))
    except nx.NetworkXUnfeasible:
        # Graph has cycles, so we'll use a heuristic approach
        # For simplicity, we'll use a DFS-based approach
        visited = set()
        path = [start_node]
        current_path = []
        dfs_longest_path(G, start_node, visited, path, current_path)
        return current_path
    
    # Dynamic programming to find longest path
    for node in topo_order:
        for successor in G.successors(node):
            if dist[successor] < dist[node] + 1:
                dist[successor] = dist[node] + 1
                pred[successor] = node
    
    # Find the node with the maximum distance
    end_node = max(dist, key=dist.get)
    
    # Reconstruct the path
    path = []
    while end_node is not None:
        path.append(end_node)
        end_node = pred[end_node]
    
    # Reverse to get from start to end
    return path[::-1]

def dfs_longest_path(G, node, visited, path, longest_path):
    """
    DFS helper for finding the longest path in a graph with cycles.
    
    Args:
        G: A NetworkX directed graph (DiGraph)
        node: Current node
        visited: Set of visited nodes in current path
        path: Current path
        longest_path: Reference to the longest path found so far
    """
    visited.add(node)
    
    for neighbor in G.successors(node):
        if neighbor not in visited:
            path.append(neighbor)
            dfs_longest_path(G, neighbor, visited, path, longest_path)
            path.pop()
    
    if len(path) > len(longest_path):
        longest_path.clear()
        longest_path.extend(path)
    
    visited.remove(node)

In [269]:
# type_attrs

In [274]:
nx.write_graphml(mergaed_G, wp+"graphs/tem_other_graph_merge.graphml")

In [280]:
covering_paths = minimal_path_cover(mergaed_G_clean)
G = mergaed_G_clean.copy()
print("Minimal path cover:")
for i, path in enumerate(covering_paths):
    print(f"Path {i+1}: {path}")
    texts = []
    for node in path:
        
        text = G.nodes[node].get("text", None)
        type = G.nodes[node].get("type", None)
        time2section = G.nodes[node].get("time2section", None)
        if type == None:
            tlist = []
            type_attrs = G.nodes[node].get("node_attributes")
            type_attrs = json.loads(type_attrs)
            for nid in type_attrs.keys():
                # print(type_attrs[nid])
                text = type_attrs[nid]['text']
                type = type_attrs[nid]['type']
                
                time2section = type_attrs[nid].get('time2section', None)
                if type in ['PROBLEM', 'TEST', 'TREATMENT', 'DURATION', 'DATE', 'FREQUENCY']:
                    tlist.append(text)
                    print(time2section, '....')
            texts.append(', '.join(tlist))
        else:
            if type in ['PROBLEM', 'TEST', 'TREATMENT', 'DURATION', 'DATE', 'FREQUENCY']:
                texts.append(text)
                print(time2section)
    # print(texts)
    print(' ---> '.join(texts))

Minimal path cover:
Path 1: ['E68', 'E54+E67', 'E66', 'E64+E65+E71', 'E69', 'E70+T16', 'E160']
BEFORE
BEFORE ....
BEFORE ....
BEFORE
BEFORE ....
BEFORE ....
BEFORE ....
BEFORE
BEFORE ....
None ....
poor perfusion ---> edema, her edema ---> distention ---> Infectious disease, her cellulitis, cellulitis on her legs ---> cefazolin ---> Keflex, 10 more days
Path 2: ['E142+E47+E52+T13', 'E141+E146', 'E145+E50+E51', 'E49', 'E89']
None ....
BEFORE ....
BEFORE ....
BEFORE ....
BEFORE ....
BEFORE ....
AFTER ....
BEFORE
q day, hydrochlorothiazide, Lisinopril, an aspirin --->  ---> heart rate, gentle doses, careful monitoring ---> bradycardia
Path 3: ['E123', 'E122', 'E124+E21', 'E121+E129+E84', 'E125+E126']
BEFORE
BEFORE
BEFORE ....
BEFORE ....
BEFORE ....
hypertension ---> diastolic dysfunction ---> congestive heart failure, pulmonary hypertension --->  ---> negative enzymes
Path 4: ['E135+E136', 'E131+E132+E133+E134+E34+E35+E36+E37+E39', 'E41']
BEFORE ....
BEFORE ....
BEFORE ....
BEFORE ....
B

In [281]:
# TODO summarize path based on context

In [33]:
G = mergaed_G_clean.copy()
print("All paths in the graph with labels:")
for source in G.nodes:
    for target in G.nodes:
        if source != target:
            # Get all simple paths
            paths = list(nx.all_simple_paths(G, source=source, target=target))
            for path in paths:
                # Convert node IDs to their labels
                print(path)
                # nodes_types = [G.nodes[node]["type"] for node in path]
                # node_labels = [G.nodes[node]["text"] for node in path]

                # # Find edges along the path and get their labels
                # edge_labels = [
                #     G.edges[path[i], path[i + 1]]["type"]
                #     for i in range(len(path) - 1)
                # ]

                
                # # Print the path with labels
                # if len(set(['PROBLEM', 'TEST', 'TREATMENT']).intersection(set(nodes_types)))>0:
                #     print(f"Path: {' -> '.join(node_labels)}")
                #     print(f"Path: {' -> '.join(nodes_types)}")
                #     print(f"Edges: {' -> '.join(edge_labels)}")
                print()

All paths in the graph with labels:
['E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8']

['E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8', 'E116']

['E103+E104', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7']

['E103+E104', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8']

['E103+E104', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8', 'E116']

['E9', 'E119']

['T8', 'E116']

['E118', 'E117', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7']

['E118', 'E117', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8']

['E118', 'E117', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8', 'E116']

['E118', 'E117']

['E101', 'E100+E82+E99']

['E117', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7']

['E117', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8']

['E117', 'E102+E105+E106+E111

In [39]:
# Path 1: ['E118', 'E117', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7', 'T8', 'E116']
# Path 2: ['E103+E104', 'E102+E105+E106+E111+E112+E113+E114+E115+E4+E5+E6+E7+E8+T4+T6+T7']
# Path 3: ['E9', 'E119']
# Path 4: ['E101', 'E100+E82+E99']
# Path 5: ['E108', 'E1+E107+E109+E110+E2+E3+E83+E95+E96+E97+E98+T2+T3+T5']

In [35]:
from openai import OpenAI
from lmformatenforcer import JsonSchemaParser
from pydantic import BaseModel
import re, json, os
from typing import Optional, Type, TypeVar

In [36]:
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://host.docker.internal:8000/v1"
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

In [279]:
"Extract clinical 'PROBLEM', 'TEST', 'TREATMENT' from the text and estimate the time of each event happened." + text 

"Extract clinical 'PROBLEM', 'TEST', 'TREATMENT' from the text and estimate the time of each event happened.HISTORY OF PRESENT ILLNESS :\nSaujule Study is a 77-year-old woman with a history of obesity and hypertension who presents with increased shortness of breath x 5 days .\nHer shortness of breath has been progressive over the last 2-3 years .\nShe has an associated dry cough but no fevers , chills , or leg pain .\nShe has dyspnea on exertion .\nShe ambulates with walker and a cane secondary to osteoarthritis .\nShe becomes short of breath just by getting up from her chair and can only walk 2-3 steps on a flat surface .\nShe feels light headed when getting up .\nHer shortness of breath and dyspnea on exertion has been progressive for the past several years .\nIt has not been sudden or acute .\nShe sleeps in a chair up right for the last 2 1/2 years secondary to osteoarthritis .\nShe has orthopnea as well but noparoxysmal nocturnal dyspnea .\nShe occasionally feels chest twinges whic

In [38]:
chat_response = client.chat.completions.create(
    model = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Extract clinical 'PROBLEM', 'TEST', 'TREATMENT' from the text and estimate the time of each event happened." + text},
    ]
)
print("Chat response:", chat_response.choices[0].message.content)

Chat response: Alright, I need to figure out how to extract 'PROBLEM', 'TEST', and 'TREATMENT' from the provided patient history. The user also wants the time each event occurred estimated.

First, looking for PROBLEMS. The patient is a 77-year-old woman with obesity and hypertension. The main presenting issue is increased shortness of breath for 5 days, but it's been progressing for 2-3 years. Also, she has a dry cough, orthopnea, dyspnea on exertion, and lower extremity edema with cellulitis episodes. She feels light-headed upon standing, has chest twinges, and sleeps upright.

So, the PROBLEMS would include her chronic conditions like obesity, hypertension, and the various symptoms like SOB, dyspnea, orthopnea, edema, chest twinges, etc. I should list each clearly.

Next, looking for TESTS. The text doesn't mention any specific tests she's had. It talks about her symptoms and management but not about any diagnostic procedures or lab tests. So, I'll note that no tests are mentioned i