## Load Model

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataloader import *
from utils import *
from model_vit_bert3 import ViTConfigCustom, ViTModelCustom, CustomVEDConfig, CustomVisionEncoderDecoder
from training_script_vit_bert_3layers_bsgt1_v2 import LightningModel

# models: Encoder    
encoder = ViTModelCustom(config=ViTConfigCustom(hidden_size=576), pretrain_4k='vit4k_xs_dino', freeze_4k=True)

# decoder
decoder_model_name="emilyalsentzer/Bio_ClinicalBERT"
decoder = AutoModelForCausalLM.from_pretrained(decoder_model_name, is_decoder=True, add_cross_attention=True)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)

# encoder decoder model
model=CustomVisionEncoderDecoder(config=CustomVEDConfig(),encoder=encoder, decoder=decoder)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

lightning_model = LightningModel(model, tokenizer, model_lr=1e-2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt="/scratch/ss4yd/logs_only_vit_bert_fullexpv2_freezeFalse/my_model/version_23/checkpoints/epoch=5-val_loss=0.93-step=4254.00.ckpt"
lightning_model.load_state_dict(torch.load(ckpt,map_location=device)['state_dict'])
lightning_model.eval()

len(list(*decoder.bert.encoder.children())[:-5])

model=lightning_model.model

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['bert.encoder.layer.4.crossattention.self.query.bias', 'bert.encoder.layer.8.crossattention.output.dense.bias', 'bert.encoder.layer.6.crossattention.output.dense.weight', 'bert.encoder.layer.11.crossattention.se

## Load Data

In [2]:
import pandas as pd
df_path='/home/ss4yd/nlp/full_experiments/data_files/all_gtex_data.pickle'
df=pd.read_pickle(df_path)

# df=df[df.dtype=='train']
# print(df.sex.value_counts())

df=df[df.dtype=='test']
df.sex.value_counts()

male      834
female    422
Name: sex, dtype: int64

In [3]:
637/(637+347), tokenizer.vocab_size

(0.6473577235772358, 28996)

In [4]:
df

Unnamed: 0,pid,svs_path,patch_path,reps_path,tissue_type,notes,new_notes,sex,dtype,reps4kpath
23864,GTEX-14DAQ-1126,/scratch/ss4yd/gtex_data_new/GTEX-14DAQ-1126.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Heart - Left Ventricle,"2 pieces, minimal interstitial fibrosis, chron...",this is a heart - left ventricle tissue from a...,female,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
23865,GTEX-183WM-0926,/scratch/ss4yd/gtex_data_new/GTEX-183WM-0926.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Heart - Atrial Appendage,"2 pieces, no abnormalities",this is a heart - atrial appendage tissue from...,female,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
23866,GTEX-15ER7-1626,/scratch/ss4yd/gtex_data_new/GTEX-15ER7-1626.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Breast - Mammary Tissue,2 pieces; 50 and 70% fat with fibrocollagenous...,this is a breast - mammary tissue tissue from ...,female,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
23867,GTEX-1AX8Z-0126,/scratch/ss4yd/gtex_data_new/GTEX-1AX8Z-0126.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Skin - Sun Exposed (Lower leg),6 pieces; 20% dermal fat; well trimmed,this is a skin - sun exposed (lower leg) tissu...,male,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
23868,GTEX-1HSKV-2426,/scratch/ss4yd/gtex_data_new/GTEX-1HSKV-2426.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Esophagus - Mucosa,6 pieces; includes few clusters of submucosal ...,this is a esophagus - mucosa tissue from a mal...,male,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
...,...,...,...,...,...,...,...,...,...,...
25115,GTEX-Y5V5-2726,/scratch/ss4yd/gtex_data_new/GTEX-Y5V5-2726.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Artery - Tibial,"2 pieces, 3x2.5 & 4x3mm; one aliquot has ~0.5m...",this is a artery - tibial tissue from a female...,female,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
25116,GTEX-1OJC4-0126,/scratch/ss4yd/gtex_data_new/GTEX-1OJC4-0126.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Skin - Sun Exposed (Lower leg),"6 pieces; relatively well trimmed, include up ...",this is a skin - sun exposed (lower leg) tissu...,female,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
25117,GTEX-1I1GT-1426,/scratch/ss4yd/gtex_data_new/GTEX-1I1GT-1426.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Esophagus - Mucosa,6 pieces,this is a esophagus - mucosa tissue from a fem...,female,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...
25118,GTEX-1RAZR-1426,/scratch/ss4yd/gtex_data_new/GTEX-1RAZR-1426.svs,/project/GutIntelligenceLab/ss4yd/gtex_data/pa...,/project/GutIntelligenceLab/ss4yd/gtex_data/hi...,Esophagus - Gastroesophageal Junction,5 pieces; 2 have 10 & 20% stromal contents,this is a esophagus - gastroesophageal junctio...,male,test,/project/GutIntelligenceLab/ss4yd/gdc_data/hip...


In [5]:
reps_path=df['reps_path']

In [6]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)

pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)

Patient ID: GTEX-1RAZR-1426
Actual Note: 
 this is a esophagus - gastroesophageal junction tissue from a male patient and it has 5 pieces; 2 have 10 & 20% stromal contents


In [7]:
from dataloader import *
df_path='/home/ss4yd/nlp/full_experiments/data_files/all_gtex_data.pickle'

# Hyperparameters
batch_size=32
epochs=30
model_lr=2e-5

n_layers=3

num_workers=1
test_loader = torch.utils.data.DataLoader(ResnetPlusVitDatasetV3(df_path,text_decode_model=decoder_model_name, dtype='test'), 
                                              batch_size=1, shuffle=False, num_workers=num_workers)

In [8]:
for batch in test_loader:
    pixel_values, labels, attention_mask, encoder_attention_mask = batch
    break
# model._prepare_encoder_decoder_kwargs_for_generation(pixel_values)

In [9]:
model.eval()
outs=model(pixel_values=pixel_values, 
           attention_mask=encoder_attention_mask,
           labels=labels, 
           decoder_attention_mask=attention_mask,
           output_attentions=True
          )

In [10]:
outs.keys()

odict_keys(['loss', 'logits', 'past_key_values', 'decoder_attentions', 'cross_attentions', 'encoder_last_hidden_state'])

In [20]:
len(outs['cross_attentions'])

12

In [11]:
for i in range(len(outs['cross_attentions'])):
    print(outs['cross_attentions'][i].shape)

torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])
torch.Size([1, 12, 128, 128])


## Examples

In [67]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)

pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=1, do_sample=False, 
                                      return_dict_in_generate=True, output_attentions=True)
decoded_cap=tokenizer.decode(gencap['sequences'][0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-1HR9M-0626
Actual Note: 
 this is a thyroid tissue from a male patient and it has 2 pieces; 10% internal fat content
Generated Note: 
 this is a thyroid tissue from a male patient and it has 2 pieces ; 1 piece has 50 % fat 


In [75]:
pixel_values.shape

torch.Size([1, 28, 576])

In [68]:
decoded_cap

('[CLS] this is a thyroid tissue from a male patient and it has 2 pieces ; 1 piece has 50 % fat [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP]. this is a thyroid tissue ( omentum ) ; other has 2 pieces ; other has 2 pieces [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP]',
 torch.Size([1, 28, 576]))

In [71]:
gencap['sequences']

tensor([[  101,  1142,  1110,   170, 21153, 16219,  7918,  1121,   170,  2581,
          5351,  1105,  1122,  1144,   123,  3423,   132,   122,  2727,  1144,
          1851,   110,  7930,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,   119,  1142,  1110,   170,
         21153, 16219,  7918,   113,   184,  1880,  1818,   114,   132,  1168,
          1144,   123,  3423,   132,  1168,  1144,   123,  3423,   102,   102,
           102,   102,   102,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,   102,   102,   102,   102,
           102,   102,   102,   102,   102,   102,  

In [70]:
gencap.keys()

odict_keys(['sequences', 'decoder_attentions', 'cross_attentions'])

In [72]:
len(gencap['cross_attentions'])

127

In [77]:
index=0
print(len(gencap['cross_attentions'][index]))
for i in range(len(gencap['cross_attentions'][index])):
    print(gencap['cross_attentions'][index][i].shape)

12
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])


In [74]:
index=-1
for i in range(len(gencap['cross_attentions'][index])):
    print(gencap['cross_attentions'][index][i].shape)

torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])
torch.Size([1, 12, 1, 28])


In [52]:
len(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))

8

In [53]:
len(gencap['cross_attentions'][0])

12

In [62]:
# samp = df.sample(1)
samp = df[df.pid=='GTEX-QDT8-2826']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-QDT8-2826
Actual Note: 
 this is a bladder tissue from a female patient and it has 6   ~8x5mm pieces.  urotherlium nearly completely sloughed; muscularis/adipose tissue is >99% of sections
Generated Note: 
 this is a esophagus - gastroesophageal junction tissue from a male patient and it has 6 pieces, all muscularis, good specimens 


In [52]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-139UC-0426
Actual Note: 
 this is a heart - left ventricle tissue from a male patient and it has 2 pieces, moderate interstitial fibrosis with ~3mm remote infarct, delineated
Generated Note: 
 this is a heart - left ventricle tissue from a male patient and it has 2 pieces ; interstitial fibrosis 


In [43]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-1L5NE-0426
Actual Note: 
 this is a heart - left ventricle tissue from a male patient and it has 2 pieces, moderate chronic ischemic changes/interstitial fibrosis; ~2mm remote micro-infarct encircled.
Generated Note: 
 this is a heart - left ventricle tissue from a male patient and it has 2 pieces, moderate interstitial fibrosis / chronic ischemic changes 


In [55]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-1H11D-0926
Actual Note: 
 this is a pancreas tissue from a male patient and it has 2 pieces; both include up to 20% attached fat, focal squamous metaplasia of duct
Generated Note: 
 this is a pancreas tissue from a male patient and it has 2 pieces, advanced saponification ; islets not well visualized ; 


In [45]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-V955-2626
Actual Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ~8x6mm.  20% occlusive atherosclerosis, relatively clean specimens
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces, 3x2 & 3x2mm ; medial calcification ; atheromatous plaques ; ~ 50 % of aliquot 


In [46]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-ZTPG-2326
Actual Note: 
 this is a adipose - visceral (omentum) tissue from a female patient and it has 2 pieces
Generated Note: 
 this is a adipose - visceral ( omentum ) tissue from a male patient and it has 2 pieces ; large vessels, no vascular tissue 


In [58]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-1GZ2Q-1526
Actual Note: 
 this is a kidney - cortex tissue from a male patient and it has 6 pieces; renal cortex with patchy interstitial fibrosis and  hyalinized glomeruli (rep arrowed)
Generated Note: 
 this is a kidney - cortex tissue from a male patient and it has 6 pieces ; glomeruli in all sections, arteriolar sclerosis ; tubules severely autolyzed ; glomeruli in medulla 


In [66]:
samp = df.sample(1)

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-ZV6S-1826
Actual Note: 
 this is a breast - mammary tissue tissue from a female patient and it has 2 pieces, ~70% loose fibrous tissue, epithelium is <5%, small floater  (colonic gland) delineated
Generated Note: 
 this is a breast - mammary tissue tissue from a male patient and it has 2 pieces ; fibroadipose tissue with gynecomastoid stroma and ductal structures 


## Effect of BEAM Size

### BEAM Size=1

In [47]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=1, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; no plaques ; well trimmed ; no plaques ; no plaques 


### BEAM Size=2

In [48]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=2, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; well trimmed ; no plaques ; no plaques ; no plaques 


### BEAM Size=3

In [49]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=3, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; well trimmed ; no plaques ; no plaques ; well trimmed 


### BEAM Size=4

In [50]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=4, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; well trimmed ; no plaques ; no plaques 


### BEAM Size=5

In [51]:
samp = df[df.pid=='GTEX-13PVR-2726']

pid=samp.pid.values[0]
print(f'Patient ID: {pid}')
print(f'Actual Note: \n {samp.new_notes.values[0].lower()}')
reps256_path=samp['reps_path'].values[0]
reps4k_path=samp['reps4kpath'].values[0]

x256 = torch.load(reps256_path)
x256mean = x256.mean(dim=1)
x4k = torch.load(reps4k_path)
pixel_values = torch.cat([x256mean, x4k], dim=1).unsqueeze(0)
# pixel_values=torch.load(samp.reps_path.values[0]).unsqueeze(0)
pixel_values.shape

# generate
gencap=lightning_model.model.generate(pixel_values, max_length=128, num_beams=5, do_sample=True)

decoded_cap=tokenizer.decode(gencap[0])
remove_sptokens=decoded_cap[6:decoded_cap.find('[SEP]')]

print(f'Generated Note: \n {remove_sptokens}')
# print(f'Full Generated Note: \n {decoded_cap}')

Patient ID: GTEX-13PVR-2726
Actual Note: 
 this is a artery - tibial tissue from a female patient and it has 2 pieces, clean specimen, no atherosis
Generated Note: 
 this is a artery - tibial tissue from a male patient and it has 2 pieces ; well trimmed ; no plaques ; no plaques ; no plaques 
