In [1]:
from dataclasses import dataclass, field
from typing import Optional

import torch
from accelerate import Accelerator
from tqdm import tqdm
from transformers import AutoTokenizer

from trl import AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed

In [2]:
config = PPOConfig(
    model_name="./finetuned-checkpoints/biobart-base--mimic-cxr/checkpoint-19600",
    learning_rate=1.41e-5,
    log_with=None,
    mini_batch_size=4,
    batch_size=16,
    gradient_accumulation_steps=1,
    early_stopping=False,
    target_kl=6.0,
    kl_penalty="kl",
    seed=42,
    use_score_scaling=False,
    use_score_norm=False,
    score_clip=None,
)

tqdm.pandas()

## Load data and models

In [3]:
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

In [4]:
from pathlib import Path
import datasets
from datasets import Image
from torchvision import transforms

#dataset_config = 'mimic-cxr','mimic-iii'  
#split = 'train','validate',test
def build_dataset(dataset_config, tokenizer, split):
    def generate_image_path(line):
        return str(Path(data_path).joinpath(dataset_config).joinpath(line.strip().split(',')[0]))
    
    data_path = '/nfs/turbo/umms-vgvinodv/data/bioNLP23-Task-1B/data/'
    
    findings_file_path = Path(data_path).joinpath(dataset_config).joinpath(split+'.findings.tok')
    impression_file_path = Path(data_path).joinpath(dataset_config).joinpath(split+'.impression.tok')
    image_file_path = Path(data_path).joinpath(dataset_config).joinpath(split+'.image.tok')


    findings = [line.strip() for line in open(findings_file_path).readlines()]
    impression = [line.strip() for line in open(impression_file_path).readlines()]
    image_paths = [generate_image_path(line) for line in open(image_file_path).readlines()]
    
    dataset = datasets.Dataset.from_dict({"text":findings,"image":image_paths})
    
    def check_img_exists(example):
        return example["image"].split('/')[10] != 'p10'

    dataset = dataset.filter(check_img_exists)
    dataset = dataset.cast_column("image", Image())
    
    def tokenize(samples):
        input_text = ["summarize: "+text for text in samples["text"]]
        samples["input_ids"] = tokenizer(input_text)["input_ids"]
        return samples
    
    dataset = dataset.map(tokenize, batched=True, num_proc=4, remove_columns=["text"])
    
    normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    def image_transforms(samples):
        samples["query"] = [transform(image.convert("RGB").resize((384,384))) for image in samples["image"]]
        return samples
    
    dataset.set_transform(image_transforms)
    
    return dataset

In [5]:
dataset_config = "mimic-cxr"
tokenized_train_data = build_dataset(dataset_config,tokenizer,"train")
tokenized_eval_data = build_dataset(dataset_config,tokenizer,"validate")

Filter:   0%|          | 0/125417 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/113182 [00:00<?, ? examples/s]

Filter:   0%|          | 0/991 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/908 [00:00<?, ? examples/s]

In [6]:
tokenized_train_data[0].keys()

dict_keys(['input_ids', 'attention_mask', 'labels', 'images'])

In [7]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=config.model_name)

In [None]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

## Initialize PPOTrainer

In [8]:
#ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=tokenized_train_data, data_collator=data_collator)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=tokenized_train_data, data_collator=collator)

In [12]:
#def collator(data):
#    return dict((key, [d[key] for d in data]) for key in data[0])
#ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=tokenized_train_data, data_collator=collator)

## Load Reward Model

In [1]:
from models.reward import get_reward_model
reward_model = get_reward_model()

## Optimize Model

In [13]:
    
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    print(batch.keys())
    query_tensors = batch["input_ids"]
    image_tensors = batch["query"]
    
    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        response = ppo_trainer.generate(query)
        response_tensors.append(response.squeeze())
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    #batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

    #### Compute sentiment score
    outputs = [reward_model.predict_itm(image,text) for image,text in zip(image_tensors,batch["response"])]
    rewards = [torch.tensor(output) for output in outputs]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

0it [00:00, ?it/s]


dict_keys(['input_ids', 'attention_mask'])


KeyError: 'images'