Skip to content

[export][non-strict] Passing in kwargs torch.export fails at non_strict_utils.fakify() #140596

@ColinPeppler

Description

@ColinPeppler

🐛 Describe the bug

Repro

import logging
import math
import os

import requests

import torch
from PIL import Image
from torch.nn import functional as F
from transformers import BlipForConditionalGeneration, BlipProcessor

def main():
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    model = BlipForConditionalGeneration.from_pretrained(
        "Salesforce/blip-image-captioning-large"
    ).to("cuda")

    img_url = (
        "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
    )
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

    # conditional image captioning
    text = "a photography of"
    inputs = processor(raw_image, text, return_tensors="pt").to("cuda")

    out = model.generate(**inputs)
    print(processor.decode(out[0], skip_special_tokens=True))

    exported = torch.export.export(model, tuple(), inputs, strict=False)

Error

Traceback (most recent call last):
  ...
  File "scripts/henrylhtsang/repros/aot.py", line 50, in main
    exported = torch.export.export(model, tuple(), inputs, strict=False)
  File "torch/export/__init__.py", line 368, in export
    return _export(
  File "torch/export/_trace.py", line 1031, in wrapper
    raise e
  File "torch/export/_trace.py", line 1004, in wrapper
    ep = fn(*args, **kwargs)
  File "torch/export/exported_program.py", line 122, in wrapper
    return fn(*args, **kwargs)
  File "torch/export/_trace.py", line 1957, in _export
    export_artifact = export_func(  # type: ignore[operator]
  File "torch/export/_trace.py", line 1707, in _non_strict_export
    ) = make_fake_inputs(
  File "torch/_export/non_strict_utils.py", line 206, in make_fake_inputs
    fake_args, fake_kwargs = tree_map_with_path(
  File "torch/utils/_pytree.py", line 1608, in tree_map_with_path
    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
  File "torch/utils/_pytree.py", line 803, in unflatten
    leaves = list(leaves)
  File "torch/utils/_pytree.py", line 1608, in <genexpr>
    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
  File "torch/_export/non_strict_utils.py", line 207, in <lambda>
    lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
  File "torch/_export/non_strict_utils.py", line 96, in fakify
    raise ValueError(f"Unsupported input type {type(t)}")
ValueError: Unsupported input type <class 'transformers.image_processing_base.BatchFeature'>

Notes

  • If using strict=True, then it's fine.
  • If using strict=False and args instead of kwargs, then it's fine.
  • If using strict=False and use kwargs, then it fails.

Versions

n/a

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Labels

    empathy-dayLabel for issues from user empathy daysmodule: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingoncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions