#### Import Libraries

In [117]:
import ast
import torch
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, TrainerCallback

#### Load Data

In [118]:
# file paths for training and validation data
base_path = "./all_data"
train_input_file = base_path + "/train/radiology/semtypes.txt"
train_target_file = base_path + "/train/radiology/captions.txt"
valid_input_file = base_path + "/validation/radiology/semtypes.txt"
valid_target_file = base_path + "/validation/radiology/captions.txt"

train_input_data = [line.split() for line in open(train_input_file)]
train_input_data = [[t[0], " ".join(t[1:])] for t in train_input_data]
train_input_data = pd.DataFrame(train_input_data)
train_input_data.columns = ["id", "semtypes"]

train_target_data = pd.read_csv(train_target_file, sep="\t", header=None)
train_target_data.columns = ["id", "caption"]

valid_input_data = [line.split() for line in open(valid_input_file)]
valid_input_data = [[t[0], " ".join(t[1:])] for t in valid_input_data]
valid_input_data = pd.DataFrame(valid_input_data)
valid_input_data.columns = ["id", "semtypes"]

valid_target_data = pd.read_csv(valid_target_file, sep="\t", header=None)
valid_target_data.columns = ["id", "caption"]

In [119]:
import json

f = open(base_path + '/llm_result.txt', "r")
contents = f.read()
contents = contents.replace("\n", "")
json_data = json.loads(contents)

llm_df = pd.DataFrame(json_data)

llm_df = llm_df.drop('index', axis=1)

llm_df

Unnamed: 0,id,relationship,summary
0,ROCO_00002,with a mass of homogeneous attenuation (necro...,The patient has undergone a CT scan which sho...
1,ROCO_00003,The patient developed pain and discomfort on ...,The patient developed pain and discomfort on ...
2,ROCO_00004,,The UMLS Metathesaurus has 12435 concepts in ...
3,ROCO_00005,,The UMLS Metathesaurus has 12456 semantic typ...
4,ROCO_00007,The given UMLS semantic types are 1. Radiogra...,There is no single UMLS semantic type that re...
...,...,...,...
1066,ROCO_01329,The answer is Causative Agent: Bacterium.,\nCausative Agent: Bacterium
1067,ROCO_01330,and ascites with dilated veins at the periphe...,No new information.
1068,ROCO_01331,,The UMLS Metathesaurus has 12435 concepts in ...
1069,ROCO_01332,".\nThe UMLS semantic types are 1. Body Part, O...",The given UMLS semantic types are 1. Diagnost...


#### Filter ones that have been training by the LLM

In [120]:
train_target_data = train_target_data.merge(llm_df, on='id')
train_target_data

Unnamed: 0,id,caption,relationship,summary
0,ROCO_00002,Computed tomography scan in axial view showin...,with a mass of homogeneous attenuation (necro...,The patient has undergone a CT scan which sho...
1,ROCO_00003,Bacterial contamination occurred after comple...,The patient developed pain and discomfort on ...,The patient developed pain and discomfort on ...
2,ROCO_00004,The patient had residual paralysis of the han...,,The UMLS Metathesaurus has 12435 concepts in ...
3,ROCO_00005,Panoramic radiograph after immediate loading.,,The UMLS Metathesaurus has 12456 semantic typ...
4,ROCO_00007,Plain abdomen x-ray: Multiple air levels at t...,The given UMLS semantic types are 1. Radiogra...,There is no single UMLS semantic type that re...
...,...,...,...,...
1066,ROCO_01329,Abdominal CT finding. Enterocutaneous fistula...,The answer is Causative Agent: Bacterium.,\nCausative Agent: Bacterium
1067,ROCO_01330,(Case 2) Post operative CT scan showing persi...,and ascites with dilated veins at the periphe...,No new information.
1068,ROCO_01331,Magnetic resonance image showing the patient'...,,The UMLS Metathesaurus has 12435 concepts in ...
1069,ROCO_01332,Post-operative chest X-ray image of the same ...,".\nThe UMLS semantic types are 1. Body Part, O...",The given UMLS semantic types are 1. Diagnost...


In [121]:
train_input_data = train_input_data[train_input_data['id'].isin(train_target_data['id'])]
train_input_data

Unnamed: 0,id,semtypes
0,ROCO_00002,T060 Diagnostic Procedure T061 Therapeutic or ...
1,ROCO_00003,T061 Therapeutic or Preventive Procedure T058 ...
2,ROCO_00004,T060 Diagnostic Procedure T047 Disease or Synd...
3,ROCO_00005,T060 Diagnostic Procedure T170 Intellectual Pr...
4,ROCO_00007,T109 Organic Chemical T074 Medical Device T121...
...,...,...
1066,ROCO_01329,T029 Body Location or Region
1067,ROCO_01330,T033 Finding T060 Diagnostic Procedure T046 Pa...
1068,ROCO_01331,T033 Finding T047 Disease or Syndrome T170 Int...
1069,ROCO_01332,T047 Disease or Syndrome T040 Organism Functio...


In [122]:
from sklearn.model_selection import train_test_split

# Split train_input_data into train, test, and valid datasets
train_input, valid_test_input = train_test_split(train_input_data, test_size=0.4, random_state=42)
valid_input, test_input = train_test_split(valid_test_input, test_size=0.5, random_state=42)

# Split train_target_data into train, test, and valid datasets
train_target, valid_test_target = train_test_split(train_target_data, test_size=0.4, random_state=42)
valid_target, test_target = train_test_split(valid_test_target, test_size=0.5, random_state=42)

In [123]:
# Reset index
train_input = train_input.reset_index(drop=True)
valid_input = valid_input.reset_index(drop=True)
test_input = test_input.reset_index(drop=True)

train_target = train_target.reset_index(drop=True)
valid_target = valid_target.reset_index(drop=True)
test_target = test_target.reset_index(drop=True)

In [124]:
# Select n% of data
n = 1
n = n / 100
train_input = train_input.sample(frac=n, random_state=42)
valid_input = valid_input.sample(frac=n, random_state=42)
test_input = test_input.sample(frac=n, random_state=42)

train_target = train_target.sample(frac=n, random_state=42)
valid_target = valid_target.sample(frac=n, random_state=42)
test_target = test_target.sample(frac=n, random_state=42)

#### T5 tokenizer and model

In [125]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')

#### Custom dataset class

In [126]:
class RoCoDataset(Dataset):
    def __init__(self, input_file, target_file, tokenizer):
        self.data = pd.merge(input_file, target_file, on="id", how="inner")
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        input_text = row['semtypes']
        target_text = f"{row['caption']} </s> {row['relationship']} </s> {row['summary']}"
        
#         input_encoding = self.tokenizer.encode_plus(input_text, padding=True, truncation=True)
#         target_encoding = self.tokenizer.encode_plus(target_text, padding=True, truncation=True)

        input_encoding = self.tokenizer.encode_plus(input_text, padding='max_length', max_length=1024, truncation=True)
        target_encoding = self.tokenizer.encode_plus(target_text, padding='max_length', max_length=1024, truncation=True)

        input_ids = input_encoding['input_ids']
        input_attention_mask = input_encoding['attention_mask']
        target_ids = target_encoding['input_ids']
        target_attention_mask = target_encoding['attention_mask']
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(input_attention_mask),
            'labels': torch.tensor(target_ids),
            'decoder_attention_mask': torch.tensor(target_attention_mask),
        }

#### Load dataset in custom format and create the data loaders

In [127]:
train_dataset = RoCoDataset(train_input_data, train_target_data, tokenizer)
valid_dataset = RoCoDataset(valid_input_data, valid_target_data, tokenizer)

In [128]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

#### Define the training arguments and create the trainer

In [129]:
training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_steps=50,
    save_total_limit=2,
    overwrite_output_dir=True,
    learning_rate=1e-4,
    warmup_steps=10,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy='steps',
    eval_steps=50,
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

In [None]:
# Train the model
trainer.train()

# Save the trained model
trainer.save_model('./trained_model')

Step,Training Loss,Validation Loss




In [28]:
# file paths for training and validation data
base_path = "./roco-image-captioning/all_data"
train_input_file = base_path + "/train/radiology/semtypes.txt"
train_target_file = base_path + "/train/radiology/captions.txt"
valid_input_file = base_path + "/validation/radiology/semtypes.txt"
valid_target_file = base_path + "/validation/radiology/captions.txt"

train_input_data = [line.split() for line in open(train_input_file)]
train_input_data = [[t[0], " ".join(t[1:])] for t in train_input_data]
train_input_data = pd.DataFrame(train_input_data)
train_input_data.columns = ["id", "semtypes"]

train_target_data = pd.read_csv(train_target_file, sep="\t", header=None)
train_target_data.columns = ["id", "caption"]

# last 10 rows of train_input_data
input_data = train_input_data.tail(50)
target_data = train_target_data.tail(50)

for _, row in input_data.iterrows():
    input_id = row['id']
    input_text = row['semtypes']
    
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Generate captions using the model
    outputs = model.generate(input_ids)

    # Decode the generated captions
    generated_captions = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    target_caption = train_target_data[train_target_data['id'] == input_id]['caption'].values[0]
    print('Input: ', input_text)
    print('Generated Caption: ', generated_captions)
    print('Original Caption: ', target_caption)
    print()

Input:  T060 Diagnostic Procedure T170 Intellectual Product T041 Mental Process T121 Pharmacologic Substance T125 Hormone T116 Amino Acid, Peptide, or Protein T023 Body Part, Organ, or Organ Component
Generated Caption:  Axial computed tomography of the right sagittal artery (CT
Original Caption:   Radiograph showing mixed density, periosteal reaction, and cortical disruption with soft tissue swelling in metaphysis of left fibula

Input:  T046 Pathologic Function T170 Intellectual Product T074 Medical Device T121 Pharmacologic Substance T033 Finding T023 Body Part, Organ, or Organ Component T109 Organic Chemical
Generated Caption:  Axial T1-weighted image of the tibia demonstrating
Original Caption:   The bleeding site was treated with a 3mm x10mm coronary artery covered stent (Jo Stent Graftmaster) (white arrows).

Input:  T130 Indicator, Reagent, or Diagnostic Aid T074 Medical Device T061 Therapeutic or Preventive Procedure T170 Intellectual Product
Generated Caption:  Axial X-ray of

Input:  T060 Diagnostic Procedure T170 Intellectual Product T041 Mental Process T130 Indicator, Reagent, or Diagnostic Aid T191 Neoplastic Process T033 Finding T023 Body Part, Organ, or Organ Component T109 Organic Chemical
Generated Caption:  Contrast computed tomography (CT) scan of the pelvis showing a
Original Caption:   Hemangioma. Color Doppler image reveals a highly vascular mass consistent with capillary hemangioma at the superolateral margin of the right orbit.

Input:  T060 Diagnostic Procedure T061 Therapeutic or Preventive Procedure
Generated Caption:  Preoperative radiograph of the patient at the time of the first procedure.
Original Caption:   Oesophageal intubation as seen on sonography

Input:  T109 Organic Chemical T033 Finding T121 Pharmacologic Substance T060 Diagnostic Procedure
Generated Caption:  CT scan showing a calcified aorta (arrow).
Original Caption:   Computer tomography demonstrating a well-defined fat density submucosal lesion

Input:  T074 Medical Device

Input:  T046 Pathologic Function T061 Therapeutic or Preventive Procedure T121 Pharmacologic Substance T047 Disease or Syndrome T037 Injury or Poisoning T109 Organic Chemical T201 Clinical Attribute
Generated Caption:  Preoperative axial MRI of a patient with a septic sy
Original Caption:   Hinchey stage III. Intravenous and rectal contrast-enhanced axial CT of the abdomen showing diverticulitis with multiple abscesses (arrow) in the inframesocolic region and pneumoperitoneum, together with generalized peritonitis.

Input:  T074 Medical Device T061 Therapeutic or Preventive Procedure
Generated Caption:  A stent placed on the stent.
Original Caption:   Decompression tube is inserted followed by the guidewire.

Input:  T046 Pathologic Function T060 Diagnostic Procedure T023 Body Part, Organ, or Organ Component
Generated Caption:  Preoperative X-ray of the pelvis showing a large calcified ilia
Original Caption:   Doppler ultrasound revealed extensive thrombosis involving the iliac veins

