#### Import Libraries

In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, TrainerCallback

#### Load Data

In [19]:
# file paths for training and validation data
base_path = "./all_data"
train_input_file = base_path + "/train/radiology/semtypes.txt"
train_target_file = base_path + "/train/radiology/captions.txt"
valid_input_file = base_path + "/validation/radiology/semtypes.txt"
valid_target_file = base_path + "/validation/radiology/captions.txt"

train_input_data = [line.split() for line in open(train_input_file)]
train_input_data = [[t[0], " ".join(t[1:])] for t in train_input_data]
train_input_data = pd.DataFrame(train_input_data)
train_input_data.columns = ["id", "semtypes"]

train_target_data = pd.read_csv(train_target_file, sep="\t", header=None)
train_target_data.columns = ["id", "caption"]

valid_input_data = [line.split() for line in open(valid_input_file)]
valid_input_data = [[t[0], " ".join(t[1:])] for t in valid_input_data]
valid_input_data = pd.DataFrame(valid_input_data)
valid_input_data.columns = ["id", "semtypes"]

valid_target_data = pd.read_csv(valid_target_file, sep="\t", header=None)
valid_target_data.columns = ["id", "caption"]

In [20]:
# only the first 20% of the data for testing
train_input_data = train_input_data[:int(len(train_input_data) * 0.2)]
train_target_data = train_target_data[:int(len(train_target_data) * 0.2)]
valid_input_data = valid_input_data[:int(len(valid_input_data) * 0.2)]
valid_target_data = valid_target_data[:int(len(valid_target_data) * 0.2)]

#### Custom dataset class

In [21]:
class RoCoDataset(Dataset):
    def __init__(self, input_file, target_file, tokenizer):
        self.data = pd.merge(input_file, target_file, on="id", how="inner")
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        input_text = row['semtypes']
        target_text = row['caption']
        
        input_encoding = self.tokenizer.encode_plus(input_text, padding='max_length', max_length=128, truncation=True)
        target_encoding = self.tokenizer.encode_plus(target_text, padding='max_length', max_length=128, truncation=True)
        
        input_ids = input_encoding['input_ids']
        input_attention_mask = input_encoding['attention_mask']
        target_ids = target_encoding['input_ids']
        target_attention_mask = target_encoding['attention_mask']
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(input_attention_mask),
            'labels': torch.tensor(target_ids),
            'decoder_attention_mask': torch.tensor(target_attention_mask),
        }

#### T5 tokenizer and model

In [5]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

#### Load dataset in custom format and create the data loaders

In [22]:
train_dataset = RoCoDataset(train_input_data, train_target_data, tokenizer)
valid_dataset = RoCoDataset(valid_input_data, valid_target_data, tokenizer)

In [25]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

#### Define the training arguments and create the trainer

In [26]:
training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_steps=500,
    save_total_limit=2,
    overwrite_output_dir=True,
    learning_rate=1e-4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    evaluation_strategy='steps',
    eval_steps=500,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

In [27]:
# Train the model
trainer.train()

# Save the trained model
trainer.save_model('./trained_model')

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
500,0.8844,0.87994,619.5992,2.64


KeyboardInterrupt: 

In [28]:
# file paths for training and validation data
base_path = "./roco-image-captioning/all_data"
train_input_file = base_path + "/train/radiology/semtypes.txt"
train_target_file = base_path + "/train/radiology/captions.txt"
valid_input_file = base_path + "/validation/radiology/semtypes.txt"
valid_target_file = base_path + "/validation/radiology/captions.txt"

train_input_data = [line.split() for line in open(train_input_file)]
train_input_data = [[t[0], " ".join(t[1:])] for t in train_input_data]
train_input_data = pd.DataFrame(train_input_data)
train_input_data.columns = ["id", "semtypes"]

train_target_data = pd.read_csv(train_target_file, sep="\t", header=None)
train_target_data.columns = ["id", "caption"]

# last 10 rows of train_input_data
input_data = train_input_data.tail(50)
target_data = train_target_data.tail(50)

for _, row in input_data.iterrows():
    input_id = row['id']
    input_text = row['semtypes']
    
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Generate captions using the model
    outputs = model.generate(input_ids)

    # Decode the generated captions
    generated_captions = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    target_caption = train_target_data[train_target_data['id'] == input_id]['caption'].values[0]
    print('Input: ', input_text)
    print('Generated Caption: ', generated_captions)
    print('Original Caption: ', target_caption)
    print()

Input:  T060 Diagnostic Procedure T170 Intellectual Product T041 Mental Process T121 Pharmacologic Substance T125 Hormone T116 Amino Acid, Peptide, or Protein T023 Body Part, Organ, or Organ Component
Generated Caption:  Axial computed tomography of the right sagittal artery (CT
Original Caption:   Radiograph showing mixed density, periosteal reaction, and cortical disruption with soft tissue swelling in metaphysis of left fibula

Input:  T046 Pathologic Function T170 Intellectual Product T074 Medical Device T121 Pharmacologic Substance T033 Finding T023 Body Part, Organ, or Organ Component T109 Organic Chemical
Generated Caption:  Axial T1-weighted image of the tibia demonstrating
Original Caption:   The bleeding site was treated with a 3mm x10mm coronary artery covered stent (Jo Stent Graftmaster) (white arrows).

Input:  T130 Indicator, Reagent, or Diagnostic Aid T074 Medical Device T061 Therapeutic or Preventive Procedure T170 Intellectual Product
Generated Caption:  Axial X-ray of

Input:  T060 Diagnostic Procedure T170 Intellectual Product T041 Mental Process T130 Indicator, Reagent, or Diagnostic Aid T191 Neoplastic Process T033 Finding T023 Body Part, Organ, or Organ Component T109 Organic Chemical
Generated Caption:  Contrast computed tomography (CT) scan of the pelvis showing a
Original Caption:   Hemangioma. Color Doppler image reveals a highly vascular mass consistent with capillary hemangioma at the superolateral margin of the right orbit.

Input:  T060 Diagnostic Procedure T061 Therapeutic or Preventive Procedure
Generated Caption:  Preoperative radiograph of the patient at the time of the first procedure.
Original Caption:   Oesophageal intubation as seen on sonography

Input:  T109 Organic Chemical T033 Finding T121 Pharmacologic Substance T060 Diagnostic Procedure
Generated Caption:  CT scan showing a calcified aorta (arrow).
Original Caption:   Computer tomography demonstrating a well-defined fat density submucosal lesion

Input:  T074 Medical Device

Input:  T046 Pathologic Function T061 Therapeutic or Preventive Procedure T121 Pharmacologic Substance T047 Disease or Syndrome T037 Injury or Poisoning T109 Organic Chemical T201 Clinical Attribute
Generated Caption:  Preoperative axial MRI of a patient with a septic sy
Original Caption:   Hinchey stage III. Intravenous and rectal contrast-enhanced axial CT of the abdomen showing diverticulitis with multiple abscesses (arrow) in the inframesocolic region and pneumoperitoneum, together with generalized peritonitis.

Input:  T074 Medical Device T061 Therapeutic or Preventive Procedure
Generated Caption:  A stent placed on the stent.
Original Caption:   Decompression tube is inserted followed by the guidewire.

Input:  T046 Pathologic Function T060 Diagnostic Procedure T023 Body Part, Organ, or Organ Component
Generated Caption:  Preoperative X-ray of the pelvis showing a large calcified ilia
Original Caption:   Doppler ultrasound revealed extensive thrombosis involving the iliac veins

