In [10]:
import numpy as np
import pandas as pd
import os
import pydicom as dicom
from report_parser import parse_report
from collections import Counter, defaultdict

In [4]:
csv_file_path = '/crimea/mimic-cxr/mimic-cxr-map.csv'
images_path = '/crimea/mimic-cxr/_images'
reports_path = '/crimea/mimic-cxr/reports'

In [5]:
df =  pd.read_csv(csv_file_path, sep=',', header=0)
df = df.loc[df['dicom_is_available'],:]
df.reset_index()
display(df.head())

Unnamed: 0,subject_id,rad_id,dicom_id,dicom_is_available
0,70233355,52727485,3b9565f5-69ab0d33-1a9d2d1b-bb09c424-7f0243e6,True
1,70233355,52727485,8074bd10-62acdde0-3df2608b-13ca2322-09ce372c,True
2,70233355,53378012,3108d905-782ffdc0-209309e8-2413eeb4-6bfb958a,True
3,70233355,53378012,40eab5a8-31446771-08c6b024-2717a65c-41f8c74f,True
4,70233355,55587989,cef1a7ea-8c7df75c-41070128-7cdf5c89-23682e1b,True


In [6]:
image_files  = set(os.listdir(images_path))
report_files = set(os.listdir(reports_path))

print('images:  %6d' % len(image_files))
print('reports: %6d' % len(report_files))

images:   91664
reports: 206574


In [8]:
captioned = {}
for idx, row in df.iterrows():
    dicom_file  = str(row['dicom_id'])+'.dcm'
    report_file = str(row['rad_id'])+'.txt'
    if (dicom_file in image_files) and (report_file in report_files):
        captioned[row['rad_id']] = (dicom_file,report_file)
        
print(len(captioned))

40027


In [11]:
# Display a few notes
N = 15
first_n = dict(list(captioned.items())[:N])
for rad_id,(dicom_file,report_file) in first_n.items():

    dicom_path = os.path.join(images_path,dicom_file)
    plan = dicom.read_file(dicom_path, stop_before_pixels=False)
    parsed_report = parse_report(os.path.join(reports_path,report_file))

    print('===================================================')
    print('Patient ID:', plan.PatientID)
    if 'findings' in parsed_report:
        print(parsed_report['findings'])

    #break

Patient ID: 77631843
there is no focal consolidation or pneumothorax. there is a NAME left pleural effusion with underlying atelectasis, decreased since DATE. postsurgical changes in the left lung are stable. the cardiomediastinal silhouette is shifted to the left, unchanged since the prior exam and likely due to volume loss. the imaged upper abdomen is unremarkable. the bones are intact.
Patient ID: 4029
cardiomediastinal silhouette is stable. the heart is not enlarged. there is no focal consolidation, pleural effusion, or pneumothorax. no pulmonary edema. multilevel degenerative changes in the NAME are noted.
Patient ID: 73114826
the lungs are clear without consolidation or edema. there is no pleural effusion or pneumothorax. the cardiomediastinal silhouette is normal.
Patient ID: 71581571
the cardiac, mediastinal and hilar contours are normal. lungs are clear. no pleural effusion or pneumothorax is seen. no acute osseous abnormalities are detected.
Patient ID: 75309451
as compared t

In [12]:
# Sort rad_ids into train or test
N = 1500
cohort = list(captioned.items())[:N]

ind = int(0.70*N)

train_cohort = cohort[:ind]
test_cohort  = cohort[ ind:]

print('train: ', len(train_cohort))
print('test:  ', len(test_cohort))

train:  1050
test:   450


In [13]:
n = 3

In [16]:
# Build ngrams for inputs & outputs


LM = defaultdict(Counter)
for rad_id,(dicom_file,report_file) in train_cohort:

    dicom_path = os.path.join(images_path,dicom_file)
    plan = dicom.read_file(dicom_path, stop_before_pixels=False)
    parsed_report = parse_report(os.path.join(reports_path,report_file))

    if 'findings' in parsed_report:
        toks = parsed_report['findings'].replace('.', ' . ').split()
        padded_toks = ['<START>' for _ in range(n-1)] + toks + ['<END>']
        for i in range(len(padded_toks)-n+1):
            context = tuple(padded_toks[i:i+n-1])
            target = padded_toks[i+n-1]
            
            LM[context][target] += 1
            
print(sorted(LM.items(), key=lambda t:sum(t[1].values()))[-1])

(('<START>', '<START>'), Counter({'the': 160, 'pa': 71, 'there': 57, 'frontal': 57, 'as': 52, 'in': 36, 'ap': 30, 'heart': 22, 'a': 19, 'compared': 18, 'lung': 15, 'cardiac': 10, 'no': 10, 'lungs': 9, 'cardiomediastinal': 9, 'comparison': 8, 'right': 8, 'single': 8, 'NAME': 8, '<END>': 5, 'since': 5, 'interval': 5, 'moderate': 4, 'patient': 4, 'tip': 4, 'mild': 4, 'cardiac,': 3, 'bibasilar': 3, 'left-sided': 3, 'low': 3, 'right-sided': 3, 'increased': 3, 'dual': 2, 'stable': 2, 'again': 2, 'one': 2, '2': 2, 'endotracheal': 2, 'previously': 2, 'et': 2, 'severe': 2, 'portable': 2, 'study': 2, 'tracheostomy': 2, 'left': 2, 'air': 1, 'indwelling': 1, 'lordotic': 1, 'eventration': 1, 'significantly': 1, 'bilateral': 1, 'central': 1, 'persistent': 1, 'new': 1, 'borderline': 1, 'pacemaker': 1, 'two': 1, 'redemonstrated': 1, 'slightly': 1, 'technologist': 1, 'radiodense': 1, 'image': 1, 'NAME-pyloric': 1, 'substantially': 1, 'lower': 1, 'large': 1, 'dual-lead': 1, 'assessment': 1, 'consolidati

In [19]:
def sample(seq_so_far):
    #print(seq_so_far)
    last = tuple(seq_so_far[-n+1:])
    words,P = list(zip(*LM[last].items()))
    P = np.array(P) / sum(P)
    choice = np.random.choice(words, p=P)
    return choice
    #y = clf.predict(x)[0]
    #next_word = y_vect.translate(y)
    #return next_word
    
#seq = ['<START>', '<START>', 'cardiac', 'silhouette']
seq = ['<START>' for _ in range(n-1)]
current = seq[-1]
while current != '<END>' and len(seq)<100:
    next_word = sample(seq)
    #print(next_word)
    seq.append(next_word)
    current = next_word
    #break
    
print(len(seq))
print(seq)

12
['<START>', '<START>', 'pa', 'and', 'lateral', 'views', 'of', 'the', 'NAME', 'aorta', '.', '<END>']


<h1>Evaluation</h1>

In [30]:
# TODO: actual scoring functions
def bleu(ref_toks, pred_toks):
    return 0.5

def meteor(ref_toks, pred_toks):
    return 0.7

def cider(ref_toks, pred_toks):
    return 0.1

In [31]:
bleus = []
for rad_id,(dicom_file,report_file) in test_cohort:
    dicom_path = os.path.join(images_path,dicom_file)
    plan = dicom.read_file(dicom_path, stop_before_pixels=False)
    parsed_report = parse_report(os.path.join(reports_path,report_file))
    
    if 'findings' not in parsed_report:
        continue
     
    print(rad_id)
    print()
    
    # get reference report
    reference_toks = parsed_report['findings'].replace('.', ' . ').split()
    
    # get generated report using the above method
    generated_toks = ['<START>' for _ in range(n-1)]
    current = generated_toks[-1]
    while current != '<END>' and len(generated_toks)<100:
        next_word = sample(generated_toks)
        #print(next_word)
        generated_toks.append(next_word)
        current = next_word
        #break
    generated_toks = generated_toks[n-1:]
        
    print('generated')
    print(generated_toks)
    
    # Compute eval metrics
    B = bleu(reference_toks, generated_toks)
    bleus.append(B)
    
    break

58001075

generated
['in', 'comparison', 'with', 'study', 'of', 'DATE,', 'there', 'is', 'persistent', 'apparent', 'blunting', 'of', 'the', 'cardiac', 'silhouette', '.', 'no', 'pneumothorax', '.', 'old', 'right', 'lateral', 'rib', 'deformities', 'are', 'noted,', 'chronicity', 'indeterminate', '.', '<END>']
