## Load Model

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataloader import *
from utils import *
from model_vit_bert import ViTConfigCustom, ViTModelCustom, CustomVEDConfig, CustomVisionEncoderDecoder
from training_script_vit_bert_5layers import LightningModel

# models: Encoder    
encoder = ViTModelCustom(config=ViTConfigCustom(hidden_size=576), pretrain_4k='vit4k_xs_dino', freeze_4k=True)

# decoder
decoder_model_name="emilyalsentzer/Bio_ClinicalBERT"
decoder = AutoModelForCausalLM.from_pretrained(decoder_model_name, is_decoder=True, add_cross_attention=True)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)

# encoder decoder model
model=CustomVisionEncoderDecoder(config=CustomVEDConfig(),encoder=encoder, decoder=decoder)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

lightning_model = LightningModel(model, tokenizer, model_lr=1e-2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt="/scratch/ss4yd/logs_only_vit_bert/my_model/version_10/checkpoints/epoch=9-val_loss=0.89-step=5000.00.ckpt"
lightning_model.load_state_dict(torch.load(ckpt,map_location=device)['state_dict'])
lightning_model.eval()

# of Patches: 196
Loading Pretrained Local VIT model...
Done!
Freezing Pretrained Local VIT model
Done


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['bert.encoder.layer.4.crossattention.self.key.weight', 'bert.encoder.layer.3.crossattention.self.key.weight', 'bert.encoder.layer.6.crossattention.self.value.bias', 'bert.encoder.layer.3.crossattention.self.key.

LightningModel(
  (model): CustomVisionEncoderDecoder(
    (encoder): ViTModelCustom(
      (local_vit): VisionTransformer4K(
        (phi): Sequential(
          (0): Linear(in_features=384, out_features=192, bias=True)
          (1): GELU(approximate=none)
          (2): Dropout(p=0.0, inplace=False)
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (blocks): ModuleList(
          (0): Block(
            (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
            (attn): Attention(
              (qkv): Linear(in_features=192, out_features=576, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=192, out_features=192, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bia

In [2]:
len(list(*decoder.bert.encoder.children())[:-5])

7

## Load Data

In [39]:
import pandas as pd
df_path='../final_more_female_data.pickle'
df=pd.read_pickle(df_path)

df=df[df.dtype=='train']
print(df.sex.value_counts())

df=df[df.dtype=='test']
df.sex.value_counts()

male      5072
female    4839
Name: sex, dtype: int64


Series([], Name: sex, dtype: int64)

In [40]:
637/(637+347), tokenizer.vocab_size

(0.6473577235772358, 28996)

## Examples

In [36]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=False)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-14ABY-0326
Actual Note: 
 this is a adipose - subcutaneous tissue from a male patient and it has 2 pieces; ~10% of fibrovascular component
Generated Note: 
 this is a adipose - subcutaneous tissue from a male patient and it has 2 pieces ; 10 % fibrovascular content 


In [16]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-ZUA1-0326
Actual Note: 
 this is a muscle - skeletal tissue from a male patient and it has 2 pieces. 10% fat in 1 piece.
Generated Note: 
 this is a muscle - skeletal tissue from a male patient and it has 2 pieces ; 10 % internal fat 


In [5]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-15EU6-1126
Actual Note: 
 this is a pancreas tissue from a male patient and it has 2 pieces, adherent/interstitial fat is ~35-40%, delineated; islets well visualized; rep encircled
Generated Note: 
 this is a pancreas tissue from a male patient and it has 2 pieces ; 10 % internal fat 


In [6]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-15G1A-1826
Actual Note: 
 this is a colon - transverse tissue from a male patient and it has 6 pieces, mucosa remarkably well preserved, 0.4mm, ~10% thickness
Generated Note: 
 this is a colon - transverse tissue from a male patient and it has 6 pieces, mucosa up to ~ 0. 3mm, ~ 20 % thickness, rep delineated 


In [7]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-17F96-2026
Actual Note: 
 this is a colon - transverse tissue from a male patient and it has 6 pieces; mucosa=3;muscle = 1 (autolysis)
Generated Note: 
 this is a colon - transverse tissue from a male patient and it has 6 pieces, mucosa well - preserved, ~ 0. 5mm thick 


In [8]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-QEG4-0926
Actual Note: 
 this is a artery - aorta tissue from a male patient and it has 6 ~8.5x1.5mm pieces, several with adherent fat/fibrous tissue up to ~5.2mm
Generated Note: 
 this is a artery - aorta tissue from a male patient and it has 6 pieces, up to 9x1mm ; 


In [9]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-145MO-0426
Actual Note: 
 this is a adipose - subcutaneous tissue from a male patient and it has 2 pieces, ~5-10% fascia/vascular tissue, rep. delineated
Generated Note: 
 this is a adipose - subcutaneous tissue from a male patient and it has 2 pieces, ~ 10 % fascia / vascular elements, rep delineated 


## Effect of BEAM Size

### BEAM Size=1

In [10]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=1, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, excellent clean specimens 


### BEAM Size=2

In [11]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, good clean specimens, no adherent fat 


### BEAM Size=3

In [12]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=3, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; no plaques ; no plaques ; well trimmed 


### BEAM Size=4

In [13]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=4, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, no adherent fat 


### BEAM Size=5

In [14]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=5, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, no adherent fat 
