In [10]:
import transformers
import importlib
from fuyu import ai2d_data
import torch
from PIL import Image
from transformers import (
    Seq2SeqTrainer
)

# load model, tokenizer, and processor
pretrained_path = "adept/fuyu-8b"
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_path)
image_processor = transformers.FuyuImageProcessor()
processor = transformers.FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)


In [2]:
with torch.inference_mode():
    model = transformers.FuyuForCausalLM.from_pretrained("adept/fuyu-8b", device_map="cuda:0", torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [22]:
importlib.reload(ai2d_data)
ds = ai2d_data.AI2DDataset("/home/ubuntu/fuyu/ai2d")
dc = ai2d_data.DataCollatorForMultiModal(processor=processor, include_labels=True)

In [4]:
def clean(model_inputs):
    result = {}
    for k, v in model_inputs.items():
        if k == 'labels':
            continue
        tensor = v.to('cuda:0')
        if tensor.dtype == torch.float32:
            tensor = tensor.to(torch.bfloat16)
        result[k] = tensor
    return result

In [5]:
lora_module_names = set()
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        names = name.split('.')
        lora_module_names.add(names[0] if len(names)==1 else names[-1])

In [6]:
lora_module_names.remove('lm_head')
lora_module_names.remove('vision_embed_tokens')

In [7]:
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import LoraLayer

In [8]:
config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=list(lora_module_names),
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM"
    #modules_to_save=["classifier"],
)
peft = get_peft_model(model, config)

In [9]:
for name, module in peft.named_modules():
    if isinstance(module, LoraLayer):
        module = module.to(torch.bfloat16)
    if "norm" in name:
        module = module.to(torch.float32)

In [26]:
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    output_dir="test_trainer",
    per_device_train_batch_size=1,
    do_train=True,
    generation_config=None)

In [27]:
trainer = Seq2SeqTrainer(model=model,
                         args=training_args,
                         tokenizer=tokenizer,
                         train_dataset=ds,

                         data_collator=dc)

In [28]:
train_result = trainer.train()

[{}]


KeyError: 'image'