In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# !pip install datasets
# !pip install evaluate

In [1]:
from datasets import load_dataset
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import evaluate
import torch
import pandas as pd
import numpy as np

In [2]:
# ds = load_dataset("EdinburghNLP/xsum", trust_remote_code=True, streaming=True)
# seq1id = 'document'
# seq2id = 'summary'

# ds = load_dataset("knkarthick/dialogsum", streaming=True)
# seq1id = 'dialogue'
# seq2id = 'summary'

ds = load_dataset("har1/MTS_Dialogue-Clinical_Note", streaming=False)
seq1id = 'dialogue'
seq2id = 'section_text'

In [None]:
ds['train']

In [None]:
# sample = ds['train'][0]
# sample = next(iter(ds['train']))
ds_iter = iter(ds['train'])

In [None]:
sample = next(ds_iter)
print(sample.keys())
print('')
print(sample[seq1id])
print('')
print(sample[seq2id])

In [None]:
# summarizer = pipeline('summarization', model="sshleifer/distilbart-cnn-12-6") # Bad
# summarizer = pipeline('summarization', model="sshleifer/distilbart-xsum-12-1") # Bad

# summarizer = pipeline('summarization', model='facebook/bart-large-cnn') # Good
# summarizer = pipeline('summarization', model='google/pegasus-large') # Okay
# summarizer = pipeline('summarization', model = 'facebook/bart-large-xsum') # Okay
# summarizer = pipeline('summarization', model="google/pegasus-xsum") # Good

In [None]:
device = 0 if torch.cuda.is_available() else -1

In [None]:
# on xsum dataset: good, okay, okay, good
# on dialog dataset: okay, bad, bad, bad
# summarizers = ['facebook/bart-large-cnn', 'google/pegasus-large', 'facebook/bart-large-xsum', 'google/pegasus-xsum']

# on dialogsum dataset: good, not good
# summarizers = ['dtruong46me/train-bart-base', 'gauravkoradiya/T5-Finetuned-Summarization-DialogueDataset']
# summarizers = ['Ketan3101/ConvoBrief'] # can't run
# summarizers = ['akhil033/Flan-T5-Conversation-Summarizer'] # can't run
summarizers = ['dtruong46me/train-bart-base']

print('Ground truth: ' + sample[seq2id] + '\n')

for s in summarizers:
    summarizer = pipeline('summarization', model=s, device=device)
    summary = summarizer(sample[seq1id])
    print(s + ': ' + summary[0]['summary_text'] + '\n')
    del summarizer

In [None]:
# summarizer = pipeline('summarization', model='dtruong46me/train-bart-base', device=device)
summarizer = pipeline('summarization', model='facebook/bart-large-cnn', device=device)

In [None]:
sample = next(ds_iter)
print(sample[seq1id] + '\n')
print('Ground truth: ' + sample[seq2id] + '\n')
summary = summarizer(sample[seq1id])
print('Prediction: ' + summary[0]['summary_text'] + '\n')

# Clinical notes summarization

In [3]:
ds_df = pd.DataFrame(ds['train'])
ds_df.head()

Unnamed: 0,ID,section_header,section_text,dialogue
0,0,GENHX,"Symptoms: no fever, no chills, no cough, no co...",Doctor: What brings you back into the clinic t...
1,1,GENHX,"Symptoms: sudden onset headache, blurry vision...",Doctor: How're you feeling today? \nPatient: ...
2,2,GENHX,Symptoms: itching.\nDiagnosis: condylomas.\nHi...,"Doctor: Hello, miss. What is the reason for yo..."
3,3,MEDICATIONS,Symptoms: N/A.\r\nDiagnosis: N/A.\r\nHistory o...,Doctor: Are you taking any over the counter me...
4,4,CC,"Symptoms: Burn, right arm.\r\nDiagnosis: N/A.\...","Doctor: Hi, how are you? \nPatient: I burned m..."


In [4]:
ds_df['section_header'].value_counts()

section_header
FAM/SOCHX        373
GENHX            302
PASTMEDICALHX    122
CC                81
PASTSURGICAL      71
ROS               71
ALLERGY           64
MEDICATIONS       61
ASSESSMENT        38
EXAM              24
DIAGNOSIS         20
DISPOSITION       17
PLAN              14
EDCOURSE          11
IMMUNIZATIONS      9
IMAGING            7
GYNHX              6
PROCEDURES         4
OTHER_HISTORY      3
LABS               3
Name: count, dtype: int64

In [5]:
ds_df[ds_df['section_header'] == 'IMAGING']
# ds_df.iloc[88]
# ds['train'][41]

Unnamed: 0,ID,section_header,section_text,dialogue
219,219,IMAGING,Symptoms: N/A\r\nDiagnosis: N/A\r\nHistory of ...,"Doctor: So, I am looking at his x ray and it d..."
576,576,IMAGING,"Symptoms: Chest pain, difficulties with breath...",Doctor: Your chest x ray showed diffuse pulmon...
650,650,IMAGING,Symptoms: N/A\r\nDiagnosis: N/A\r\nHistory of ...,Doctor: I am looking at her x ray report and s...
774,774,IMAGING,Symptoms: N/A\r\nDiagnosis: The patient's CBC ...,"Doctor: So, we looked at your previous blood w..."
1088,1088,IMAGING,Symptoms: N/A\r\nDiagnosis: N/A\r\nHistory of ...,Doctor: Are you finished with your cancer trea...
1143,1143,IMAGING,Symptoms: N/A\r\nDiagnosis: N/A\r\nHistory of ...,Doctor: I have reviewed your x rays from your ...
1298,97,IMAGING,Symptoms: N/A\nDiagnosis: Sinus tachycardia\nH...,"Doctor: Well, I have your E K G report, shows ..."


In [6]:
sample = ds['train'][1298]
print(sample.keys())
print('')
print(sample[seq1id])
print('')
print(sample[seq2id])

dict_keys(['ID', 'section_header', 'section_text', 'dialogue'])

Doctor: Well, I have your E K G report, shows you have sinus tachycardia. In other words, your heart is beating faster than normal due to rapid firing of sinus node.
Patient: Okay.
Doctor: Well, there are no S T changes.
Patient: Hm.

Symptoms: N/A
Diagnosis: Sinus tachycardia
History of Complaint: N/A
Plan of Action: No acute ST changes



In [None]:
# t2t = pipeline("text2text-generation", model="har1/HealthScribe-Clinical_Note_Generator", device=device) # pretty dang good
# t2t = pipeline('text2text-generation', model='facebook/bart-base', device=device)

In [None]:
summary = t2t(sample[seq1id])

In [None]:
print('Ground truth:\n')
print(sample[seq2id] + '\n\n')
print('Prediction:\n')
print(summary[0]['generated_text'])

In [16]:
# Manual pipeline

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')
tokenized_text = tokenizer(sample[seq1id], text_target=sample[seq2id], return_tensors='pt')
output = model.generate(**tokenized_text, num_beams=4, max_length=150, early_stopping=True)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded)

del tokenizer, model



Doctor: Well, I have your E K G report, shows you have sinus tachycardia. In other words, your heart is beating faster than normal due to rapid firing of sinus node. Patient: Hm. Doctor: There are no S T changes.


In [18]:
tokenized_text

{'input_ids': tensor([[    0, 41152,    35,  2647,     6,    38,    33,   110,   381,   229,
           272,   266,     6,   924,    47,    33, 10272,   687,   326, 35600,
          6940,   493,     4,    96,    97,  1617,     6,   110,  1144,    16,
          4108,  3845,    87,  2340,   528,     7,  6379,  5834,     9, 10272,
           687, 37908,     4, 50118, 18276,  4843,    35,  8487,     4, 50118,
         41152,    35,  2647,     6,    89,    32,   117,   208,   255,  1022,
             4, 50118, 18276,  4843,    35,   289,   119,     4,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[    0, 46994, 47629,    35,   234,    73,   250, 50118, 29038, 11244,
         13310,    35,  8356,   687,   326, 35600,  6940,   493, 50118, 38261,
             9

In [None]:
model.forward(**tokenized_text)

In [12]:
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')

In [13]:
type(model)

transformers.models.bart.modeling_bart.BartForConditionalGeneration

In [19]:
tokenized_text.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])