In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import pandas as pd
import numpy as np
import os
from sentence_transformers import SentenceTransformer
import spacy
from tqdm import tqdm
from collections import defaultdict
from datasets import load_metric
from differ import diff_ratio
import re

In [3]:
tqdm.pandas()

In [4]:
os.getcwd()

'/Users/tom/phd/summariser/clin_sum'

In [43]:
all_docs = pd.read_json('../mimic_summ_data/mimic_3_val.json', lines=True)

In [44]:
all_docs.text = all_docs.text.apply(clean)

In [47]:
# remove excessively long documents???

In [None]:
all_docs['text_sents'] = all_docs.text.apply(lambda t: [s.text for s in nlp(t).sents  if len(s) > 3])

In [6]:
docs = docs[0:10]

In [7]:
pat = re.compile(r'(?:brief)?\n?\s?hospital course:?|SUMMARY OF HOSPITAL COURSE BY SYSTEMS:?', re.IGNORECASE)
def clean(s: str):
    s = s.replace('\n\n' ,'\n')\
         .replace(r'\s{2+}', ' ')\
         .replace(r'\t', ' ')
    return pat.sub('', s)

In [8]:
docs.summary = docs.summary.apply(clean)
docs.text = docs.text.apply(clean)

0      The patient was transferred to the\nIntensiv...
1    \nShe was admitted to ICU for close observatio...
2      The patient was taken to the operating room\...
3      (By systems including pertinent laboratory\n...
4    \n1. Cardiovascular:  The patient was transfer...
5    \n38 yo male with chronic ACTH dependence, asp...
6    \nASSESSMENT/PLAN [**9-17**]:\n79 year-old mal...
7    \nPt was admitted and underwent a right pigtai...
8    \n75-year-old man with Burkitt's lymphoma s/p ...
9    \nHe was taken to the operating room on [**10-...
Name: summary, dtype: object

In [13]:
model = SentenceTransformer('all-MiniLM-L6-v2')

In [14]:
nlp = spacy.load('en_core_web_md')



In [15]:
docs['summ_sents'] = docs.summary.apply(lambda t: [s.text for s in nlp(t).sents if len(s) > 3])
docs['text_sents'] = docs.text.apply(lambda t: [s.text for s in nlp(t).sents  if len(s) > 3])

In [16]:
# docs['summ_embed'] = docs.summ_sents.progress_apply(lambda sents: [model.encode(s) for s in sents])
docs['text_embed'] = docs.text_sents.progress_apply(lambda sents: [model.encode(s) for s in sents])

100%|██████████| 10/10 [00:06<00:00,  1.56it/s]
100%|██████████| 10/10 [02:03<00:00, 12.32s/it]


In [17]:
docs.to_pickle('doc_embeds.pickle')

In [None]:
docs = pd.read_pickle('doc_embeds.pickle')

In [18]:
# train 6 seperate LSTM models for prediciting 'top-line' extractive summaries.
sent_limits = [1,2,3,5,10,15]

# LSTM embedding model
- embeddings are 'fixed' and provided by s-bert (could also be fine-tuned)
- (bi-)LSTM ranker on top of sentence embeddings (w/ or w/o attn)

In [None]:
# closest matching sentences from a rouge-Lsum perspective??
# model training data is determined by number of sequenecs to extract..? 

In [19]:
metric = load_metric('rouge')
# provides 'oracle - rouge2' maximum that can be achieved by the model(s)
# for each sent in limited sent summs, find 'closest' matching extractive sentence and mark as 1, all others should be marked as 0.

In [20]:
docs.head(3)

Unnamed: 0,hadm_id,summary,text,summ_sents,text_sents,summ_embed,text_embed
0,124571,The patient was transferred to the\nIntensiv...,Radiology:CHEST (PORTABLE AP)\n 1) Possible s...,[The patient was transferred to the\nIntensive...,"[Radiology:CHEST (PORTABLE AP)\n 1) , Possibl...","[[-0.034001175, 0.009998081, -0.06825533, -0.0...","[[0.019646827, 0.09096582, -0.016166693, -0.01..."
1,161919,\nShe was admitted to ICU for close observatio...,Radiology:CT HEAD W/O CONTRAST\nKKgc MON [**21...,[\nShe was admitted to ICU for close observati...,"[Radiology:CT HEAD W/O CONTRAST\n, [**2138-5-1...","[[0.07051655, 0.03656692, 0.036904074, 0.10025...","[[0.07394422, 0.03841625, -0.025370654, -0.029..."
2,109365,The patient was taken to the operating room\...,Nursing/other:Report\nResp Care\n73 yo admitte...,[The patient was taken to the operating room\n...,[Nursing/other:Report\nResp Care\n73 yo admitt...,"[[0.042849362, 0.1184255, -0.07812561, -0.0100...","[[-0.012794435, 0.050156347, -0.011502372, -0...."


In [21]:
sent_limd_sums = defaultdict(list)
for lim in sent_limits:
    docs[f'summ_lim_{lim}'] = docs.summ_sents.apply(lambda sents: sents[:lim])

In [32]:
first_sent = docs.summ_lim_1

In [22]:
def _parse_score(lvl, scores):
    return (lvl, scores[lvl].mid.precision, scores[lvl].mid.recall, scores[lvl].mid.fmeasure)

In [23]:
for lim in sent_limits:
    col = f'summ_lim_{lim}'
    preds_srs = []
    for summ_sents, text_sents in zip(docs[col], docs.text_sents):
        sim_text_sent_idxs = []
        sents_to_compare = text_sents
        for summ_sent in summ_sents:
            sents_to_compare = [s if i not in sim_text_sent_idxs else ''
                                for i, s in enumerate(text_sents)]
            ratios = [diff_ratio(summ_sent, sent)[0] for sent in sents_to_compare]
            max_ratio_sents = np.where(ratios == np.amax(ratios))[0]
            sim_text_sent_idxs.extend(max_ratio_sents)
        preds = np.zeros(len(text_sents))
        for i in sim_text_sent_idxs:
            preds[i] = 1
        preds_srs.append(preds)
    docs[f'preds_lim_{lim}'] = preds_srs

In [24]:
docs = docs.drop(['summary', 'text', 'summ_embed'], axis=1)

In [None]:
docs.to_pickle('doc_embeds.pickle')

In [None]:
# compute max ROUGE scores from 'oracle' model

In [25]:
def _parse_score(lvl, scores):
    return (lvl, scores[lvl].mid.precision, scores[lvl].mid.recall, scores[lvl].mid.fmeasure)

In [28]:
extractive_score_ceil = {}
for lim in sent_limits:
    text_sums = []
    for sent_idxs, text_sents in zip([np.argwhere(p == 1) for p in docs[f'preds_lim_{lim}']], 
                                     docs.text_sents):
        text_sum = []
        for i in sent_idxs:
            t = text_sents[i[0]]
            if t not in text_sum:
                text_sum.append(t)
        text_sums.append(''.join(text_sum))
    metric.add_batch(predictions=text_sums, references=docs[f'summ_lim_{lim}'].str.join('').tolist())
    scores = metric.compute()
    extractive_score_ceil[lim] = _parse_score('rouge1', scores), _parse_score('rouge2', scores), _parse_score('rougeLsum', scores)

In [None]:
extractive_score_ceil

In [61]:
import torch
import torch.nn as nn
from torch.nn import LSTM
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [53]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f81d1716fa8>

In [54]:
inputs = sorted([torch.tensor(t) for t in docs.text_embed], key=lambda t: t.shape[0], reverse=True)
outputs = [torch.tensor(d_preds) for d_preds in docs.preds_lim_15]

In [55]:
in_lens = [len(i) for i in inputs]

In [56]:
outputs = torch.nn.utils.rnn.pad_sequence(outputs, batch_first=True)

In [57]:
inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)

In [62]:
class LSTMClf(nn.Module):
    def __init__(self):
        super(LSTMClf, self).__init__()
        self.model = LSTM(384, 50, num_layers=2, batch_first=True, dropout=0.5)
        self.fc = nn.Linear(50, 1)
    
    def forward(self, X, X_lens):
        X = torch.nn.utils.rnn.pack_padded_sequence(X, X_lens, batch_first=True)
        X, h = self.model(X)
        X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
        X = self.fc(X)
        return X

In [63]:
model = LSTMClf()

In [64]:
criterion = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.1)

In [65]:
model.train()

LSTMClf(
  (model): LSTM(384, 50, num_layers=2, batch_first=True, dropout=0.5)
  (fc): Linear(in_features=50, out_features=1, bias=True)
)

In [66]:
outputs.shape

torch.Size([10, 4540])

In [67]:
running_loss = []
for epoch in tqdm(range(10)):
    logits = model(inputs, in_lens)
    loss = criterion(logits.squeeze(), outputs)
    loss.backward()
    running_loss.append(loss.item())
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    torch.nn.utils.clip_grad_norm_(parameters, 0.25)
    optim.step()
    optim.zero_grad()

100%|██████████| 10/10 [06:45<00:00, 40.58s/it]


In [68]:
running_loss

[0.7510637823484083,
 0.6069032915121987,
 0.5673535308240203,
 0.5310154317266057,
 0.4953841878154565,
 0.4625271369234581,
 0.43124800153486925,
 0.4020194065181891,
 0.3748356447783237,
 0.3489772009557433]