In [10]:
from time import gmtime, strftime
import numpy as np
import pandas as pd
import os
import dicom
from report_parser import parse_report
import pickle 
import tqdm
import random
from collections import Counter, defaultdict

In [3]:
# Get train and test data
data_dir = '/scratch/wboag/2019/cxr/cxr-baselines/camera_ready/data'
train_df = pd.read_csv(os.path.join(data_dir,'train.tsv'), sep='\t')
test_df  = pd.read_csv(os.path.join(data_dir, 'test.tsv'), sep='\t')

print(train_df.shape)
display(train_df.head())

print(test_df.shape)
display(test_df.head())

(228136, 7)


Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Unnamed: 0.1.1,subject_id,rad_id,dicom_id,dicom_is_available
0,0,0,2,70233355,53378012,3108d905-782ffdc0-209309e8-2413eeb4-6bfb958a,True
1,1,1,3,70233355,53378012,40eab5a8-31446771-08c6b024-2717a65c-41f8c74f,True
2,2,2,6,78564939,51423061,fc601540-ae89d087-3589ac06-85224a6b-bb5960ce,True
3,3,3,7,78564939,51423061,9956b6ce-67a4e84b-6038ce80-52428d83-04d83f25,True
4,4,4,8,71322,51527637,d5072bc4-bb422de8-97f3973a-0d8e5ae0-7c52ac3b,True


(99145, 7)


Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Unnamed: 0.1.1,subject_id,rad_id,dicom_id,dicom_is_available
0,0,0,27,73451872,55507110,f1a7e903-618a45fe-84eb71e2-73901894-a689d584,True
1,1,1,28,73451872,55507110,523dbd29-f0c5d7eb-09635cf1-1a7de126-44622b1c,True
2,2,2,37,68870,51526655,78ecaf71-9fdb0b43-b0134402-8c5e739f-2c6c0ea2,True
3,3,3,38,68870,51526655,14089000-1023e4ed-157da1b0-f14f1dcd-7eaf3cb2,True
4,4,4,39,68870,57395479,ceb36d05-686e9404-43dfdc4f-e050bf09-89b8d71d,True


In [4]:
with open('/crimea/wboag/2019/cxr/camera_ready_top100.pkl', 'rb') as f:
    neighbors = pickle.load(f)
    
print(len(neighbors))

99145


In [5]:
# Map each dicom to its rad_id
rad_lookup = dict(train_df[['dicom_id','rad_id']].values)
dict(list(rad_lookup.items())[:5])

{'397b3697-73db8d26-149babd2-a0452bd3-e6f85f4d': 50586031,
 '521d3636-a277fea0-7c075fee-5cd06409-314809ee': 55830882,
 '6e5a3f22-489c2c40-dfbe9d42-3286db49-9f27ee33': 51891019,
 '85cf089b-08c8b9de-5fe26672-ff11d29c-c74d4156': 52327947,
 'e5f4b2a8-32813d3e-a1c75b35-906cbe86-2dcd36e0': 59272024}

In [6]:
# Where to read reports
reports_path = '/crimea/mimic-cxr/reports'

In [7]:
# Build language model from the 100 given dicom_id neighbors

START = '<START>'
END   = '<END>'

def fit(dicom_ids, n=3): 
    # Language model
    LM = defaultdict(Counter)
    
    for dicom_id in dicom_ids:

        rad_id = rad_lookup[dicom_id]
        parsed_report = parse_report(os.path.join(reports_path,'%s.txt'%rad_id))

        if 'findings' in parsed_report:
            toks = parsed_report['findings'].replace('.', ' . ').split()
            padded_toks = [START for _ in range(n-1)] + toks + [END]
            for i in range(len(padded_toks)-n+1):
                context = tuple(padded_toks[i:i+n-1])
                target = padded_toks[i+n-1]
            
                # TODO: get similarities 
                #sim = sim_score(img1,img2)
                sim = 1
            
                LM[context][target] += sim
    return LM

In [8]:
# Sample from the string of tokens you're generating

def sample(seq_so_far):
    #print(seq_so_far)
    last = tuple(seq_so_far[-n+1:])
    words,P = list(zip(*LM[last].items()))
    P = np.array(P) / sum(P)
    choice = np.random.choice(words, p=P)
    return choice
    #y = clf.predict(x)[0]
    #next_word = y_vect.translate(y)
    #return next_word

In [13]:
n = 3

generated_reports = {}
for pred_dicom in tqdm.tqdm(test_df.dicom_id):
    
    # Build ngram model from the neighbors
    nn = neighbors[pred_dicom]
    LM = fit(nn, n=n)
    
    # get generated report by sampling from the ngram model 
    #   (i.e. select next word with probability that it follows given (n-1) words)
    generated_toks = [START for _ in range(n-1)]
    current = generated_toks[-1]
    while current != END and len(generated_toks)<100:
        next_word = sample(generated_toks)
        #print(next_word)
        generated_toks.append(next_word)
        current = next_word
        #break
    generated_toks = generated_toks[n-1:]
    if generated_toks[-1] == END: generated_toks[:-1]
    
    # Store generated sentence
    g_toks = ' '.join(generated_toks)    
    generated_reports[pred_dicom]  = g_toks
    
    #break


100%|██████████| 99145/99145 [1:06:10<00:00, 24.97it/s]


In [14]:
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

pred_dir = 'output'
pred_file = os.path.join(pred_dir, '%d-gram.tsv' % n)
print(pred_file)
with open(pred_file, 'w') as f:
    print('dicom_id\tgenerated', file=f)
    for dicom_id,generated in sorted(generated_reports.items()):
        print('%s\t%s' % (dicom_id,generated), file=f)
        
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

2019-11-13 08:30:26
output/3-gram.tsv
2019-11-13 08:30:26
