# Chest X-ray Image Report Generation (CXIRG)

## Install Required Modules

In [1]:
!pip install openpyxl
!pip install pandas
!pip install pillow
!pip install pytorch-ignite
!pip install scikit-learn
!pip install torch
!pip install transformers



## Import Required Modules

In [2]:
import os
import random
import torch

import numpy as np
import pandas as pd

from ignite.metrics import Rouge
from pandas.core.common import random_state
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, ViTModel, ViTImageProcessor
from typing import Any, Dict, List, Tuple

  from .autonotebook import tqdm as notebook_tqdm


## Set The Random Seed

In [3]:
seed = 48763

np.random.seed(seed=seed, )

random_state(state=seed, )

random.seed(a=seed, )

torch.manual_seed(seed=seed, )
torch.cuda.manual_seed(seed=seed, )
torch.cuda.manual_seed_all(seed=seed, )
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

## Set The Device & Initialize Models

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Language Model
lm_config = GPT2Config.from_pretrained("gpt2")
lm_config.add_cross_attention = True
lm = GPT2LMHeadModel.from_pretrained("gpt2", config=lm_config).to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Vision Model
vm = ViTModel.from_pretrained("google/vit-base-patch16-224").to(device)
for param in vm.parameters():
    param.requires_grad = False
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

## The CXIRG Dataset

In [5]:
class CXIRGDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]]) -> None:
        super(CXIRGDataset, self).__init__()
        self.data = data

    def __getitem__(self, index: int) -> Dict[str, Any]:
        return self.data[index]

    def __len__(self) -> int:
        return len(self.data)

## The Collate Function for The DataLoader

In [6]:
prompt = "Please generate a report based on the given chest X-ray image."

def collate_fn(one_batch_data: List[Dict[str, Any]]):
    names = [one_data["name"] for one_data in one_batch_data]

    images_middle = processor(
        images=[one_data["image"] for one_data in one_batch_data], 
        return_tensors="pt"
    )
    images_middle = images_middle.to(device)

    with torch.no_grad():
        images_embedding = vm(**images_middle).last_hidden_state

    max_length = max([len(one_data["text"]) for one_data in one_batch_data])
    max_length = min(max_length, 1024)

    inputs_token = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=[
            (prompt + " " + tokenizer.eos_token + " " + one_data["text"]) for one_data in one_batch_data
        ],
        max_length=max_length,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    attention_mask = inputs_token["attention_mask"]
    inputs_token = inputs_token["input_ids"]

    labels_token = inputs_token.clone()
    for idx, one_data in enumerate(one_batch_data):
        text_length = (len(tokenizer(prompt)["input_ids"]) + 1)
        labels_token[idx, :text_length] = -100

    return names, images_embedding.to(device), inputs_token.to(device), labels_token.to(device), attention_mask.to(device)

## Load The Train & Validation Data

In [7]:
train_data = []

report_path = "data/train_data/reports.xlsx"
report_df = pd.read_excel(report_path)

image_dir_path = "data/train_data/images"
for image_name in os.listdir(image_dir_path):
    image = Image.open(os.path.join(image_dir_path, image_name))
    if image.mode != "RGB":
        image = image.convert("RGB")

    text = report_df[report_df["name"] == image_name[:13]]["text"].values[0].replace("_x000D_", "\r")

    train_data.append({
        "name": image_name[:13],
        "image": image,
        "text": text
    })

train_dataset = CXIRGDataset(train_data)

In [8]:
valid_data = []

report_path = "data/valid_data/reports.xlsx"
report_df = pd.read_excel(report_path)

image_dir_path = "data/valid_data/images"
for image_name in os.listdir(image_dir_path):
    image = Image.open(os.path.join(image_dir_path, image_name))
    if image.mode != "RGB":
        image = image.convert("RGB")

    text = report_df[report_df["name"] == image_name[:13]]["text"].values[0].replace("_x000D_", "\r")

    valid_data.append({
        "name": image_name[:13],
        "image": image,
        "text": text
    })

valid_dataset = CXIRGDataset(valid_data)

## Set The Hyperparameters & Initialize The Optimizer, Dataloaders and Evaluation Metric

In [9]:
lr = 1e-5
epochs = 1
optimizer = AdamW(params=lm.parameters(), lr=lr)

train_batch_size = 8
valid_batch_size = 1
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=collate_fn
)
valid_dataloader = DataLoader(
    dataset=valid_dataset,
    batch_size=valid_batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

rouge = Rouge(variants=["L", 2], multiref="best")

## The Evaluation Function

In [10]:
def evaluate(lm: GPT2LMHeadModel, epoch: int) -> Tuple[Dict[str, float], float]:
    lm.eval()

    pbar = tqdm(valid_dataloader)
    pbar.set_description(f"Evaluting Epoch: {epoch + 1}")

    loss_list = []

    with torch.no_grad():
        for names, images, inputs, labels, attention_mask in pbar:
            loss = lm(
                input_ids=inputs,
                attention_mask=attention_mask,
                encoder_hidden_states=images,
                labels=labels
            ).loss
            loss_list.append(loss.item())
            pbar.set_postfix(loss=loss.item())

            predictions_inputs = {
                "input_ids": inputs,
                "attention_mask": attention_mask,
                "encoder_hidden_states": images
            }

            predictions = [
                prediction for prediction in tokenizer.batch_decode(
                    lm.generate(
                        **predictions_inputs,
                        max_length=256
                    )
                )
            ]

            prompt_token = tokenizer.encode(
                text=prompt,
                return_tensors="pt"
            ).to(device).squeeze(0)

            _labels = []
            for _label in labels:
                _label = torch.cat((prompt_token, _label), 0)
                _label = _label.tolist()
                _label = [token for token in _label if token != -100]
                _labels.append(tokenizer.batch_decode([_label])[0])

            print(f"Names: {names}")
            print(f"Predictions: {predictions}")
            print(f"Labels: {_labels}")
            print()

            for prediction, _label in zip(predictions, _labels):
                split_prediction = prediction.split()
                split_label = _label.split()

                for one_word in split_prediction:
                    rouge.update(([one_word], [split_label]))

    return rouge.compute(), np.mean(np.array(loss_list))

In [11]:
for epoch in range(epochs):
    lm.train()

    pbar = tqdm(train_dataloader)
    pbar.set_description(f"Training Epoch [{epoch + 1} / {epochs}]")

    for _, images, inputs, labels, attention_mask in pbar:
        optimizer.zero_grad()
        loss = lm(
            input_ids=inputs,
            attention_mask=attention_mask,
            encoder_hidden_states=images,
            labels=labels
        ).loss
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item())

    torch.save(lm, f"outputs/checkpoint_{epoch}.pt")

    print(f"Rouge-2 score on epoch {epoch}:", evaluate(lm=lm, epoch=epoch))

Training Epoch [1 / 1]: 100%|██████████| 12/12 [00:05<00:00,  2.04it/s, loss=5.21]
Evaluting Epoch: 1:   0%|          | 0/10 [00:00<?, ?it/s, loss=4.89]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  10%|█         | 1/10 [00:00<00:02,  4.12it/s, loss=5.81]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  20%|██        | 2/10 [00:00<00:01,  5.77it/s, loss=5.81]

Names: ['NLP_CHEST_071']
Predictions: ['Please generate a report based on the given chest X-ray image. <|endoftext|> Chest film shows:\r\nImpression:\r\n-Bilateral lung infiltrations.\r\n-Suspect right lower lung patch. \r\n Blunting right CP angle. \r\n-Tortuous atherosclerotic aorta. \r\n-Scoliosis, DJD and osteoporosis of spine. \r\n Compression fracture of T12.\r\n Old fracture of left ribs.\r\n-S/P fixation in L-spine.  \r\n-S/P tracheostomy and NG tube.   \r\n-S/P tracheostomy and NG tube.   <|endoftext|>']
Labels: ['Please generate a report based on the given chest X-ray image.<|endoftext|> Chest film shows:\r\nImpression:\r\n-Bilateral lung infiltrations.\r\n-Suspect right lower lung patch. \r\n Blunting right CP angle. \r\n-Tortuous atherosclerotic aorta. \r\n-Scoliosis, DJD and osteoporosis of spine. \r\n Compression fracture of T12.\r\n Old fracture of left ribs.\r\n-S/P fixation in L-spine.  \r\n-S/P tracheostomy and NG tube.   \r\n']

Names: ['NLP_CHEST_002']
Predictions: 

Evaluting Epoch: 1:  20%|██        | 2/10 [00:00<00:01,  5.77it/s, loss=5.17]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  30%|███       | 3/10 [00:00<00:01,  6.57it/s, loss=4.64]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  40%|████      | 4/10 [00:00<00:00,  6.92it/s, loss=4.64]

Names: ['NLP_CHEST_004']
Predictions: ['Please generate a report based on the given chest X-ray image. <|endoftext|> Chest PA view show: \r\nImpression:\r\n-S/P RLL wedge resection.\u3000 \r\n-Bilateral lungs metastasis.\r\n-Left lower lung subsegmental atelectasis. \r\n-Increased infiltrations in both lungs.\r\n-Blunting right CP angle. \r\n-Tortuous atherosclerotic aorta.\r\n-Scoliosis, DJD and osteoporosis of spine. \r\n-Compression fracture of L1.<|endoftext|>']
Labels: ['Please generate a report based on the given chest X-ray image.<|endoftext|> Chest PA view show: \r\nImpression:\r\n-S/P RLL wedge resection.\u3000 \r\n-Bilateral lungs metastasis.\r\n-Left lower lung subsegmental atelectasis. \r\n-Increased infiltrations in both lungs.\r\n-Blunting right CP angle. \r\n-Tortuous atherosclerotic aorta.\r\n-Scoliosis, DJD and osteoporosis of spine. \r\n-Compression fracture of L1.']

Names: ['NLP_CHEST_031']
Predictions: ['Please generate a report based on the given chest X-ray image

Evaluting Epoch: 1:  40%|████      | 4/10 [00:00<00:00,  6.92it/s, loss=4.97]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  50%|█████     | 5/10 [00:00<00:00,  6.87it/s, loss=5.65]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Names: ['NLP_CHEST_057']
Predictions: ['Please generate a report based on the given chest X-ray image. <|endoftext|> Chest plain film shows:\r\nImpression:\r\n-Increased infiltrations in both lungs.\r\n-Tortuous atherosclerotic dilated aorta.\r\n-Normal heart size. \r\n-DJD of spine. \r\n Old fracture of right ribs.\r\n-Increased both lung markings. \r\n S/P Lt jugular CVC insertion. \r\n S/P NG and endotracheal tube.\r\n-Susp. Lt pneumothorax. \r\n Suspect pneumomediastinum. \r\n Subcutaneous emphysema in bilateral neck.  \r\n-S/P bilateral chest tube insertion. <|endoftext|>']
Labels: ['Please generate a report based on the given chest X-ray image.<|endoftext|> Chest plain film shows:\r\nImpression:\r\n-Increased infiltrations in both lungs.\r\n-Tortuous atherosclerotic dilated aorta.\r\n-Normal heart size. \r\n-DJD of spine. \r\n Old fracture of right ribs.\r\n-Increased both lung markings. \r\n S/P Lt jugular CVC insertion. \r\n S/P NG and endotracheal tube.\r\n-Susp. Lt pneumothor

Evaluting Epoch: 1:  60%|██████    | 6/10 [00:01<00:00,  6.37it/s, loss=4.79]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  70%|███████   | 7/10 [00:01<00:00,  6.79it/s, loss=4.92]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Names: ['NLP_CHEST_027']
Predictions: ['Please generate a report based on the given chest X-ray image. <|endoftext|> Chest PA View:\r\nImpression: \r\n> Cardiomegaly with bilateral pulmonary congestion.\r\n> Postinflammatory fibrosis in both upper lungs.\r\n> Atherosclerosis of aorta.\r\n> Old fractures of left 5th and 6th ribs.\r\n> R/O osteoporosis.\r\n> Spondylosis of thoracolumbar spine.\r\n> S/P abdominal operation in RUQ.<|endoftext|>']
Labels: ['Please generate a report based on the given chest X-ray image.<|endoftext|> Chest PA View:\r\nImpression: \r\n> Cardiomegaly with bilateral pulmonary congestion.\r\n> Postinflammatory fibrosis in both upper lungs.\r\n> Atherosclerosis of aorta.\r\n> Old fractures of left 5th and 6th ribs.\r\n> R/O osteoporosis.\r\n> Spondylosis of thoracolumbar spine.\r\n> S/P abdominal operation in RUQ.']

Names: ['NLP_CHEST_085']
Predictions: ['Please generate a report based on the given chest X-ray image. <|endoftext|> Chest X ray: \r\n\r\n- Right pne

Evaluting Epoch: 1:  80%|████████  | 8/10 [00:01<00:00,  6.20it/s, loss=5.28]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Evaluting Epoch: 1:  90%|█████████ | 9/10 [00:01<00:00,  5.94it/s, loss=5.99]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Names: ['NLP_CHEST_011']
Predictions: ["Please generate a report based on the given chest X-ray image. <|endoftext|> Chest AP view showed:\r\n1.s/p sternotomy and CABG.\r\n  Enlarged heart size with tortuous aorta.\r\n2.R't middle and lower lung faint patches.\r\n  L't lower lung consolidation.\r\n  L't pleural effusion.\r\n3.No mediastinum widening.\r\n4.s/p endotracheal tube and NG intubation.\r\n  L't lower lung and L't lower lung.<|endoftext|>"]
Labels: ["Please generate a report based on the given chest X-ray image.<|endoftext|> Chest AP view showed:\r\n1.s/p sternotomy and CABG.\r\n  Enlarged heart size with tortuous aorta.\r\n2.R't middle and lower lung faint patches.\r\n  L't lower lung consolidation.\r\n  L't pleural effusion.\r\n3.No mediastinum widening.\r\n4.s/p endotracheal tube and NG intubation.\r\n"]



Evaluting Epoch: 1: 100%|██████████| 10/10 [00:01<00:00,  6.08it/s, loss=5.99]

Names: ['NLP_CHEST_015']
Predictions: ['Please generate a report based on the given chest X-ray image. <|endoftext|> Chest X ray: \r\n\r\n- No obvious lung mass nor consolidation patch.\r\n- Normal heart size.\r\n- No pleural effusion. \U000fe309\n- No pleural effusion. \U000fe309<|endoftext|>']
Labels: ['Please generate a report based on the given chest X-ray image.<|endoftext|> Chest X ray: \r\n\r\n- No obvious lung mass nor consolidation patch.\r\n- Normal heart size.\r\n- No pleural effusion. ']

Rouge-2 score on epoch 0: ({'Rouge-L-P': 0.6607337529932065, 'Rouge-L-R': 0.9952925507485995, 'Rouge-L-F': 0.9952925507485995, 'Rouge-2-P': 0.8235696003058185, 'Rouge-2-R': 0.927234826449486, 'Rouge-2-F': 0.927234826449486}, 5.210005569458008)



