diff --git a/.github/workflows/_unittest.yml b/.github/workflows/_unittest.yml index a53026bc129..74ea5ca7bcc 100644 --- a/.github/workflows/_unittest.yml +++ b/.github/workflows/_unittest.yml @@ -37,6 +37,9 @@ jobs: CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" \ .ci/scripts/setup-linux.sh cmake + # Install llama3_2_vision dependencies. + PYTHON_EXECUTABLE=python ./examples/models/llama3_2_vision/install_requirements.sh + # Run pytest with coverage pytest -n auto --cov=./ --cov-report=xml # Run gtest @@ -67,6 +70,10 @@ jobs: ${CONDA_RUN} --no-capture-output \ .ci/scripts/setup-macos.sh cmake + # Install llama3_2_vision dependencies. + PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ + ./examples/models/llama3_2_vision/install_requirements.sh + # Run pytest with coverage ${CONDA_RUN} pytest -n auto --cov=./ --cov-report=xml # Run gtest diff --git a/examples/models/llama3_2_vision/preprocess/model.py b/examples/models/llama3_2_vision/preprocess/model.py index 7b3b4869af6..eb93a089d88 100644 --- a/examples/models/llama3_2_vision/preprocess/model.py +++ b/examples/models/llama3_2_vision/preprocess/model.py @@ -26,6 +26,9 @@ class PreprocessConfig: max_num_tiles: int = 4 tile_size: int = 224 antialias: bool = False + # Used for reference eager model from torchtune. + resize_to_max_canvas: bool = False + possible_resolutions: Optional[List[Tuple[int, int]]] = None class CLIPImageTransformModel(EagerModelBase): diff --git a/examples/models/llama3_2_vision/preprocess/test_preprocess.py b/examples/models/llama3_2_vision/preprocess/test_preprocess.py index 73a3fd29607..83a05877495 100644 --- a/examples/models/llama3_2_vision/preprocess/test_preprocess.py +++ b/examples/models/llama3_2_vision/preprocess/test_preprocess.py @@ -3,34 +3,31 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -import unittest - -from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Tuple import numpy as np import PIL +import pytest import torch +# Import these first. Otherwise, the custom ops are not registered. from executorch.extension.pybindings import portable_lib # noqa # usort: skip -from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip -from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import ( - export_preprocess, - get_example_inputs, - lower_to_executorch_preprocess, +from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa # usort: skip + +from executorch.examples.models.llama3_2_vision.preprocess.model import ( + CLIPImageTransformModel, + PreprocessConfig, ) + +from executorch.exir import EdgeCompileConfig, to_edge + from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) -from parameterized import parameterized from PIL import Image -from torchtune.models.clip.inference._transform import ( - _CLIPImageTransform, - CLIPImageTransform, -) +from torchtune.models.clip.inference._transform import CLIPImageTransform from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( find_supported_resolutions, @@ -40,24 +37,74 @@ from torchtune.modules.transforms.vision_utils.get_inscribed_size import ( get_inscribed_size, ) + from torchvision.transforms.v2 import functional as F -@dataclass -class PreprocessConfig: - image_mean: Optional[List[float]] = None - image_std: Optional[List[float]] = None - resize_to_max_canvas: bool = True - resample: str = "bilinear" - antialias: bool = False - tile_size: int = 224 - max_num_tiles: int = 4 - possible_resolutions = None +def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]: + config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas) + + reference_model = CLIPImageTransform( + image_mean=config.image_mean, + image_std=config.image_std, + resample=config.resample, + antialias=config.antialias, + tile_size=config.tile_size, + max_num_tiles=config.max_num_tiles, + resize_to_max_canvas=config.resize_to_max_canvas, + possible_resolutions=None, + ) + + model = CLIPImageTransformModel(config) + + exported_model = torch.export.export( + model.get_eager_model(), + model.get_example_inputs(), + dynamic_shapes=model.get_dynamic_shapes(), + strict=False, + ) + + # aoti_path = torch._inductor.aot_compile( + # exported_model.module(), + # model.get_example_inputs(), + # ) + + edge_program = to_edge( + exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ) + executorch_model = edge_program.to_executorch() + + return { + "config": config, + "reference_model": reference_model, + "model": model, + "exported_model": exported_model, + # "aoti_path": aoti_path, + "executorch_model": executorch_model, + } + + +# From https://github.com/pytorch/torchtune/blob/main/tests/test_utils.py#L231 +def assert_expected( + actual: Any, + expected: Any, + rtol: float = 1e-5, + atol: float = 1e-8, + check_device: bool = True, +): + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + check_device=check_device, + msg=f"actual: {actual}, expected: {expected}", + ) -class TestImageTransform(unittest.TestCase): +class TestImageTransform: """ - This unittest checks that the exported image transform model produces the + This test checks that the exported image transform model produces the same output as the reference model. Reference model: CLIPImageTransform @@ -66,7 +113,11 @@ class TestImageTransform(unittest.TestCase): https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26 """ - def setUp(self): + models_no_resize = initialize_models(resize_to_max_canvas=False) + models_resize = initialize_models(resize_to_max_canvas=True) + + @pytest.fixture(autouse=True) + def setup_function(self): np.random.seed(0) def prepare_inputs( @@ -121,51 +172,7 @@ def prepare_inputs( return image_tensor, inscribed_size, best_resolution - # This test setup mirrors the one in torchtune: - # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py - # The values are slightly different, as torchtune uses antialias=True, - # and this test uses antialias=False, which is exportable (has a portable kernel). - @parameterized.expand( - [ - ( - (100, 400, 3), # image_size - torch.Size([2, 3, 224, 224]), # expected shape - False, # resize_to_max_canvas - [0.2230, 0.1763], # expected_tile_means - [1.0, 1.0], # expected_tile_max - [0.0, 0.0], # expected_tile_min - [1, 2], # expected_aspect_ratio - ), - ( - (1000, 300, 3), # image_size - torch.Size([4, 3, 224, 224]), # expected shape - True, # resize_to_max_canvas - [0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means - [0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max - [0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min - [4, 1], # expected_aspect_ratio - ), - ( - (200, 200, 3), # image_size - torch.Size([4, 3, 224, 224]), # expected shape - True, # resize_to_max_canvas - [0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means - [0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max - [0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min - [2, 2], # expected_aspect_ratio - ), - ( - (600, 200, 3), # image_size - torch.Size([3, 3, 224, 224]), # expected shape - False, # resize_to_max_canvas - [0.4472, 0.4468, 0.3031], # expected_tile_means - [1.0, 1.0, 1.0], # expected_tile_max - [0.0, 0.0, 0.0], # expected_tile_min - [3, 1], # expected_aspect_ratio - ), - ] - ) - def test_preprocess( + def run_preprocess( self, image_size: Tuple[int], expected_shape: torch.Size, @@ -175,45 +182,7 @@ def test_preprocess( expected_tile_min: List[float], expected_ar: List[int], ) -> None: - config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas) - - reference_model = CLIPImageTransform( - image_mean=config.image_mean, - image_std=config.image_std, - resize_to_max_canvas=config.resize_to_max_canvas, - resample=config.resample, - antialias=config.antialias, - tile_size=config.tile_size, - max_num_tiles=config.max_num_tiles, - possible_resolutions=None, - ) - - eager_model = _CLIPImageTransform( - image_mean=config.image_mean, - image_std=config.image_std, - resample=config.resample, - antialias=config.antialias, - tile_size=config.tile_size, - max_num_tiles=config.max_num_tiles, - ) - - exported_model = export_preprocess( - image_mean=config.image_mean, - image_std=config.image_std, - resample=config.resample, - antialias=config.antialias, - tile_size=config.tile_size, - max_num_tiles=config.max_num_tiles, - ) - - executorch_model = lower_to_executorch_preprocess(exported_model) - executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer) - - aoti_path = torch._inductor.aot_compile( - exported_model.module(), - get_example_inputs(), - ) - + models = self.models_resize if resize_to_max_canvas else self.models_no_resize # Prepare image input. image = ( np.random.randint(0, 256, np.prod(image_size)) @@ -223,60 +192,129 @@ def test_preprocess( image = PIL.Image.fromarray(image) # Run reference model. + reference_model = models["reference_model"] reference_output = reference_model(image=image) reference_image = reference_output["image"] reference_ar = reference_output["aspect_ratio"].tolist() # Check output shape and aspect ratio matches expected values. - self.assertEqual(reference_image.shape, expected_shape) - self.assertEqual(reference_ar, expected_ar) + assert ( + reference_image.shape == expected_shape + ), f"Expected shape {expected_shape} but got {reference_image.shape}" + + assert ( + reference_ar == expected_ar + ), f"Expected ar {reference_ar} but got {expected_ar}" # Check pixel values within expected range [0, 1] - self.assertTrue(0 <= reference_image.min() <= reference_image.max() <= 1) + assert ( + 0 <= reference_image.min() <= reference_image.max() <= 1 + ), f"Expected pixel values in range [0, 1] but got {reference_image.min()} to {reference_image.max()}" # Check mean, max, and min values of the tiles match expected values. for i, tile in enumerate(reference_image): - self.assertAlmostEqual( - tile.mean().item(), expected_tile_means[i], delta=1e-4 + assert_expected( + tile.mean().item(), expected_tile_means[i], rtol=0, atol=1e-4 ) - self.assertAlmostEqual(tile.max().item(), expected_tile_max[i], delta=1e-4) - self.assertAlmostEqual(tile.min().item(), expected_tile_min[i], delta=1e-4) + assert_expected(tile.max().item(), expected_tile_max[i], rtol=0, atol=1e-4) + assert_expected(tile.min().item(), expected_tile_min[i], rtol=0, atol=1e-4) # Check num tiles matches the product of the aspect ratio. expected_num_tiles = reference_ar[0] * reference_ar[1] - self.assertEqual(expected_num_tiles, reference_image.shape[0]) + assert ( + expected_num_tiles == reference_image.shape[0] + ), f"Expected {expected_num_tiles} tiles but got {reference_image.shape[0]}" # Pre-work for eager and exported models. The reference model performs these # calculations and passes the result to _CLIPImageTransform, the exportable model. image_tensor, inscribed_size, best_resolution = self.prepare_inputs( - image=image, config=config + image=image, config=models["config"] ) # Run eager model and check it matches reference model. + eager_model = models["model"].get_eager_model() eager_image, eager_ar = eager_model( image_tensor, inscribed_size, best_resolution ) eager_ar = eager_ar.tolist() - self.assertTrue(torch.allclose(reference_image, eager_image)) - self.assertEqual(reference_ar, eager_ar) + assert_expected(eager_image, reference_image, rtol=0, atol=1e-4) + assert ( + reference_ar == eager_ar + ), f"Eager model: expected {reference_ar} but got {eager_ar}" # Run exported model and check it matches reference model. + exported_model = models["exported_model"] exported_image, exported_ar = exported_model.module()( image_tensor, inscribed_size, best_resolution ) exported_ar = exported_ar.tolist() - self.assertTrue(torch.allclose(reference_image, exported_image)) - self.assertEqual(reference_ar, exported_ar) + assert_expected(exported_image, reference_image, rtol=0, atol=1e-4) + assert ( + reference_ar == exported_ar + ), f"Exported model: expected {reference_ar} but got {exported_ar}" # Run executorch model and check it matches reference model. + executorch_model = models["executorch_model"] + executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer) et_image, et_ar = executorch_module.forward( (image_tensor, inscribed_size, best_resolution) ) - self.assertTrue(torch.allclose(reference_image, et_image)) - self.assertEqual(reference_ar, et_ar.tolist()) + assert_expected(et_image, reference_image, rtol=0, atol=1e-4) + assert ( + reference_ar == et_ar.tolist() + ), f"Executorch model: expected {reference_ar} but got {et_ar.tolist()}" # Run aoti model and check it matches reference model. - aoti_model = torch._export.aot_load(aoti_path, "cpu") - aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution) - self.assertTrue(torch.allclose(reference_image, aoti_image)) - self.assertEqual(reference_ar, aoti_ar.tolist()) + # aoti_path = models["aoti_path"] + # aoti_model = torch._export.aot_load(aoti_path, "cpu") + # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution) + # self.assertTrue(torch.allclose(reference_image, aoti_image)) + # self.assertEqual(reference_ar, aoti_ar.tolist()) + + # This test setup mirrors the one in torchtune: + # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py + # The values are slightly different, as torchtune uses antialias=True, + # and this test uses antialias=False, which is exportable (has a portable kernel). + def test_preprocess1(self): + self.run_preprocess( + (100, 400, 3), # image_size + torch.Size([2, 3, 224, 224]), # expected shape + False, # resize_to_max_canvas + [0.2230, 0.1763], # expected_tile_means + [1.0, 1.0], # expected_tile_max + [0.0, 0.0], # expected_tile_min + [1, 2], # expected_aspect_ratio + ) + + def test_preprocess2(self): + self.run_preprocess( + (1000, 300, 3), # image_size + torch.Size([4, 3, 224, 224]), # expected shape + True, # resize_to_max_canvas + [0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means + [0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max + [0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min + [4, 1], # expected_aspect_ratio + ) + + def test_preprocess3(self): + self.run_preprocess( + (200, 200, 3), # image_size + torch.Size([4, 3, 224, 224]), # expected shape + True, # resize_to_max_canvas + [0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means + [0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max + [0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min + [2, 2], # expected_aspect_ratio + ) + + def test_preprocess4(self): + self.run_preprocess( + (600, 200, 3), # image_size + torch.Size([3, 3, 224, 224]), # expected shape + False, # resize_to_max_canvas + [0.4472, 0.4468, 0.3031], # expected_tile_means + [1.0, 1.0, 1.0], # expected_tile_max + [0.0, 0.0, 0.0], # expected_tile_min + [3, 1], # expected_aspect_ratio + ) diff --git a/pytest.ini b/pytest.ini index 1ca39f0a508..3666c9c879a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,6 +16,7 @@ addopts = devtools/ # examples examples/models/llama/tests + examples/models/llama3_2_vision/preprocess # examples/models/llava/test TODO: enable this # exir exir/_serialize/test