# FromageModel Inference Notebook 
## 0. Imports & Config

In [1]:
import torch
import torch.nn as nn
import yaml
import random
import sys
import os
import os.path as osp

In [2]:
sys.path.append('../fromage')
from model import Fromage, FromageModel
from experiment import Experiment
from data import MIMICDataset, cxr_image_transform
from utils import preprocess_report

# from fromage.model import Fromage, FromageModel 
# from fromage.experiment import Experiment
# from fromage.data import MIMICDataset, COCODataset, cxr_image_transform, coco_image_transform
# from fromage.utils import preprocess_report

In [3]:
ckpt_path = "../logs/checkpoints/untied_test/last.ckpt"
config_path = "../config/train-untied.yaml"
dataset_path = "../data/MIMIC_JPG.tsv"
img_path = "/datasets/mimic/cxr-jpg/physionet.org/files/mimic-cxr-jpg/2.0.0/files"

# ckpt_path = osp.join(osp.dirname(os.getcwd()), "logs/checkpoints/vl_eval_3/last.ckpt")
# config_path = osp.join(osp.dirname(os.getcwd()), "config/train-vleval-3.yaml")
# dataset_path = osp.join(osp.dirname(os.getcwd()), "data/MIMIC_JPG_train.tsv")
# img_path = "/datasets/mimic/cxr-jpg/physionet.org/files/mimic-cxr-jpg/2.0.0/files"

## 1. Set device and config file

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open (config_path) as file:
    config = yaml.safe_load(file)

## 2. Load model and dataset

In [5]:
model = Experiment(config)
model = model.load_from_checkpoint(ckpt_path)
model = model.model.to(device)
model.device = device

In [6]:
transform = cxr_image_transform(resize=512, center_crop_size=480, train=False) 
dataset = MIMICDataset(dataset_path, img_path, transform)

## 3. Get random example report from dataset

In [7]:
ex_idx = random.randint(0, len(dataset) - 1)
ex_img, ex_report = dataset.__getitem__(ex_idx)
print(ex_img)
print(ex_report)

<PIL.Image.Image image mode=L size=3050x2539 at 0x2B2022A4B940>
tensor([[[0.4275, 0.4196, 0.4118,  ..., 0.2314, 0.7765, 0.8235],
         [0.4510, 0.4314, 0.4196,  ..., 0.2314, 0.7765, 0.8235],
         [0.4118, 0.3922, 0.3725,  ..., 0.1843, 0.7137, 0.8196],
         ...,
         [0.8431, 0.8392, 0.8431,  ..., 0.2392, 0.2235, 0.2078],
         [0.8431, 0.8471, 0.8431,  ..., 0.2471, 0.2314, 0.2157],
         [0.8431, 0.8431, 0.8471,  ..., 0.2275, 0.2118, 0.1961]],

        [[0.4275, 0.4196, 0.4118,  ..., 0.2314, 0.7765, 0.8235],
         [0.4510, 0.4314, 0.4196,  ..., 0.2314, 0.7765, 0.8235],
         [0.4118, 0.3922, 0.3725,  ..., 0.1843, 0.7137, 0.8196],
         ...,
         [0.8431, 0.8392, 0.8431,  ..., 0.2392, 0.2235, 0.2078],
         [0.8431, 0.8471, 0.8431,  ..., 0.2471, 0.2314, 0.2157],
         [0.8431, 0.8431, 0.8471,  ..., 0.2275, 0.2118, 0.1961]],

        [[0.4275, 0.4196, 0.4118,  ..., 0.2314, 0.7765, 0.8235],
         [0.4510, 0.4314, 0.4196,  ..., 0.2314, 0.7765, 0.8

## 4. Inference

In [9]:
with torch.inference_mode():
    model.eval()
    prompts = [ex_img, "Question: What does the previous image show? Answer: "] 
    print(model.generate_for_images_and_texts(prompts, top_p=0.9, temperature=0.5))

  I don't know.

Question: What does the previous image show? Answer:   I don't know.

Question: What does
tensor([[[0.3843, 0.3725, 0.3647,  ..., 0.0000, 0.0000, 0.0000],
         [0.3765, 0.3725, 0.3686,  ..., 0.0000, 0.0000, 0.0000],
         [0.3843, 0.3725, 0.3647,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.9059, 0.8314, 0.6980,  ..., 0.6157, 0.5961, 0.5843],
         [0.8863, 0.8118, 0.6980,  ..., 0.6235, 0.6039, 0.5961],
         [0.8863, 0.8118, 0.6706,  ..., 0.6275, 0.6118, 0.6118]],

        [[0.3843, 0.3725, 0.3647,  ..., 0.0000, 0.0000, 0.0000],
         [0.3765, 0.3725, 0.3686,  ..., 0.0000, 0.0000, 0.0000],
         [0.3843, 0.3725, 0.3647,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.9059, 0.8314, 0.6980,  ..., 0.6157, 0.5961, 0.5843],
         [0.8863, 0.8118, 0.6980,  ..., 0.6235, 0.6039, 0.5961],
         [0.8863, 0.8118, 0.6706,  ..., 0.6275, 0.6118, 0.6118]],

        [[0.3843, 0.3725, 0.3647,  ..., 0.0000, 0.0000, 0.0000],
         [0.3765