#### Import Libraries

In [1]:
import torch
import regex as re
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, TrainerCallback

#### Load Data

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
# file paths for training and validation data
base_path = "./all_data"
train_input_file = base_path + "/train/radiology/semtypes.txt"
train_target_file = base_path + "/train/radiology/captions.txt"
valid_input_file = base_path + "/validation/radiology/semtypes.txt"
valid_target_file = base_path + "/validation/radiology/captions.txt"

train_input_data = [line.split() for line in open(train_input_file)]
train_input_data = [[t[0], " ".join(t[1:])] for t in train_input_data]
train_input_data = pd.DataFrame(train_input_data)
train_input_data.columns = ["id", "semtypes"]

train_target_data = pd.read_csv(train_target_file, sep="\t", header=None)
train_target_data.columns = ["id", "caption"]

valid_input_data = [line.split() for line in open(valid_input_file)]
valid_input_data = [[t[0], " ".join(t[1:])] for t in valid_input_data]
valid_input_data = pd.DataFrame(valid_input_data)
valid_input_data.columns = ["id", "semtypes"]

valid_target_data = pd.read_csv(valid_target_file, sep="\t", header=None)
valid_target_data.columns = ["id", "caption"]

In [14]:
# LLM output
f = open('./all_data/llm_result.txt', "r")
contents = f.read()
contents = contents.replace("\n", "")
json_data = json.loads(contents)

llm_df = pd.DataFrame(json_data)

llm_df = llm_df.drop('index', axis=1)

llm_df = llm_df[llm_df['relationship'].apply(lambda x: re.search(r'\w', str(x)) is not None)]
llm_df = llm_df.reset_index(drop=True)

llm_df


Unnamed: 0_level_0,id,relationship,summary
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,ROCO_00002,\nAI: The diagnosis is Obliteration of the lef...,The diagnosis is Obliteration of the left maxi...
1,ROCO_00003,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
2,ROCO_00004,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
4,ROCO_00007,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
5,ROCO_00008,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
...,...,...,...
516,ROCO_00651,\nAI: For the fluorescence angiography procedu...,For the fluorescence angiography procedure (Di...
517,ROCO_00652,\nAI: The UMLS semantic types describe the dis...,The UMLS semantic types describe the disease a...
518,ROCO_00653,\nAI: Computed tomography revealing (Pathology...,Computed tomography revealing (Pathology) pneu...
519,ROCO_00654,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...


#### Filter ones that have been training by the LLM

In [15]:
train_target_data = train_target_data.merge(llm_df, on='id')
train_target_data

Unnamed: 0,id,caption,relationship,summary
0,ROCO_00002,Computed tomography scan in axial view showin...,\nAI: The diagnosis is Obliteration of the lef...,The diagnosis is Obliteration of the left maxi...
1,ROCO_00003,Bacterial contamination occurred after comple...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
2,ROCO_00004,The patient had residual paralysis of the han...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
3,ROCO_00007,Plain abdomen x-ray: Multiple air levels at t...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
4,ROCO_00008,A 3-year-old child with visual difficulties. ...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
...,...,...,...,...
500,ROCO_00651,Fig. 5Fluorescein angiography: early hyperflu...,\nAI: For the fluorescence angiography procedu...,For the fluorescence angiography procedure (Di...
501,ROCO_00652,Transverse CT thorax image at the level of th...,\nAI: The UMLS semantic types describe the dis...,The UMLS semantic types describe the disease a...
502,ROCO_00653,Computed tomography revealing right upper-lun...,\nAI: Computed tomography revealing (Pathology...,Computed tomography revealing (Pathology) pneu...
503,ROCO_00654,Lateral fluoroscopic view in a 77-year-old os...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...


In [16]:
train_input_data = train_input_data[train_input_data['id'].isin(train_target_data['id'])]
train_input_data

Unnamed: 0,id,semtypes
0,ROCO_00002,T060 Diagnostic Procedure T061 Therapeutic or ...
1,ROCO_00003,T061 Therapeutic or Preventive Procedure T058 ...
2,ROCO_00004,T060 Diagnostic Procedure T047 Disease or Synd...
4,ROCO_00007,T109 Organic Chemical T074 Medical Device T121...
5,ROCO_00008,T170 Intellectual Product T121 Pharmacologic S...
...,...,...
516,ROCO_00651,T046 Pathologic Function T060 Diagnostic Proce...
517,ROCO_00652,T047 Disease or Syndrome T170 Intellectual Pro...
518,ROCO_00653,T046 Pathologic Function T041 Mental Process T...
519,ROCO_00654,T060 Diagnostic Procedure T170 Intellectual Pr...


In [17]:
train_input_data = train_input_data.merge(train_target_data, on='id')
train_input_data

Unnamed: 0,id,semtypes,caption,relationship,summary
0,ROCO_00002,T060 Diagnostic Procedure T061 Therapeutic or ...,Computed tomography scan in axial view showin...,\nAI: The diagnosis is Obliteration of the lef...,The diagnosis is Obliteration of the left maxi...
1,ROCO_00003,T061 Therapeutic or Preventive Procedure T058 ...,Bacterial contamination occurred after comple...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
2,ROCO_00004,T060 Diagnostic Procedure T047 Disease or Synd...,The patient had residual paralysis of the han...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
3,ROCO_00007,T109 Organic Chemical T074 Medical Device T121...,Plain abdomen x-ray: Multiple air levels at t...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
4,ROCO_00008,T170 Intellectual Product T121 Pharmacologic S...,A 3-year-old child with visual difficulties. ...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
...,...,...,...,...,...
500,ROCO_00651,T046 Pathologic Function T060 Diagnostic Proce...,Fig. 5Fluorescein angiography: early hyperflu...,\nAI: For the fluorescence angiography procedu...,For the fluorescence angiography procedure (Di...
501,ROCO_00652,T047 Disease or Syndrome T170 Intellectual Pro...,Transverse CT thorax image at the level of th...,\nAI: The UMLS semantic types describe the dis...,The UMLS semantic types describe the disease a...
502,ROCO_00653,T046 Pathologic Function T041 Mental Process T...,Computed tomography revealing right upper-lun...,\nAI: Computed tomography revealing (Pathology...,Computed tomography revealing (Pathology) pneu...
503,ROCO_00654,T060 Diagnostic Procedure T170 Intellectual Pr...,Lateral fluoroscopic view in a 77-year-old os...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...


In [18]:
from sklearn.model_selection import train_test_split

# Split train_input_data into train, test, and valid datasets
train_input, valid_test_input = train_test_split(train_input_data, test_size=0.3, random_state=42)
valid_input, test_input = train_test_split(valid_test_input, test_size=0.5, random_state=42)

# Split train_target_data into train, test, and valid datasets
# train_target, valid_test_target = train_test_split(train_target_data, test_size=0.3, random_state=42)
# valid_target, test_target = train_test_split(valid_test_target, test_size=0.5, random_state=42)

In [19]:
# Reset index
train_input = train_input.reset_index(drop=True)
valid_input = valid_input.reset_index(drop=True)
test_input = test_input.reset_index(drop=True)

# train_target = train_target.reset_index(drop=True)
# valid_target = valid_target.reset_index(drop=True)
# test_target = test_target.reset_index(drop=True)

In [20]:
test_input

Unnamed: 0,id,semtypes,caption,relationship,summary
0,ROCO_00050,T074 Medical Device T047 Disease or Syndrome T...,Free air beneath the diaphragm at abdominal x...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
1,ROCO_00573,T060 Diagnostic Procedure T040 Organism Functi...,CT scan of the neck and upper medastinum: Con...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
2,ROCO_00434,T109 Organic Chemical T121 Pharmacologic Subst...,Abdomen plan X-ray.,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
3,ROCO_00334,"T060 Diagnostic Procedure T023 Body Part, Orga...",Coronal computed tomography of the neck clear...,\nAI: The diagnosis is Fish stuck in throat (S...,The diagnosis is Fish stuck in throat (Sign or...
4,ROCO_00558,T109 Organic Chemical T060 Diagnostic Procedur...,CT scan shows the bullet clear in the lower a...,\nAI: The diagnosis for procedure (Diagnostic ...,The diagnosis for procedure (Diagnostic Proced...
...,...,...,...,...,...
71,ROCO_00093,T046 Pathologic Function T047 Disease or Syndr...,Anterior uveitis with cystoid macular edema.I...,\nAI: The UMLS semantic types describe the sym...,The UMLS semantic types describe the symptoms ...
72,ROCO_00098,T033 Finding T047 Disease or Syndrome T060 Dia...,Ocular ultrasound of the left eye demonstrati...,\nAI: Ocular ultrasound of the left eye at pre...,Ocular ultrasound of the left eye at presentat...
73,ROCO_00034,T201 Clinical Attribute T074 Medical Device T0...,Fluoroscopy image of EVAR procedure.,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
74,ROCO_00462,T060 Diagnostic Procedure T041 Mental Process ...,CT scan showing mass separate from right kidn...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...


In [11]:
# Select n% of data
n = 100
n = n / 100
train_input = train_input.sample(frac=n, random_state=42)
valid_input = valid_input.sample(frac=n, random_state=42)
test_input = test_input.sample(frac=n, random_state=42)

train_target = train_target.sample(frac=n, random_state=42)
valid_target = valid_target.sample(frac=n, random_state=42)
test_target = test_target.sample(frac=n, random_state=42)

# Reset index
train_input = train_input.reset_index(drop=True)
valid_input = valid_input.reset_index(drop=True)
test_input = test_input.reset_index(drop=True)

train_target = train_target.reset_index(drop=True)
valid_target = valid_target.reset_index(drop=True)
test_target = test_target.reset_index(drop=True)

NameError: name 'train_target' is not defined

In [None]:
valid_target

#### T5 tokenizer and model

In [21]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


#### Custom dataset class

In [22]:
class RoCoDataset(Dataset):
    def __init__(self, input_file, tokenizer):
        self.data = input_file
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        input_text = row['semtypes']
        target_text = f"{row['caption']}"
        
#         input_encoding = self.tokenizer.encode_plus(input_text, padding=True, truncation=True)
#         target_encoding = self.tokenizer.encode_plus(target_text, padding=True, truncation=True)

        input_encoding = self.tokenizer.encode_plus(input_text, padding='max_length', max_length=256, truncation=True)
        target_encoding = self.tokenizer.encode_plus(target_text, padding='max_length', max_length=256, truncation=True)

        input_ids = input_encoding['input_ids']
        input_attention_mask = input_encoding['attention_mask']
        target_ids = target_encoding['input_ids']
        target_attention_mask = target_encoding['attention_mask']
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(input_attention_mask),
            'labels': torch.tensor(target_ids),
            'decoder_attention_mask': torch.tensor(target_attention_mask),
        }

#### Load dataset in custom format

In [23]:
train_dataset = RoCoDataset(train_input, tokenizer)
valid_dataset = RoCoDataset(valid_input, tokenizer)

#### Define the training arguments and create the trainer

In [24]:
training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    save_steps=500,
    save_total_limit=2,
    overwrite_output_dir=True,
    learning_rate=1e-4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    evaluation_strategy='steps',
    eval_steps=500,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

In [25]:
# Train the model
trainer.train()

# Save the trained model
trainer.save_model('./trained_model')

***** Running training *****
  Num examples = 353
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 267
  Number of trainable parameters = 222903552


Step,Training Loss,Validation Loss




Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ./trained_model
Configuration saved in ./trained_model\config.json
Model weights saved in ./trained_model\pytorch_model.bin
tokenizer config file saved in ./trained_model\tokenizer_config.json
Special tokens file saved in ./trained_model\special_tokens_map.json


In [26]:
test_input

Unnamed: 0,id,semtypes,caption,relationship,summary
0,ROCO_00050,T074 Medical Device T047 Disease or Syndrome T...,Free air beneath the diaphragm at abdominal x...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
1,ROCO_00573,T060 Diagnostic Procedure T040 Organism Functi...,CT scan of the neck and upper medastinum: Con...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
2,ROCO_00434,T109 Organic Chemical T121 Pharmacologic Subst...,Abdomen plan X-ray.,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
3,ROCO_00334,"T060 Diagnostic Procedure T023 Body Part, Orga...",Coronal computed tomography of the neck clear...,\nAI: The diagnosis is Fish stuck in throat (S...,The diagnosis is Fish stuck in throat (Sign or...
4,ROCO_00558,T109 Organic Chemical T060 Diagnostic Procedur...,CT scan shows the bullet clear in the lower a...,\nAI: The diagnosis for procedure (Diagnostic ...,The diagnosis for procedure (Diagnostic Proced...
...,...,...,...,...,...
71,ROCO_00093,T046 Pathologic Function T047 Disease or Syndr...,Anterior uveitis with cystoid macular edema.I...,\nAI: The UMLS semantic types describe the sym...,The UMLS semantic types describe the symptoms ...
72,ROCO_00098,T033 Finding T047 Disease or Syndrome T060 Dia...,Ocular ultrasound of the left eye demonstrati...,\nAI: Ocular ultrasound of the left eye at pre...,Ocular ultrasound of the left eye at presentat...
73,ROCO_00034,T201 Clinical Attribute T074 Medical Device T0...,Fluoroscopy image of EVAR procedure.,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...
74,ROCO_00462,T060 Diagnostic Procedure T041 Mental Process ...,CT scan showing mass separate from right kidn...,\nAI: The UMLS semantic types describe the dia...,The UMLS semantic types describe the diagnosis...


In [None]:
temp = pd.merge(test_input, test_target, on="id", how="inner")
temp

In [27]:
# Calculate evaluation metrics
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
bleu_scores = []

test_input.reset_index(drop=True, inplace=True)
test_target.reset_index(drop=True, inplace=True)

for _, row in test_input.iterrows():
    input_id = row['id']
    input_text = row['semtypes']
    target_caption = row['caption']
    
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    
    input_ids = input_ids.to(device)
    
    # Generate captions using the model
    outputs = model.generate(input_ids)

    # Decode the generated captions
    generated_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
    bleu_score = sentence_bleu([target_caption.split()], generated_caption.split())
    bleu_scores.append(bleu_score)

    print('Input: ', input_text)
    print('Generated Caption: ', generated_caption)
    print('Original Caption: ', target_caption)
    print()
    
average_bleu_score = sum(bleu_scores) / len(bleu_scores)

# Print average BLEU score
print("Average BLEU score:", average_bleu_score)

NameError: name 'test_target' is not defined

## With rationale

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

class RoCoDatasetWithRationale(Dataset):
    def __init__(self, input_file, tokenizer):
        self.data = input_file
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        input_text = row['semtypes']
        target_text = f"{row['caption']} </s> {row['summary']}"
        
#         input_encoding = self.tokenizer.encode_plus(input_text, padding=True, truncation=True)
#         target_encoding = self.tokenizer.encode_plus(target_text, padding=True, truncation=True)

        input_encoding = self.tokenizer.encode_plus(input_text, padding='max_length', max_length=256, truncation=True)
        target_encoding = self.tokenizer.encode_plus(target_text, padding='max_length', max_length=256, truncation=True)

        input_ids = input_encoding['input_ids']
        input_attention_mask = input_encoding['attention_mask']
        target_ids = target_encoding['input_ids']
        target_attention_mask = target_encoding['attention_mask']
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(input_attention_mask),
            'labels': torch.tensor(target_ids),
            'decoder_attention_mask': torch.tensor(target_attention_mask),
        }
    
train_dataset = RoCoDatasetWithRationale(train_input, tokenizer)
valid_dataset = RoCoDatasetWithRationale(valid_input, tokenizer)

training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    save_steps=500,
    save_total_limit=2,
    overwrite_output_dir=True,
    learning_rate=1e-4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    evaluation_strategy='steps',
    eval_steps=500,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

In [None]:
# Train the model
trainer.train()

# Save the trained model
trainer.save_model('./trained_model')

In [None]:
# Calculate evaluation metrics
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
bleu_scores = []

test_input.reset_index(drop=True, inplace=True)
test_target.reset_index(drop=True, inplace=True)

for _, row in test_input.iterrows():
    input_id = row['id']
    input_text = row['semtypes']
    target_caption = row['caption']
    
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    
    input_ids = input_ids.to(device)
    
    # Generate captions using the model
    outputs = model.generate(input_ids)

    # Decode the generated captions
    generated_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
    bleu_score = sentence_bleu([target_caption.split()], generated_caption.split())
    bleu_scores.append(bleu_score)

    print('Input: ', input_text)
    print('Generated Caption: ', generated_caption)
    print('Original Caption: ', target_caption)
    print()
    
average_bleu_score = sum(bleu_scores) / len(bleu_scores)

# Print average BLEU score
print("Average BLEU score:", average_bleu_score)

In [None]:
# # file paths for training and validation data
# base_path = "./roco-image-captioning/all_data"
# train_input_file = base_path + "/train/radiology/semtypes.txt"
# train_target_file = base_path + "/train/radiology/captions.txt"
# valid_input_file = base_path + "/validation/radiology/semtypes.txt"
# valid_target_file = base_path + "/validation/radiology/captions.txt"

# train_input_data = [line.split() for line in open(train_input_file)]
# train_input_data = [[t[0], " ".join(t[1:])] for t in train_input_data]
# train_input_data = pd.DataFrame(train_input_data)
# train_input_data.columns = ["id", "semtypes"]

# train_target_data = pd.read_csv(train_target_file, sep="\t", header=None)
# train_target_data.columns = ["id", "caption"]

# # last 10 rows of train_input_data
# input_data = train_input_data.tail(50)
# target_data = train_target_data.tail(50)

# for _, row in input_data.iterrows():
#     input_id = row['id']
#     input_text = row['semtypes']
    
#     # Tokenize the input text
#     input_ids = tokenizer.encode(input_text, return_tensors='pt')

#     # Generate captions using the model
#     outputs = model.generate(input_ids)

#     # Decode the generated captions
#     generated_captions = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
#     target_caption = train_target_data[train_target_data['id'] == input_id]['caption'].values[0]
#     print('Input: ', input_text)
#     print('Generated Caption: ', generated_captions)
#     print('Original Caption: ', target_caption)
#     print()