## Load Model

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataloader import *
from utils import *
from model_vit_bert import ViTConfigCustom, ViTModelCustom, CustomVEDConfig, CustomVisionEncoderDecoder
from training_script_vit_bert import LightningModel

# models: Encoder    
encoder = ViTModelCustom(config=ViTConfigCustom(hidden_size=576), pretrain_4k='vit4k_xs_dino', freeze_4k=True)

# decoder
decoder_model_name="emilyalsentzer/Bio_ClinicalBERT"
decoder = AutoModelForCausalLM.from_pretrained(decoder_model_name, is_decoder=True, add_cross_attention=True)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)

# encoder decoder model
model=CustomVisionEncoderDecoder(config=CustomVEDConfig(),encoder=encoder, decoder=decoder)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

lightning_model = LightningModel(model, tokenizer, model_lr=1e-2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt="/scratch/ss4yd/logs_only_vit_bert/my_model/version_10/checkpoints/epoch=9-val_loss=0.89-step=5000.00.ckpt"
lightning_model.load_state_dict(torch.load(ckpt,map_location=device)['state_dict'])
lightning_model.eval()

len(list(*decoder.bert.encoder.children())[:-2])

# of Patches: 196
Loading Pretrained Local VIT model...
Done!
Freezing Pretrained Local VIT model
Done


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.7.crossattention.self.value.bias', 'bert.encoder.layer.2.crossattention.self.query.weight', 'bert.encoder.layer.6.crossattention.self.

10

## Load Data

In [2]:
import pandas as pd
df_path='../new_lstm_decoder/data_files/prepared_prelim_data_tokenized_cls256_pathcap_thumb_newsent.pickle'
df=pd.read_pickle(df_path)

df=df[df.dtype=='test']

## Examples

In [3]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-15TU5-1426
Actual Note: 
 this is a heart - left ventricle tissue from a male patient and it has 2 pieces




Generated Note: 
 this is a heart - left ventricle tissue from a male patient and it has 2 pieces, no significant ischemic changes 


In [4]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-14BIN-1426
Actual Note: 
 this is a ovary tissue from a female patient and it has 2 pieces; atrophic with corpora albicans
Generated Note: 
 this is a ovary tissue from a female patient and it has 2 pieces, typical post menopausal atrophy 


In [5]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-1AX9K-0326
Actual Note: 
 this is a artery - coronary tissue from a male patient and it has 2 pieces; large [50%] intimal plaque in 1 piece, none in other; 60% is external fat
Generated Note: 
 this is a adipose - visceral ( omentum ) tissue from a male patient and it has 6 pieces, up to 10x6mm ; 


In [6]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-15CHS-0426
Actual Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, no adherent fat/ atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, no adherent fat 


In [7]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-16BQI-1026
Actual Note: 
 this is a breast - mammary tissue tissue from a female patient and it has 2 pieces; 90% fat with scattered atrophic ducts, no lobules
Generated Note: 
 this is a adipose - subcutaneous tissue from a male patient and it has 2 pieces ; 10 % fibrovascular content 


In [8]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-13111-0226
Actual Note: 
 this is a thyroid tissue from a male patient and it has 2 pieces; multifocal mild stromal fibrosis compromising thyroid follicles
Generated Note: 
 this is a thyroid tissue from a male patient and it has 2 pieces, no abnormalities 


In [9]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-ZTX8-1126
Actual Note: 
 this is a testis tissue from a male patient and it has 2 pieces
Generated Note: 
 this is a testis tissue from a male patient and it has 2 pieces ; spermatogenesis is present 


## Effect of BEAM Size

### BEAM Size=1

In [10]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=1, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; 1 with 40 % and 40 % external fat ; other fibrofatty plaque 


### BEAM Size=2

In [11]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; well trimmed 


### BEAM Size=3

In [12]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=3, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; well dissected ; no plaques ; no plaques 


### BEAM Size=4

In [13]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=4, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, no adherent fat, delineated 


### BEAM Size=5

In [14]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=5, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, no adherent fat, delineated 
