## Create sankey diagrams

In [36]:
data_dir = 'd:/data/covid/tajson'
left_category = 'Diagnosis'
right_category = 'MedicationName'
left_ontology = {
 'C5203670': ('COVID-19',0),
 'C3714514': ('infection',1),
 'C0011065': ('death',2),
 'C0042769': ('viral infections',1),
 'C1175175': ('SARS',3),
 'C0009450': ('infectious disease',1),
 'C0006826': ('CA',4), #cardiac arrest (остановка сердца :0)
 'C0221423': ('illnesses',5),
 'C0035222': ('ARDS',6), #Acute respiratory distress syndrome (типа воспаление легких)
 'C0011849': ('diabetes',7),
 'C0032285': ('pneumonia',8),
 'C0021400': ('influenza',9)
}
right_ontology = {
 'C0020336': ('hydroxychloroquine', 0),
 'C0008269': ('chloroquine', 1),
 'C1609165': ('Tocilizumab', 2),
 'C4726677': ('remdesivir', 3),
 'C0052796': ('azithromycin', 4),
 'C0674432': ('lopinavir', 5),
 'C0292818': ('ritonavir', 5),
 'C0034386': ('quarantine', 6),
 'C0011777': ('dexamethasone', 7),
 'C0035525': ('ribavirin', 8),
 'C0021641': ('insulin', 9),
 'C0042866': ('vitamin D', 10),
}

In [49]:
def getmap(ontology):
    omap = {}
    for c,(n,i) in ontology.items():
        if i not in omap.keys():
            omap[i] = n
    return omap

left_map = getmap(left_ontology)
right_map = getmap(right_ontology)

In [37]:
import json
import os
import glob
import numpy as np
import tqdm

In [57]:
def load(fn):
    with open(fn) as f:
        return json.load(f)

def get_onto(lst,onto='UMLS'):
    try:
        return [ x['id'] for x in lst['links'] if x['dataSource']==onto][0]
    except:
        return ""

snd = lambda x: x[1]
lleft = max(map(snd,left_ontology.values()))+1
lright = max(map(snd,right_ontology.values()))+1

matrix = np.zeros(shape=(lleft,lright))

for fn in tqdm.tqdm(glob.glob(data_dir+'/*.json')):
    #print(f" + Processing {fn}")
    js = load(fn)
    for p in js.values():
        left_set = set(get_onto(t) for t in p['entities'] if t['category']==left_category) - set([''])
        right_set = set(get_onto(t) for t in p['entities'] if t['category']==right_category) - set([''])
        # print(left_set,right_set)
        for l in left_set & set(left_ontology.keys()):
            for r in right_set & set(right_ontology.keys()):
                matrix[left_ontology[l][1],right_ontology[r][1]]+=1

100%|██████████| 626/626 [02:28<00:00,  4.21it/s]


In [44]:
matrix

array([[3.100e+03, 1.434e+03, 1.105e+03, 1.652e+03, 1.036e+03, 2.313e+03,
        1.136e+03, 4.430e+02, 2.830e+02, 2.650e+02, 3.900e+02, 0.000e+00],
       [9.510e+02, 6.000e+02, 3.710e+02, 6.100e+02, 3.390e+02, 7.160e+02,
        7.000e+02, 2.080e+02, 2.820e+02, 1.130e+02, 2.290e+02, 0.000e+00],
       [5.260e+02, 2.010e+02, 2.450e+02, 2.870e+02, 2.130e+02, 2.960e+02,
        1.810e+02, 1.020e+02, 8.000e+01, 4.700e+01, 6.800e+01, 0.000e+00],
       [8.600e+01, 6.600e+01, 2.300e+01, 9.200e+01, 2.600e+01, 1.710e+02,
        7.400e+01, 1.300e+01, 1.270e+02, 5.000e+00, 8.000e+00, 0.000e+00],
       [9.800e+01, 3.300e+01, 3.600e+01, 3.500e+01, 3.900e+01, 4.200e+01,
        3.000e+00, 3.400e+01, 1.400e+01, 7.000e+00, 2.200e+01, 0.000e+00],
       [2.440e+02, 6.500e+01, 1.780e+02, 1.200e+02, 1.020e+02, 1.740e+02,
        6.300e+01, 5.700e+01, 5.600e+01, 4.200e+01, 2.500e+01, 0.000e+00],
       [2.180e+02, 6.500e+01, 2.110e+02, 1.150e+02, 1.040e+02, 1.850e+02,
        4.000e+00, 8.300e+01, 3.

In [55]:
import plotly.graph_objects as go

threshold = 100
highlite1 = [0]
highlite2 = set()

all_nodes = list(left_map.values()) + list(right_map.values())
source_indices = list(range(len(left_map)))
target_indices = list(range(len(left_map),len(left_map)+len(right_map)))

s, t, v, c = [], [], [], []

for i in left_map.keys():
    for j in right_map.keys():
        if matrix[i,j] > threshold:
            s.append(i)
            t.append(len(left_map)+j)
            v.append(matrix[i,j])
            c.append('magenta' if j in highlite2 or i in highlite1 else 'lightgrey')

fig = go.Figure(data=[go.Sankey(
    # Define nodes
    node = dict(
    pad = 40,
    thickness = 40,
    line = dict(color = "black", width = 1.0),
    label =  all_nodes,
    #color =  "blue"
    ),

    # Add links
    link = dict(
    source =  s,
    target =  t,
    value =  v,
    color = c
))])

fig.update_layout(title_text='Diagram',
                font_size=13)
fig.show()
