# Demonstration Notebook: CT-AGRG Model Initialization, Loading, and Inference

This notebook demonstrates the initialization, loading, and inference procedures for the CT-AGRG model.

In [None]:
%cd ../.

In [None]:
import time
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

## 1. Config file

In [None]:
import torch
import argparse

from step2.modules.utils_dir import parse_yaml

parser = argparse.ArgumentParser(description='Script with parameters from JSON file')
parser.add_argument('--yaml_file', type=str, default='.step2/config/default.yaml', help='YAML file containing parameters')
args = parser.parse_args([])

args.yaml_file = "./step2/config/default.yaml"

params         = parse_yaml(args.yaml_file)
yaml_file      = args.yaml_file
args           = argparse.Namespace(**params)
args.yaml_file = yaml_file

args.device    = torch.device('cuda:2')
device         = args.device

print(f'Warnings! Device: {device} will be used.')

## 2. Initialize and load CT-AGRG

In [None]:
from step2.models.report_generation_model import ReportGenerationModel

# Please adjust paths
args.path_gpt2         = "./ckpt/gpt-2-pubmed-medium"
args.path_report_model = "./ckpt/model_state_dict.pth"
args.path_thresholds   = "./ckpt/thresholds.json"

model = ReportGenerationModel(args, mode="generation")
ckpt  = torch.load(args.path_report_model)
msg = model.load_state_dict(ckpt);
print(msg)

model.to(device);
model.freeze();
model.eval();

### Load GPT-2 tokenizer

In [None]:
from transformers import GPT2Tokenizer

def get_tokenizer(args):
    """
    Return GPT-2 tokenizer.
    """
    checkpoint          = args.path_gpt2
    tokenizer           = GPT2Tokenizer.from_pretrained(checkpoint)
    tokenizer.pad_token = tokenizer.eos_token

    return tokenizer

tokenizer = get_tokenizer(args)

## 4. Load CT Scan

In [None]:
import numpy as np

def nii_img_to_tensor(path, dd=240, dh=480, dw=480):
    """
    Read the volume and pre-process it.
    """

    # Warning: To adjust!
    # Assuming that the CT scan is already formatted with SLP orientation with HU values
    array = np.load(path)['arr_0'] # [C, H, W]

    # Array to tensor
    tensor = torch.tensor(array)

    # Clip Hounsfield Units to [-1000, +200]
    tensor = torch.clip(tensor, -1000., +200.)

    # Shift to [0, +1200]
    tensor = tensor + torch.tensor(+1000., dtype=torch.float32)

    # Map [0, +1200] to [0, 1]
    tensor = tensor / torch.tensor(+1200., dtype=torch.float32)

    # ImageNet Normalization
    tensor = tensor + torch.tensor(-0.449, dtype=torch.float32)

    # extract dimensions
    d, h, w = tensor.shape

    # calculate cropping values for height, width, and depth
    h_start = max((h - dh) // 2, 0)
    h_end   = min(h_start + dh, h)
    w_start = max((w - dw) // 2, 0)
    w_end   = min(w_start + dw, w)
    d_start = max((d - dd) // 2, 0)
    d_end   = min(d_start + dd, d)

    # crop
    tensor = tensor[d_start:d_end, h_start:h_end, w_start:w_end]

    # # calculate padding values for height, width, and depth
    pad_h_before = (dh - tensor.size(1)) // 2
    pad_h_after  = dh - tensor.size(1) - pad_h_before
    pad_w_before = (dw - tensor.size(2)) // 2
    pad_w_after  = dw - tensor.size(2) - pad_w_before
    pad_d_before = (dd - tensor.size(0)) // 2
    pad_d_after  = dd - tensor.size(0) - pad_d_before

    # pad
    tensor = torch.nn.functional.pad(tensor, (pad_w_before, pad_w_after, pad_h_before, pad_h_after, pad_d_before, pad_d_after), value=-0.449)

    # unsqueeze
    tensor = tensor.unsqueeze(0) # [1, 240, 480, 480]

    return tensor

In [None]:
# Please adjust with your own path.
path_volume = "/path/to/volume"
volume      = nii_img_to_tensor(path_volume) # [1, 240, 480, 480]

## 5. Inference to generate the report

In [None]:
# extract generated report as a string
generated_report = model.generate(
    tokenizer            = tokenizer, 
    volumes              = volume.unsqueeze(0), 
    max_length           = args.max_seq_length,
    num_beams            = args.beam_size, 
    num_beam_groups      = args.group_size,
    do_sample            = args.do_sample,
    num_return_sequences = args.num_return_sequences,
    early_stopping       = args.early_stopping
)

print(generated_report)