From 915f7cc70f63b93eaf001db6241379a4093b2f2a Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sat, 13 Jan 2024 22:35:37 -0800 Subject: [PATCH 01/35] add openflamingo --- requirements.txt | 3 +- src/helm/benchmark/model_metadata_registry.py | 5 +- src/helm/benchmark/run_expander.py | 24 ++++ src/helm/benchmark/run_specs.py | 6 + src/helm/config/model_deployments.yaml | 8 ++ src/helm/config/model_metadata.yaml | 11 +- .../vision_language/open_flamingo_client.py | 125 ++++++++++++++++++ 7 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo_client.py diff --git a/requirements.txt b/requirements.txt index 54b87e47e1..0274dcdc73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -90,6 +90,7 @@ numba==0.56.4 numpy==1.23.3 openai==0.27.8 opencv-python==4.8.1.78 +open-flamingo==2.0.1 openpyxl==3.0.10 outcome==1.2.0 packaging==21.3 @@ -168,7 +169,7 @@ torchvision==0.13.1 ; sys_platform == "darwin" torch==1.12.1+cu113 ; sys_platform == "linux" torchvision==0.13.1+cu113 ; sys_platform == "linux" tqdm==4.64.1 -transformers==4.36.0 +transformers==4.28.1 trio==0.22.0 trio-websocket==0.9.2 typer==0.4.2 diff --git a/src/helm/benchmark/model_metadata_registry.py b/src/helm/benchmark/model_metadata_registry.py index a4b85edc81..9e5a19041b 100644 --- a/src/helm/benchmark/model_metadata_registry.py +++ b/src/helm/benchmark/model_metadata_registry.py @@ -56,10 +56,13 @@ IDEFICS_INSTRUCT_MODEL_TAG: str = "IDEFICS_INSTRUCT_MODEL_TAG" # Llava should use a special prompt format (see `LlavaRunExpander`) LLAVA_MODEL_TAG: str = "LLAVA_MODEL_TAG" - +# OpenFlamingo +OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG" # Frozen is set to false as the model_deployment_registry.py file # might populate the deployment_names field. + + @dataclass(frozen=False) class ModelMetadata: name: str diff --git a/src/helm/benchmark/run_expander.py b/src/helm/benchmark/run_expander.py index afb68dd1e8..35d7ff8583 100644 --- a/src/helm/benchmark/run_expander.py +++ b/src/helm/benchmark/run_expander.py @@ -420,6 +420,30 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]: ] +class OpenFlamingoRunExpander(RunExpander): + """ + Custom prompt for OpenFlamingo models. + See https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b for more information. + """ + + name = "open_flamingo" + + def expand(self, run_spec: RunSpec) -> List[RunSpec]: + return [ + replace( + run_spec, + name=run_spec.name, + adapter_spec=replace( + run_spec.adapter_spec, + input_prefix="", + input_suffix="", + output_prefix="", + output_suffix="", + ), + ), + ] + + class FormatPromptRunExpander(RunExpander): """Adds a prefix and suffix to the prompt.""" diff --git a/src/helm/benchmark/run_specs.py b/src/helm/benchmark/run_specs.py index b2d798007f..09c2132e2c 100644 --- a/src/helm/benchmark/run_specs.py +++ b/src/helm/benchmark/run_specs.py @@ -32,6 +32,7 @@ GoogleRunExpander, IDEFICSInstructRunExpander, LlavaRunExpander, + OpenFlamingoRunExpander, StopRunExpander, ChatMLRunExpander, IncreaseTemperatureRunExpander, @@ -66,6 +67,7 @@ GOOGLE_GEMINI_MODEL_TAG, IDEFICS_INSTRUCT_MODEL_TAG, LLAVA_MODEL_TAG, + OPEN_FLAMINGO_MODEL_TAG, NO_NEWLINES_TAG, NLG_PREFIX_TAG, CHATML_MODEL_TAG, @@ -3093,6 +3095,10 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec: if LLAVA_MODEL_TAG in model.tags: run_spec = singleton(LlavaRunExpander().expand(run_spec)) + # OpenFlamingo + if OPEN_FLAMINGO_MODEL_TAG in model.tags: + run_spec = singleton(OpenFlamingoRunExpander().expand(run_spec)) + # For multiple choice if BUGGY_TEMP_0_TAG in model.tags and run_spec.adapter_spec.temperature == 0: increase_temperature_expander = IncreaseTemperatureRunExpander(value=1e-4) diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index ceb9478ea1..582485e1fc 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -464,6 +464,14 @@ model_deployments: max_sequence_length: 2048 client_spec: class_name: "helm.proxy.clients.vision_language.huggingface_vlm_client.HuggingFaceVLMClient" + + ## OpenFlamingo + - name: openflamingo/OpenFlamingo-9B-vitl-mpt7b + model_name: openflamingo/OpenFlamingo-9B-vitl-mpt7b + tokenizer_name: anas-awadalla/mpt-7b + max_sequence_length: 2048 + client_spec: + class_name: "helm.proxy.clients.vision_language.huggingface_vlm_client.OpenFlamingoClient" ## Mistral AI - name: huggingface/bakLlava-v1-hf diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index 87f321e8ac..44d2bfa234 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -1092,7 +1092,16 @@ models: num_parameters: 13000000000 release_date: 2023-10-05 tags: [VISION_LANGUAGE_MODEL_TAG, LLAVA_MODEL_TAG] - + + + - name: openflamingo/OpenFlamingo-9B-vitl-mpt7b + display_name: OpenFlamingo (9B) + description: OpenFlamingo is an open source implementation of DeepMind's Flamingo models. This 9B-parameter model uses a CLIP ViT-L/14 vision encoder and MPT-7B language model. ([paper](https://arxiv.org/abs/2308.01390)) + creator_organization_name: OpenFlamingo + access: open + num_parameters: 9000000000 + release_date: 2023-08-02 + tags: [VISION_LANGUAGE_MODEL_TAG, OPEN_FLAMINGO_MODEL_TAG] # 01.AI diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py new file mode 100644 index 0000000000..7e6e7468e6 --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -0,0 +1,125 @@ +from threading import Lock +from typing import List + +import torch +from huggingface_hub import hf_hub_download +from open_flamingo import create_model_and_transforms + +from helm.common.cache import CacheConfig +from helm.common.images_utils import open_image +from helm.common.gpu_utils import get_torch_device_name +from helm.common.media_object import TEXT_TYPE +from helm.common.optional_dependencies import handle_module_not_found_error +from helm.common.request import Request, RequestResult, Sequence, Token +from helm.common.request import wrap_request_time +from helm.proxy.clients.client import CachingClient, generate_uid_for_multimodal_prompt + +try: + from PIL import Image +except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["images"]) + + +class OpenFlamingoClient(CachingClient): + """ + OpenFlamingo is an open source implementation of DeepMind's Flamingo models. + https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b + """ + + END_OF_CHUNK_TOKEN: str = "<|endofchunk|>" + IMAGE_TOKEN: str = "" + + _model_lock: Lock = Lock() + + def __init__(self, cache_config: CacheConfig): + super().__init__(cache_config) + self._device: str = get_torch_device_name() + self._get_model() + + def _get_model(self): + with self._model_lock: + self._model, self.image_processor, self.tokenizer = create_model_and_transforms( + clip_vision_encoder_path="ViT-L-14", + clip_vision_encoder_pretrained="openai", + lang_encoder_path="anas-awadalla/mpt-7b", + tokenizer_path="anas-awadalla/mpt-7b", + cross_attn_every_n_layers=4, + ) + self.tokenizer.padding_side = "left" + checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt") + self._model.load_state_dict(torch.load(checkpoint_path), strict=False) + self._model = self._model.to(self._device) + + def make_request(self, request: Request) -> RequestResult: + assert request.multimodal_prompt is not None, "Multimodal prompt is required" + + # Build the prompt + prompt_text: str = "" + images: List[Image.Image] = [] + for media_object in request.multimodal_prompt.media_objects: + if media_object.is_type("image") and media_object.location: + images.append(open_image(media_object.location)) + prompt_text += self.IMAGE_TOKEN + elif media_object.is_type(TEXT_TYPE): + if media_object.text is None: + raise ValueError("MediaObject of text type has missing text field value") + prompt_text += media_object.text + self.END_OF_CHUNK_TOKEN + else: + raise ValueError(f"Unrecognized MediaObject type {media_object.type}") + + # Preprocess + vision_x: torch.Tensor = torch.cat([self.image_processor(image).unsqueeze(0) for image in images], dim=0) + vision_x = vision_x.unsqueeze(1).unsqueeze(0) + + lang_x = self.tokenizer( + [prompt_text], + return_tensors="pt", + ) + + # Generate + try: + generation_args = { + "max_new_tokens": request.max_tokens, + "num_beams": 1, + } + + def do_it(): + generated_text: str = self._model.generate( + vision_x=vision_x.to(self._device), + lang_x=lang_x["input_ids"].to(self._device), + attention_mask=lang_x["attention_mask"].to(self._device), + max_new_tokens=generation_args["max_new_tokens"], + num_beams=generation_args["num_beams"], + ) + generated_text = self.tokenizer.decode(generated_text[0]) + assert generated_text.startswith( + prompt_text + ), f"Generated text: {generated_text} does not start with prompt: {prompt_text}" + + # Remove the prompt from the generated text + generated_text = generated_text[len(prompt_text) :].strip() + return {"output": generated_text} + + cache_key = CachingClient.make_cache_key( + raw_request={ + "model": request.model, + "prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt), + **generation_args, + }, + request=request, + ) + result, cached = self.cache.get(cache_key, wrap_request_time(do_it)) + except RuntimeError as e: + return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[]) + + tokens: List[Token] = [ + Token(text=str(self.tokenizer.decode(id)), logprob=0, top_logprobs={}) for id in lang_x["input_ids"][0] + ] + completions: List[Sequence] = [Sequence(text=result["generated_text"], logprob=0, tokens=tokens)] + return RequestResult( + success=True, + cached=cached, + request_time=result["request_time"], + completions=completions, + embedding=[], + ) From 367aa30b5f4b93dfd881ef067296f96831f06e52 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sat, 13 Jan 2024 23:35:21 -0800 Subject: [PATCH 02/35] fix ver --- requirements.txt | 4 ++-- setup.cfg | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0274dcdc73..5733d4f6bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -166,8 +166,8 @@ toml==0.10.2 tomli==2.0.1 torch==1.12.1 ; sys_platform == "darwin" torchvision==0.13.1 ; sys_platform == "darwin" -torch==1.12.1+cu113 ; sys_platform == "linux" -torchvision==0.13.1+cu113 ; sys_platform == "linux" +torch==2.0.1+cu118 ; sys_platform == "linux" +torchvision==0.15.2+cu118 ; sys_platform == "linux" tqdm==4.64.1 transformers==4.28.1 trio==0.22.0 diff --git a/setup.cfg b/setup.cfg index 09157cf6c8..702199af5a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ install_requires= scikit-learn~=1.1.2 # Models and Metrics Extras - transformers~=4.36.0 # For anthropic_client, vision_language.huggingface_vlm_client, huggingface_client, huggingface_tokenizer, test_openai_token_cost_estimator, model_summac (via summarization_metrics) + transformers>=4.28.0 # For anthropic_client, vision_language.huggingface_vlm_client, huggingface_client, huggingface_tokenizer, test_openai_token_cost_estimator, model_summac (via summarization_metrics) # TODO: Upgrade torch - we need > 2.0.0 for newer versions of transformers torch>=1.12.1,<3.0.0 # For huggingface_client, yalm_tokenizer, model_summac (via summarization_metrics) torchvision>=0.13.1,<3.0.0 # For huggingface_client, yalm_tokenizer, model_summac (via summarization_metrics) From 799f41f0258b921d368574237b158826620d8e61 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sat, 13 Jan 2024 23:44:45 -0800 Subject: [PATCH 03/35] fix ver --- install-dev.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/install-dev.sh b/install-dev.sh index 845d0dced4..826091fa46 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -7,7 +7,7 @@ set -e # On Mac OS, skip installing pytorch with CUDA because CUDA is not supported if [[ $OSTYPE != 'darwin'* ]]; then # Manually install pytorch to avoid pip getting killed: https://stackoverflow.com/a/54329850 - pip install --no-cache-dir --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.12.1+cu113 torchvision==0.13.1+cu113 + pip install --no-cache-dir --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.0.1+cu118 torchvision==0.15.2+cu118 fi # Install all pinned dependencies pip install -r requirements.txt From 0b7bcb42889eb3af08f15d27c28227f7e3cd6034 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sat, 13 Jan 2024 23:49:20 -0800 Subject: [PATCH 04/35] fix ver --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5733d4f6bc..5d6a5ba25d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ beautifulsoup4==4.11.1 bert-score==0.3.13 bitarray==2.7.3 black==22.10.0 -blanc==0.2.7 +blanc==0.3.4 blis==0.7.8 boto3==1.24.89 botocore==1.27.89 From 70a8e348426bbd55dc1ca8f2728984af72f8bf4c Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sat, 13 Jan 2024 23:54:09 -0800 Subject: [PATCH 05/35] fix ver --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5d6a5ba25d..6b941fda68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,6 @@ beautifulsoup4==4.11.1 bert-score==0.3.13 bitarray==2.7.3 black==22.10.0 -blanc==0.3.4 blis==0.7.8 boto3==1.24.89 botocore==1.27.89 From 1eef0a54ce07467445dc75f8602ef506ba5f8952 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sat, 13 Jan 2024 23:57:15 -0800 Subject: [PATCH 06/35] fix ver --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6b941fda68..0f580c2e34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -139,7 +139,7 @@ scaleapi==2.13.0 scikit-learn==1.1.2 scipy==1.10.0 selenium==4.8.0 -sentencepiece==0.1.97 +sentencepiece==0.1.98 simple-slurm==0.2.6 six==1.16.0 smart-open==5.2.1 From abe19e318d80162f87a6a5d09a7a62f047ef1200 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 00:11:28 -0800 Subject: [PATCH 07/35] fix ver --- requirements.txt | 86 ++++++++++++++++++++++++------------------------ 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0f580c2e34..3a8a8e80fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ 2captcha-python==1.1.3 -absl-py==1.2.0 +absl-py==2.0.0 accelerate==0.25.0 aiodns==3.0.0 aiohttp==3.8.5 aiohttp-retry==2.8.3 -aiosignal==1.2.0 +aiosignal==1.3.1 aleph-alpha-client==2.14.0 anthropic==0.2.5 async-generator==1.10 -async-timeout==4.0.2 -attrs==22.1.0 +async-timeout==4.0.3 +attrs==23.2.0 beautifulsoup4==4.11.1 bert-score==0.3.13 bitarray==2.7.3 @@ -17,11 +17,11 @@ black==22.10.0 blis==0.7.8 boto3==1.24.89 botocore==1.27.89 -bottle==0.12.23 -cachetools==5.2.0 -catalogue==2.0.8 +bottle==0.12.25 +cachetools==5.3.2 +catalogue==2.0.10 cattrs==22.2.0 -certifi==2023.7.22 +certifi==2023.11.17 cffi==1.15.1 cfgv==3.3.1 charset-normalizer==2.1.1 @@ -29,7 +29,7 @@ click==8.0.4 colorama==0.4.5 contourpy==1.0.5 cycler==0.11.0 -cymem==2.0.6 +cymem==2.0.8 Cython==0.29.32 dacite==1.6.0 datasets==2.5.2 @@ -77,55 +77,55 @@ matplotlib==3.6.0 mccabe==0.7.0 moverscore==1.0.3 mpmath==1.3.0 -multidict==6.0.2 +multidict==6.0.4 multiprocess==0.70.13 -murmurhash==1.0.8 +murmurhash==1.0.10 mypy==1.5.1 mypy-extensions==1.0.0 -networkx==2.8.7 -nltk==3.7 +networkx==3.1 +nltk==3.8.1 nodeenv==1.7.0 numba==0.56.4 -numpy==1.23.3 +numpy==1.23.5 openai==0.27.8 opencv-python==4.8.1.78 open-flamingo==2.0.1 openpyxl==3.0.10 outcome==1.2.0 packaging==21.3 -pandas==1.5.0 +pandas==2.0.3 pandas-stubs==1.5.0.221003 -parameterized==0.8.1 +parameterized==0.9.0 pathspec==0.10.1 pathy==0.10.2 -Pillow==9.3.0 -platformdirs==2.5.2 +pillow==10.2.0 +platformdirs==3.5.0 pluggy==1.0.0 portalocker==2.5.1 pre-commit==2.20.0 -preshed==3.0.7 -protobuf==3.20.2 -psutil==5.9.2 -pyarrow==11.0.0 -pyasn1==0.4.8 -pyasn1-modules==0.2.8 +preshed==3.0.9 +protobuf==4.25.2 +psutil==5.9.7 +pyarrow==14.0.2 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 pycares==4.3.0 pycodestyle==2.9.1 pycparser==2.21 -pydantic==1.8.2 +pydantic==1.10.13 pyemd==0.5.1 pyext==0.7 pyflakes==2.5.0 pyhocon==0.3.59 pymongo==4.2.0 -pyparsing==2.4.7 +pyparsing==3.1.1 PySocks==1.7.1 pytest==7.2.0 python-dateutil==2.8.2 pytorch-pretrained-bert==0.6.2 pytrec-eval==0.5 -pytz==2022.4 -PyYAML==6.0 +pytz==2023.3.post1 +PyYAML==6.0.1 regex==2022.9.13 requests==2.31.0 responses==0.18.0 @@ -136,28 +136,28 @@ s3transfer==0.6.0 sacrebleu==2.2.1 sacremoses==0.0.53 scaleapi==2.13.0 -scikit-learn==1.1.2 -scipy==1.10.0 +scikit-learn==1.1.3 +scipy==1.10.1 selenium==4.8.0 sentencepiece==0.1.98 simple-slurm==0.2.6 six==1.16.0 -smart-open==5.2.1 +smart-open==6.4.0 sniffio==1.3.0 sortedcontainers==2.4.0 soupsieve==2.3.2.post1 spacy==3.5.4 spacy-legacy==3.0.12 -spacy-loggers==1.0.3 +spacy-loggers==1.0.5 sqlitedict==1.7.0 -srsly==2.4.4 +srsly==2.4.8 stanza==1.4.2 summ-eval==0.892 surge-api==1.1.0 -sympy==1.11.1 +sympy==1.12 tabulate==0.9.0 thinc==8.1.12 -threadpoolctl==3.1.0 +threadpoolctl==3.2.0 tiktoken==0.3.3 tls-client==0.1.8 tokenizers>=0.13.3 @@ -171,28 +171,28 @@ tqdm==4.64.1 transformers==4.28.1 trio==0.22.0 trio-websocket==0.9.2 -typer==0.4.2 +typer==0.9.0 types-Pillow==9.3.0.4 types-pytz==2022.4.0.0 types-redis==4.3.21.1 -types-requests==2.28.11.2 +types-requests==2.31.0.20240106 types-tabulate==0.9.0.0 types-urllib3==1.26.25 typing==3.7.4.3 -typing_extensions==4.4.0 +typing_extensions==4.9.0 uncertainty-calibration==0.1.4 undetected-chromedriver==3.2.1 uritemplate==4.1.1 -urllib3==1.26.12 +urllib3==2.1.0 virtualenv==20.16.5 -wasabi==0.10.1 +wasabi==1.1.2 websocket-client==1.3.2 websockets==10.4 wsproto==1.2.0 xlrd==2.0.1 -xxhash==3.0.0 -yarl==1.8.1 -zipp==3.11.0 +xxhash==3.4.1 +yarl==1.9.4 +zipp==3.17.0 zope.event==4.5.0 zope.interface==5.4.0 zstandard==0.18.0 From b221cbffee58fd1a026be5ab352d04b636f909c0 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 00:15:15 -0800 Subject: [PATCH 08/35] fix ver --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3a8a8e80fd..a61d6ebbd4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,6 @@ bitarray==2.7.3 black==22.10.0 blis==0.7.8 boto3==1.24.89 -botocore==1.27.89 bottle==0.12.25 cachetools==5.3.2 catalogue==2.0.10 From 63011e7b49e2c639abc2529faba6aa51ff62ca87 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 00:19:18 -0800 Subject: [PATCH 09/35] fix ver --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index a61d6ebbd4..f58edd1c33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,7 +54,7 @@ googleapis-common-protos==1.56.4 greenlet==1.1.3 gunicorn==20.1.0 h11==0.14.0 -httplib2==0.20.4 +httplib2==0.22.0 huggingface-hub>=0.15.1 icetk==0.0.4 identify==2.5.6 @@ -91,7 +91,7 @@ opencv-python==4.8.1.78 open-flamingo==2.0.1 openpyxl==3.0.10 outcome==1.2.0 -packaging==21.3 +packaging==23.2 pandas==2.0.3 pandas-stubs==1.5.0.221003 parameterized==0.9.0 @@ -115,7 +115,7 @@ pydantic==1.10.13 pyemd==0.5.1 pyext==0.7 pyflakes==2.5.0 -pyhocon==0.3.59 +pyhocon==0.3.60 pymongo==4.2.0 pyparsing==3.1.1 PySocks==1.7.1 From 7729aca9d7cbc4440d0d08eec65ab101e14d3e5f Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 00:51:04 -0800 Subject: [PATCH 10/35] add openflamingo --- install-dev.sh | 2 +- requirements.txt | 105 +++--- setup.cfg | 3 +- .../vision_language/open_flamingo/__init__.py | 2 + .../open_flamingo/src/__init__.py | 0 .../open_flamingo/src/factory.py | 137 +++++++ .../open_flamingo/src/flamingo.py | 333 ++++++++++++++++++ .../open_flamingo/src/flamingo_lm.py | 149 ++++++++ .../open_flamingo/src/helpers.py | 267 ++++++++++++++ .../open_flamingo/src/utils.py | 42 +++ .../vision_language/open_flamingo_client.py | 2 +- 11 files changed, 987 insertions(+), 55 deletions(-) create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/__init__.py create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/src/__init__.py create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/src/helpers.py create mode 100644 src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py diff --git a/install-dev.sh b/install-dev.sh index 826091fa46..845d0dced4 100755 --- a/install-dev.sh +++ b/install-dev.sh @@ -7,7 +7,7 @@ set -e # On Mac OS, skip installing pytorch with CUDA because CUDA is not supported if [[ $OSTYPE != 'darwin'* ]]; then # Manually install pytorch to avoid pip getting killed: https://stackoverflow.com/a/54329850 - pip install --no-cache-dir --find-links https://download.pytorch.org/whl/torch_stable.html torch==2.0.1+cu118 torchvision==0.15.2+cu118 + pip install --no-cache-dir --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.12.1+cu113 torchvision==0.13.1+cu113 fi # Install all pinned dependencies pip install -r requirements.txt diff --git a/requirements.txt b/requirements.txt index f58edd1c33..95358866d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,28 @@ 2captcha-python==1.1.3 -absl-py==2.0.0 +absl-py==1.2.0 accelerate==0.25.0 aiodns==3.0.0 aiohttp==3.8.5 aiohttp-retry==2.8.3 -aiosignal==1.3.1 +aiosignal==1.2.0 aleph-alpha-client==2.14.0 anthropic==0.2.5 async-generator==1.10 -async-timeout==4.0.3 -attrs==23.2.0 +async-timeout==4.0.2 +attrs==22.1.0 beautifulsoup4==4.11.1 bert-score==0.3.13 bitarray==2.7.3 black==22.10.0 +blanc==0.2.7 blis==0.7.8 boto3==1.24.89 -bottle==0.12.25 -cachetools==5.3.2 -catalogue==2.0.10 +botocore==1.27.89 +bottle==0.12.23 +cachetools==5.2.0 +catalogue==2.0.8 cattrs==22.2.0 -certifi==2023.11.17 +certifi==2023.7.22 cffi==1.15.1 cfgv==3.3.1 charset-normalizer==2.1.1 @@ -28,7 +30,7 @@ click==8.0.4 colorama==0.4.5 contourpy==1.0.5 cycler==0.11.0 -cymem==2.0.8 +cymem==2.0.6 Cython==0.29.32 dacite==1.6.0 datasets==2.5.2 @@ -54,7 +56,7 @@ googleapis-common-protos==1.56.4 greenlet==1.1.3 gunicorn==20.1.0 h11==0.14.0 -httplib2==0.22.0 +httplib2==0.20.4 huggingface-hub>=0.15.1 icetk==0.0.4 identify==2.5.6 @@ -76,55 +78,54 @@ matplotlib==3.6.0 mccabe==0.7.0 moverscore==1.0.3 mpmath==1.3.0 -multidict==6.0.4 +multidict==6.0.2 multiprocess==0.70.13 -murmurhash==1.0.10 +murmurhash==1.0.8 mypy==1.5.1 mypy-extensions==1.0.0 -networkx==3.1 -nltk==3.8.1 +networkx==2.8.7 +nltk==3.7 nodeenv==1.7.0 numba==0.56.4 -numpy==1.23.5 +numpy==1.23.3 openai==0.27.8 opencv-python==4.8.1.78 -open-flamingo==2.0.1 openpyxl==3.0.10 outcome==1.2.0 -packaging==23.2 -pandas==2.0.3 +packaging==21.3 +pandas==1.5.0 pandas-stubs==1.5.0.221003 -parameterized==0.9.0 +parameterized==0.8.1 pathspec==0.10.1 pathy==0.10.2 -pillow==10.2.0 -platformdirs==3.5.0 +Pillow==9.3.0 +platformdirs==2.5.2 pluggy==1.0.0 portalocker==2.5.1 pre-commit==2.20.0 -preshed==3.0.9 -protobuf==4.25.2 -psutil==5.9.7 -pyarrow==14.0.2 -pyasn1==0.5.1 -pyasn1-modules==0.3.0 +preshed==3.0.7 +protobuf==3.20.2 +psutil==5.9.2 +pyarrow==11.0.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 pycares==4.3.0 pycodestyle==2.9.1 pycparser==2.21 -pydantic==1.10.13 +pydantic==1.8.2 pyemd==0.5.1 pyext==0.7 pyflakes==2.5.0 -pyhocon==0.3.60 +pyhocon==0.3.59 pymongo==4.2.0 -pyparsing==3.1.1 +pyparsing==2.4.7 PySocks==1.7.1 pytest==7.2.0 python-dateutil==2.8.2 pytorch-pretrained-bert==0.6.2 pytrec-eval==0.5 -pytz==2023.3.post1 -PyYAML==6.0.1 +pytz==2022.4 +PyYAML==6.0 regex==2022.9.13 requests==2.31.0 responses==0.18.0 @@ -135,28 +136,28 @@ s3transfer==0.6.0 sacrebleu==2.2.1 sacremoses==0.0.53 scaleapi==2.13.0 -scikit-learn==1.1.3 -scipy==1.10.1 +scikit-learn==1.1.2 +scipy==1.10.0 selenium==4.8.0 -sentencepiece==0.1.98 +sentencepiece==0.1.97 simple-slurm==0.2.6 six==1.16.0 -smart-open==6.4.0 +smart-open==5.2.1 sniffio==1.3.0 sortedcontainers==2.4.0 soupsieve==2.3.2.post1 spacy==3.5.4 spacy-legacy==3.0.12 -spacy-loggers==1.0.5 +spacy-loggers==1.0.3 sqlitedict==1.7.0 -srsly==2.4.8 +srsly==2.4.4 stanza==1.4.2 summ-eval==0.892 surge-api==1.1.0 -sympy==1.12 +sympy==1.11.1 tabulate==0.9.0 thinc==8.1.12 -threadpoolctl==3.2.0 +threadpoolctl==3.1.0 tiktoken==0.3.3 tls-client==0.1.8 tokenizers>=0.13.3 @@ -164,35 +165,35 @@ toml==0.10.2 tomli==2.0.1 torch==1.12.1 ; sys_platform == "darwin" torchvision==0.13.1 ; sys_platform == "darwin" -torch==2.0.1+cu118 ; sys_platform == "linux" -torchvision==0.15.2+cu118 ; sys_platform == "linux" +torch==1.12.1+cu113 ; sys_platform == "linux" +torchvision==0.13.1+cu113 ; sys_platform == "linux" tqdm==4.64.1 -transformers==4.28.1 +transformers==4.36.0 trio==0.22.0 trio-websocket==0.9.2 -typer==0.9.0 +typer==0.4.2 types-Pillow==9.3.0.4 types-pytz==2022.4.0.0 types-redis==4.3.21.1 -types-requests==2.31.0.20240106 +types-requests==2.28.11.2 types-tabulate==0.9.0.0 types-urllib3==1.26.25 typing==3.7.4.3 -typing_extensions==4.9.0 +typing_extensions==4.4.0 uncertainty-calibration==0.1.4 undetected-chromedriver==3.2.1 uritemplate==4.1.1 -urllib3==2.1.0 +urllib3==1.26.12 virtualenv==20.16.5 -wasabi==1.1.2 +wasabi==0.10.1 websocket-client==1.3.2 websockets==10.4 wsproto==1.2.0 xlrd==2.0.1 -xxhash==3.4.1 -yarl==1.9.4 -zipp==3.17.0 +xxhash==3.0.0 +yarl==1.8.1 +zipp==3.11.0 zope.event==4.5.0 zope.interface==5.4.0 zstandard==0.18.0 -fairlearn==0.9.0 +fairlearn==0.9.0 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 702199af5a..2537702cad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -223,6 +223,7 @@ exclude = venv/* src/helm/proxy/clients/image_generation/dalle_mini/* src/helm/proxy/clients/image_generation/mindalle/* + src/helm/proxy/clients/vision_language/open_flamingo/* # Ignore completely: # E203 - White space before ':', (conflicts with black) @@ -240,7 +241,7 @@ check_untyped_defs = True disable_error_code = annotation-unchecked # TODO: Change disallow_untyped_defs to True disallow_untyped_defs = False -exclude = dalle_mini|mindalle +exclude = dalle_mini|mindalle|open_flamingo [tool:pytest] addopts = diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/__init__.py b/src/helm/proxy/clients/vision_language/open_flamingo/__init__.py new file mode 100644 index 0000000000..ab67750bb7 --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo/__init__.py @@ -0,0 +1,2 @@ +from .src.flamingo import Flamingo +from .src.factory import create_model_and_transforms diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/__init__.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py new file mode 100644 index 0000000000..5f5fadff21 --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py @@ -0,0 +1,137 @@ +from typing import Optional + +from transformers import AutoModelForCausalLM, AutoTokenizer +import open_clip + +from .flamingo import Flamingo +from .flamingo_lm import FlamingoLMMixin +from .utils import extend_instance + + +def create_model_and_transforms( + clip_vision_encoder_path: str, + clip_vision_encoder_pretrained: str, + lang_encoder_path: str, + tokenizer_path: str, + cross_attn_every_n_layers: int = 1, + use_local_files: bool = False, + decoder_layers_attr_name: str = None, + freeze_lm_embeddings: bool = False, + cache_dir: Optional[str] = None, + **flamingo_kwargs, +): + """ + Initialize a Flamingo model from a pretrained vision encoder and language encoder. + Appends special tokens to the tokenizer and freezes backbones. + + Args: + clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") + clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") + lang_encoder_path (str): path to pretrained language encoder + tokenizer_path (str): path to pretrained tokenizer + cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. + use_local_files (bool, optional): whether to use local files. Defaults to False. + decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. + freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. + cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. + Returns: + Flamingo: Flamingo model from pretrained vision and language encoders + Image processor: Pipeline to preprocess input images + Tokenizer: A tokenizer for the language model + """ + vision_encoder, _, image_processor = open_clip.create_model_and_transforms( + clip_vision_encoder_path, + pretrained=clip_vision_encoder_pretrained, + cache_dir=cache_dir, + ) + # set the vision encoder to output the visual features + vision_encoder.visual.output_tokens = True + + text_tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + local_files_only=use_local_files, + trust_remote_code=True, + cache_dir=cache_dir, + ) + # add Flamingo special tokens to the tokenizer + text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", ""]}) + if text_tokenizer.pad_token is None: + # Issue: GPT models don't have a pad token, which we use to + # modify labels for the loss. + text_tokenizer.add_special_tokens({"pad_token": ""}) + + lang_encoder = AutoModelForCausalLM.from_pretrained( + lang_encoder_path, + local_files_only=use_local_files, + trust_remote_code=True, + cache_dir=cache_dir, + ) + + # hacks for MPT-1B, which doesn't have a get_input_embeddings method + if "mpt-1b-redpajama-200b" in lang_encoder_path: + + class EmbeddingFnMixin: + def get_input_embeddings(self): + return self.transformer.wte + + def set_input_embeddings(self, new_embeddings): + self.transformer.wte = new_embeddings + + extend_instance(lang_encoder, EmbeddingFnMixin) + + # convert LM to FlamingoLM + extend_instance(lang_encoder, FlamingoLMMixin) + + if decoder_layers_attr_name is None: + decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) + lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) + lang_encoder.resize_token_embeddings(len(text_tokenizer)) + + model = Flamingo( + vision_encoder, + lang_encoder, + text_tokenizer.encode("<|endofchunk|>")[-1], + text_tokenizer.encode("")[-1], + vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"], + cross_attn_every_n_layers=cross_attn_every_n_layers, + **flamingo_kwargs, + ) + + # Freeze all parameters + model.requires_grad_(False) + assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 + + # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings + model.perceiver.requires_grad_(True) + model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) + if not freeze_lm_embeddings: + model.lang_encoder.get_input_embeddings().requires_grad_(True) + # TODO: investigate also training the output embeddings when untied + + print( + f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" + ) + + return model, image_processor, text_tokenizer + + +def _infer_decoder_layers_attr_name(model): + for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: + if k.lower() in model.__class__.__name__.lower(): + return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] + + raise ValueError( + f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." + ) + + +__KNOWN_DECODER_LAYERS_ATTR_NAMES = { + "opt": "model.decoder.layers", + "gptj": "transformer.h", + "gpt-j": "transformer.h", + "pythia": "gpt_neox.layers", + "llama": "model.layers", + "gptneoxforcausallm": "gpt_neox.layers", + "mpt": "transformer.blocks", + "mosaicgpt": "transformer.blocks", +} diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py new file mode 100644 index 0000000000..7c29061342 --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py @@ -0,0 +1,333 @@ +import torch +from einops import rearrange +from torch import nn +from .helpers import PerceiverResampler +from torch.distributed.fsdp.wrap import ( + enable_wrap, + wrap, +) +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) + +from .utils import apply_with_stopping_condition + + +class Flamingo(nn.Module): + def __init__( + self, + vision_encoder: nn.Module, + lang_encoder: nn.Module, + eoc_token_id: int, + media_token_id: int, + vis_dim: int, + cross_attn_every_n_layers: int = 1, + gradient_checkpointing: bool = False, + ): + """ + Args: + vision_encoder (nn.Module): HF CLIPModel + lang_encoder (nn.Module): HF causal language model + eoc_token_id (int): Token id for <|endofchunk|> + media_token_id (int): Token id for + vis_dim (int): Dimension of the visual features. + Visual features are projected to match this shape along the last dimension. + cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. + """ + super().__init__() + self.eoc_token_id = eoc_token_id + self.media_token_id = media_token_id + self.vis_dim = vis_dim + if hasattr(lang_encoder.config, "d_model"): + self.lang_dim = lang_encoder.config.d_model # mpt uses d_model + else: + self.lang_dim = lang_encoder.config.hidden_size + + self.vision_encoder = vision_encoder.visual + self.perceiver = PerceiverResampler(dim=self.vis_dim) + self.lang_encoder = lang_encoder + self.lang_encoder.init_flamingo( + media_token_id=media_token_id, + lang_hidden_size=self.lang_dim, + vis_hidden_size=self.vis_dim, + cross_attn_every_n_layers=cross_attn_every_n_layers, + gradient_checkpointing=gradient_checkpointing, + ) + self._use_gradient_checkpointing = gradient_checkpointing + self.perceiver._use_gradient_checkpointing = gradient_checkpointing + + def forward( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + labels: torch.Tensor = None, + clear_conditioned_layers: bool = True, + past_key_values=None, + use_cache: bool = False, + ): + """ + Forward pass of Flamingo. + + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) with F=1 + lang_x (torch.Tensor): Language input ids + shape (B, T_txt) + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + labels (torch.Tensor, optional): Labels. Defaults to None. + clear_conditioned_layers: if True, clear the conditioned layers + once the foward pass is completed. Set this to false if the + same set of images will be reused in another subsequent + forward pass. + past_key_values: pre-computed values to pass to language model. + See past_key_values documentation in Hugging Face + CausalLM models. + use_cache: whether to use cached key values. See use_cache + documentation in Hugging Face CausalLM models. + """ + assert ( + self.lang_encoder.initialized_flamingo + ), "Flamingo layers are not initialized. Please call `init_flamingo` first." + + assert ( + self.lang_encoder._use_cached_vision_x or vision_x is not None + ), "Must provide either vision_x or have precached media using cache_media()." + + if self.lang_encoder._use_cached_vision_x: + # Case: use cached; vision_x should be cached and other + # vision-related inputs should not be provided. + assert ( + vision_x is None + ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." + assert self.lang_encoder.is_conditioned() + + else: + # Case: do not use caching (i.e. this is a standard forward pass); + self._encode_vision_x(vision_x=vision_x) + self._condition_media_locations(input_ids=lang_x) + + output = self.lang_encoder( + input_ids=lang_x, + attention_mask=attention_mask, + labels=labels, + past_key_values=past_key_values, + use_cache=use_cache, + ) + + if clear_conditioned_layers: + self.lang_encoder.clear_conditioned_layers() + + return output + + def generate( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + **kwargs, + ): + """ + Generate text conditioned on vision and language inputs. + + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + images in the same chunk are collated along T_img, and frames are collated along F + currently only F=1 is supported (single-frame videos) + lang_x (torch.Tensor): Language input + shape (B, T_txt) + **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: + max_length (int, optional): Maximum length of the output. Defaults to None. + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + num_beams (int, optional): Number of beams. Defaults to 1. + max_new_tokens (int, optional): Maximum new tokens. Defaults to None. + temperature (float, optional): Temperature. Defaults to 1.0. + top_k (int, optional): Top k. Defaults to 50. + top_p (float, optional): Top p. Defaults to 1.0. + no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. + length_penalty (float, optional): Length penalty. Defaults to 1.0. + num_return_sequences (int, optional): Number of return sequences. Defaults to 1. + do_sample (bool, optional): Do sample. Defaults to False. + early_stopping (bool, optional): Early stopping. Defaults to False. + Returns: + torch.Tensor: lang_x with generated tokens appended to it + """ + num_beams = kwargs.pop("num_beams", 1) + if num_beams > 1: + vision_x = vision_x.repeat_interleave(num_beams, dim=0) + + self.lang_encoder._use_cached_vision_x = True + self._encode_vision_x(vision_x=vision_x) + + eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) + output = self.lang_encoder.generate( + input_ids=lang_x, + attention_mask=attention_mask, + eos_token_id=eos_token_id, + num_beams=num_beams, + **kwargs, + ) + + self.lang_encoder.clear_conditioned_layers() + self.lang_encoder._use_cached_vision_x = False + return output + + def _encode_vision_x(self, vision_x: torch.Tensor): + """ + Compute media tokens from vision input by passing it through vision encoder and conditioning language model. + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + Images in the same chunk are collated along T_img, and frames are collated along F + Currently only F=1 is supported (single-frame videos) + + rearrange code based on https://github.com/dhansmair/flamingo-mini + """ + + assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" + b, T, F = vision_x.shape[:3] + assert F == 1, "Only single frame supported" + + vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") + with torch.no_grad(): + vision_x = self.vision_encoder(vision_x)[1] + vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) + vision_x = self.perceiver(vision_x) + + for layer in self.lang_encoder._get_decoder_layers(): + layer.condition_vis_x(vision_x) + + def wrap_fsdp(self, wrapper_kwargs, device_id): + """ + Manually wraps submodules for FSDP and move other parameters to device_id. + + Why manually wrap? + - all parameters within the FSDP wrapper must have the same requires_grad. + We have a mix of frozen and unfrozen parameters. + - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors + See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 + + The rough wrapping structure is: + - FlamingoModel + - FSDP(FSDP(vision_encoder)) + - FSDP(FSDP(perceiver)) + - lang_encoder + - FSDP(FSDP(input_embeddings)) + - FlamingoLayers + - FSDP(FSDP(gated_cross_attn_layer)) + - FSDP(FSDP(decoder_layer)) + - FSDP(FSDP(output_embeddings)) + - other parameters + + Known issues: + - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, + train with DDP or set the --freeze_lm_embeddings flag to true. + - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. + Although the training curves look okay, we found that downstream performance dramatically + degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). + + FAQs about our FSDP wrapping strategy: + Why double wrap? + As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook + only free gathered parameters if the module is NOT FSDP root. + + Why unfreeze the decoder_layers? + See https://github.com/pytorch/pytorch/issues/95805 + As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param + requires_grad=True. We need the postback to fire to avoid OOM. + To effectively freeze the decoder layers, we exclude them from the optimizer. + + What is assumed to be frozen v. unfrozen? + We assume that the model is being trained under normal Flamingo settings + with these lines being called in factory.py: + ``` + # Freeze all parameters + model.requires_grad_(False) + assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 + + # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings + model.perceiver.requires_grad_(True) + model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) + [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True) + ``` + """ + # unfreeze the decoder layers + for block in self.lang_encoder.old_decoder_blocks: + block.requires_grad_(True) + + # wrap in FSDP + with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): + self.perceiver = wrap(wrap(self.perceiver)) + self.lang_encoder.old_decoder_blocks = nn.ModuleList( + wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks + ) + self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( + wrap(wrap(layer)) if layer is not None else None for layer in self.lang_encoder.gated_cross_attn_layers + ) + self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) + self.lang_encoder.set_input_embeddings(wrap(wrap(self.lang_encoder.get_input_embeddings()))) + self.lang_encoder.set_output_embeddings(wrap(wrap(self.lang_encoder.get_output_embeddings()))) + self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen + + # manually move non-FSDP managed parameters to device_id + # these are all in lang_encoder + apply_with_stopping_condition( + module=self.lang_encoder, + apply_fn=lambda m: m.to(device_id), + apply_condition=lambda m: len(list(m.children())) == 0, + stopping_condition=lambda m: isinstance(m, FSDP), + ) + + # exclude the original decoder layers from the optimizer + for block in self.lang_encoder.old_decoder_blocks: + for p in block.parameters(): + p.exclude_from_optimizer = True + + # set up clip_grad_norm_ function + def clip_grad_norm_(max_norm): + self.perceiver.clip_grad_norm_(max_norm) + for layer in self.lang_encoder.gated_cross_attn_layers: + if layer is not None: + layer.clip_grad_norm_(max_norm) + self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) + + self.clip_grad_norm_ = clip_grad_norm_ + + def _condition_media_locations(self, input_ids: torch.Tensor): + """ + Compute the media token locations from lang_x and condition the language model on these. + Args: + input_ids (torch.Tensor): Language input + shape (B, T_txt) + """ + media_locations = input_ids == self.media_token_id + + for layer in self.lang_encoder._get_decoder_layers(): + layer.condition_media_locations(media_locations) + + def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): + """ + Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. + All subsequent calls to forward() will generate attending to the LAST + image in vision_x. + This is not meant to be used to cache things for generate(). + Args: + input_ids (torch.Tensor): Language input + shape (B, T_txt) + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + Images in the same chunk are collated along T_img, and frames are collated along F + Currently only F=1 is supported (single-frame videos) + """ + self._encode_vision_x(vision_x=vision_x) + self._condition_media_locations(input_ids=input_ids) + self.lang_encoder._use_cached_vision_x = True + + def uncache_media(self): + """ + Clear all conditioning. + """ + self.lang_encoder.clear_conditioned_layers() + self.lang_encoder._use_cached_vision_x = False diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py new file mode 100644 index 0000000000..9b947dfff1 --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py @@ -0,0 +1,149 @@ +import torch.nn as nn +from .helpers import GatedCrossAttentionBlock +from .utils import getattr_recursive, setattr_recursive + + +class FlamingoLayer(nn.Module): + """ + FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. + """ + + def __init__(self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False): + super().__init__() + self.gated_cross_attn_layer = gated_cross_attn_layer + self.decoder_layer = decoder_layer + self.vis_x = None + self.media_locations = None + if self.gated_cross_attn_layer is not None: + self.gated_cross_attn_layer._use_gradient_checkpointing = gradient_checkpointing + self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing + + def is_conditioned(self) -> bool: + """Check whether the layer is conditioned.""" + return self.vis_x is not None and self.media_locations is not None + + # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) + def condition_vis_x(self, vis_x): + self.vis_x = vis_x + + def condition_media_locations(self, media_locations): + self.media_locations = media_locations + + def condition_use_cached_media(self, use_cached_media): + self.use_cached_media = use_cached_media + + def forward( + self, + lang_x, + attention_mask=None, + **decoder_layer_kwargs, + ): + # Cross attention + if self.gated_cross_attn_layer is not None: + if self.vis_x is None: + raise ValueError("vis_x must be conditioned before forward pass") + + if self.media_locations is None: + raise ValueError("media_locations must be conditioned before forward pass") + + lang_x = self.gated_cross_attn_layer( + lang_x, + self.vis_x, + media_locations=self.media_locations, + use_cached_media=self.use_cached_media, + ) + + # Normal decoder layer + lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) + return lang_x + + +class FlamingoLMMixin(nn.Module): + """ + Mixin to add cross-attention layers to a language model. + """ + + def set_decoder_layers_attr_name(self, decoder_layers_attr_name): + self.decoder_layers_attr_name = decoder_layers_attr_name + + def _get_decoder_layers(self): + return getattr_recursive(self, self.decoder_layers_attr_name) + + def _set_decoder_layers(self, value): + setattr_recursive(self, self.decoder_layers_attr_name, value) + + def init_flamingo( + self, + media_token_id, + lang_hidden_size, + vis_hidden_size, + cross_attn_every_n_layers, + gradient_checkpointing, + ): + """ + Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. + """ + self.old_decoder_blocks = self._get_decoder_layers() + self.gated_cross_attn_layers = nn.ModuleList( + [ + GatedCrossAttentionBlock(dim=lang_hidden_size, dim_visual=vis_hidden_size) + if (layer_idx + 1) % cross_attn_every_n_layers == 0 + else None + for layer_idx, _ in enumerate(self._get_decoder_layers()) + ] + ) + self.init_flamingo_layers(gradient_checkpointing) + self.media_token_id = media_token_id + self.initialized_flamingo = True + self._use_cached_vision_x = False + + def init_flamingo_layers(self, gradient_checkpointing): + """ + Re initializes the FlamingoLayers. + Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks + """ + self._set_decoder_layers( + nn.ModuleList( + [ + FlamingoLayer(gated_cross_attn_layer, decoder_layer, gradient_checkpointing) + for gated_cross_attn_layer, decoder_layer in zip( + self.gated_cross_attn_layers, self.old_decoder_blocks + ) + ] + ) + ) + + def forward(self, input_ids, attention_mask, **kwargs): + """Condition the Flamingo layers on the media locations before forward()""" + if not self.initialized_flamingo: + raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.") + + media_locations = input_ids == self.media_token_id + + # if there are media already cached and we're generating and there are no media tokens in the input, + # we'll assume that ALL input tokens should attend to the last previous media that is cached. + # this is especially important for HF generate() compatibility, since generate() calls forward() + # repeatedly one token at a time (with no media tokens). + # without this check, the model would not attend to any images when generating (after the first token) + use_cached_media_locations = self._use_cached_vision_x and self.is_conditioned() and not media_locations.any() + + for layer in self._get_decoder_layers(): + if not use_cached_media_locations: + layer.condition_media_locations(media_locations) + layer.condition_use_cached_media(use_cached_media_locations) + + # package arguments for the other parent's forward. since we don't know the order of the arguments, + # make them all kwargs + kwargs["input_ids"] = input_ids + kwargs["attention_mask"] = attention_mask + return super().forward(**kwargs) # Call the other parent's forward method + + def is_conditioned(self) -> bool: + """Check whether all decoder layers are already conditioned.""" + return all(l.is_conditioned() for l in self._get_decoder_layers()) + + def clear_conditioned_layers(self): + for layer in self._get_decoder_layers(): + layer.condition_vis_x(None) + layer.condition_media_locations(None) + layer.condition_use_cached_media(None) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/helpers.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/helpers.py new file mode 100644 index 0000000000..a12b8ec40a --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/helpers.py @@ -0,0 +1,267 @@ +""" +Based on: https://github.com/lucidrains/flamingo-pytorch +""" + +import torch +from einops import rearrange, repeat +from einops_exts import rearrange_many +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=6, + dim_head=64, + heads=8, + num_latents=64, + max_num_media=None, + max_num_frames=None, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None + self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + """ + Args: + x (torch.Tensor): image features + shape (b, T, F, v, D) + Returns: + shape (b, T, n, D) where n is self.num_latents + """ + b, T, F, v = x.shape[:4] + + # frame and media time embeddings + if exists(self.frame_embs): + frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) + x = x + frame_embs + x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions + if exists(self.media_time_embs): + x = x + self.media_time_embs[:T] + + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents) + + +# gated cross attention +class MaskedCrossAttention(nn.Module): + def __init__( + self, + *, + dim, + dim_visual, + dim_head=64, + heads=8, + only_attend_immediate_media=True, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # whether for text to only attend to immediate preceding image, or all previous images + self.only_attend_immediate_media = only_attend_immediate_media + + def forward(self, x, media, media_locations=None, use_cached_media=False): + """ + Args: + x (torch.Tensor): text features + shape (B, T_txt, D_txt) + media (torch.Tensor): image features + shape (B, T_img, n, D_img) where n is the dim of the latents + media_locations: boolean mask identifying the media tokens in x + shape (B, T_txt) + use_cached_media: bool + If true, treat all of x as if they occur after the last media + registered in media_locations. T_txt does not need to exactly + equal media_locations.shape[1] in this case + """ + + if not use_cached_media: + assert ( + media_locations.shape[1] == x.shape[1] + ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" + + T_txt = x.shape[1] + _, T_img, n = media.shape[:3] + h = self.heads + + x = self.norm(x) + + q = self.to_q(x) + media = rearrange(media, "b t n d -> b (t n) d") + + k, v = self.to_kv(media).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) + + q = q * self.scale + + sim = einsum("... i d, ... j d -> ... i j", q, k) + + if exists(media_locations): + media_time = torch.arange(T_img, device=x.device) + 1 + + if use_cached_media: + # text time is set to the last cached media location + text_time = repeat( + torch.count_nonzero(media_locations, dim=1), + "b -> b i", + i=T_txt, + ) + else: + # at each boolean of True, increment the time counter (relative to media time) + text_time = media_locations.cumsum(dim=-1) + + # text time must equal media time if only attending to most immediate image + # otherwise, as long as text time is greater than media time (if attending to all previous images / media) + mask_op = torch.eq if self.only_attend_immediate_media else torch.ge + + text_to_media_mask = mask_op( + rearrange(text_time, "b i -> b 1 i 1"), + repeat(media_time, "j -> 1 1 1 (j n)", n=n), + ) + sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + if exists(media_locations) and self.only_attend_immediate_media: + # any text without a preceding media needs to have attention zeroed out + text_without_media_mask = text_time == 0 + text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1") + attn = attn.masked_fill(text_without_media_mask, 0.0) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class GatedCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + dim, + dim_visual, + dim_head=64, + heads=8, + ff_mult=4, + only_attend_immediate_media=True, + ): + super().__init__() + self.attn = MaskedCrossAttention( + dim=dim, + dim_visual=dim_visual, + dim_head=dim_head, + heads=heads, + only_attend_immediate_media=only_attend_immediate_media, + ) + self.attn_gate = nn.Parameter(torch.tensor([0.0])) + + self.ff = FeedForward(dim, mult=ff_mult) + self.ff_gate = nn.Parameter(torch.tensor([0.0])) + + def forward( + self, + x, + media, + media_locations=None, + use_cached_media=False, + ): + x = ( + self.attn( + x, + media, + media_locations=media_locations, + use_cached_media=use_cached_media, + ) + * self.attn_gate.tanh() + + x + ) + x = self.ff(x) * self.ff_gate.tanh() + x + + return x diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py new file mode 100644 index 0000000000..e8a19ddf9e --- /dev/null +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py @@ -0,0 +1,42 @@ +def extend_instance(obj, mixin): + """Apply mixins to a class instance after creation""" + base_cls = obj.__class__ + base_cls_name = obj.__class__.__name__ + obj.__class__ = type( + base_cls_name, (mixin, base_cls), {} + ) # mixin needs to go first for our forward() logic to work + + +def getattr_recursive(obj, att): + """ + Return nested attribute of obj + Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c + """ + if att == "": + return obj + i = att.find(".") + if i < 0: + return getattr(obj, att) + else: + return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) + + +def setattr_recursive(obj, att, val): + """ + Set nested attribute of obj + Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val + """ + if "." in att: + obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) + setattr(obj, att.split(".")[-1], val) + + +def apply_with_stopping_condition(module, apply_fn, apply_condition=None, stopping_condition=None, **other_args): + if stopping_condition(module): + return + if apply_condition(module): + apply_fn(module, **other_args) + for child in module.children(): + apply_with_stopping_condition( + child, apply_fn, apply_condition=apply_condition, stopping_condition=stopping_condition, **other_args + ) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index 7e6e7468e6..0f0801fbf4 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -3,7 +3,7 @@ import torch from huggingface_hub import hf_hub_download -from open_flamingo import create_model_and_transforms +from helm.proxy.clients.vision_language.open_flamingo import create_model_and_transforms from helm.common.cache import CacheConfig from helm.common.images_utils import open_image From 1d23a9a01ce3848235b42a76bb5cf54097a0e509 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 01:16:18 -0800 Subject: [PATCH 11/35] add openflamingo --- src/helm/config/model_deployments.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index 582485e1fc..315c90b266 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -471,7 +471,7 @@ model_deployments: tokenizer_name: anas-awadalla/mpt-7b max_sequence_length: 2048 client_spec: - class_name: "helm.proxy.clients.vision_language.huggingface_vlm_client.OpenFlamingoClient" + class_name: "helm.proxy.clients.vision_language.open_flamingo_client.OpenFlamingoClient" ## Mistral AI - name: huggingface/bakLlava-v1-hf From d27893a95acd767606387f26aeb3cd1f74097005 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 01:18:40 -0800 Subject: [PATCH 12/35] add openflamingo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 95358866d3..6e27e24f43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -168,7 +168,7 @@ torchvision==0.13.1 ; sys_platform == "darwin" torch==1.12.1+cu113 ; sys_platform == "linux" torchvision==0.13.1+cu113 ; sys_platform == "linux" tqdm==4.64.1 -transformers==4.36.0 +transformers==4.28.1 trio==0.22.0 trio-websocket==0.9.2 typer==0.4.2 From c9fd286dd79aaa8382399f8555ce9a2e1a0e68e6 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 01:29:35 -0800 Subject: [PATCH 13/35] add openflamingo --- src/helm/proxy/clients/vision_language/open_flamingo_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index 0f0801fbf4..c8d9437db4 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -115,7 +115,7 @@ def do_it(): tokens: List[Token] = [ Token(text=str(self.tokenizer.decode(id)), logprob=0, top_logprobs={}) for id in lang_x["input_ids"][0] ] - completions: List[Sequence] = [Sequence(text=result["generated_text"], logprob=0, tokens=tokens)] + completions: List[Sequence] = [Sequence(text=result["output"], logprob=0, tokens=tokens)] return RequestResult( success=True, cached=cached, From 5367f057f21eb81976793476623bef5c327cc137 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 01:52:40 -0800 Subject: [PATCH 14/35] add openflamingo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6e27e24f43..fa52ba0090 100644 --- a/requirements.txt +++ b/requirements.txt @@ -168,7 +168,7 @@ torchvision==0.13.1 ; sys_platform == "darwin" torch==1.12.1+cu113 ; sys_platform == "linux" torchvision==0.13.1+cu113 ; sys_platform == "linux" tqdm==4.64.1 -transformers==4.28.1 +transformers==4.32.0 trio==0.22.0 trio-websocket==0.9.2 typer==0.4.2 From 0afa0516c31f3888a2a8f3032879cae3b9c8ff34 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 02:07:30 -0800 Subject: [PATCH 15/35] add openflamingo --- src/helm/config/tokenizer_configs.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/helm/config/tokenizer_configs.yaml b/src/helm/config/tokenizer_configs.yaml index d9f8ffa80c..417a72a405 100644 --- a/src/helm/config/tokenizer_configs.yaml +++ b/src/helm/config/tokenizer_configs.yaml @@ -165,6 +165,12 @@ tokenizer_configs: class_name: "helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" end_of_text_token: "" prefix_token: "" + + - name: anas-awadalla/mpt-7b + tokenizer_spec: + class_name: "helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" + end_of_text_token: "<|endoftext|>" + prefix_token: "" # Huggingface - name: huggingface/gpt2 From 014aadae44beb9ce6b41808f29ad5ff46fcd3a3e Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 09:40:23 -0800 Subject: [PATCH 16/35] add openflamingo --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index fa52ba0090..de3f802e7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,6 +36,8 @@ dacite==1.6.0 datasets==2.5.2 dill==0.3.5.1 distlib==0.3.6 +einops==0.7.0 +einops-exts==0.0.4 emoji==2.1.0 et-xmlfile==1.1.0 exceptiongroup==1.1.0 From 1bee70c697969de3c6f187fdf452093750aa7ae8 Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Sun, 14 Jan 2024 10:26:59 -0800 Subject: [PATCH 17/35] add openflamingo --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index de3f802e7f..7515f9235c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -91,6 +91,7 @@ nodeenv==1.7.0 numba==0.56.4 numpy==1.23.3 openai==0.27.8 +open-clip-torch==2.24.0 opencv-python==4.8.1.78 openpyxl==3.0.10 outcome==1.2.0 From 3613907906ba94799e7017e464492615ed6f609f Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Tue, 16 Jan 2024 12:56:22 -0800 Subject: [PATCH 18/35] fix GHA build - define openflamingo dependencies --- setup.cfg | 8 +++++++- .../vision_language/open_flamingo/src/factory.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 2537702cad..e611540801 100644 --- a/setup.cfg +++ b/setup.cfg @@ -135,7 +135,13 @@ models = crfm-helm[yandex] vlm = - torch~=2.1.2 # For IDEFICS + # For OpenFlamingo + einops~=0.7.0 + einops-exts~=0.0.4 + open-clip-torch~=2.24.0 + + # For IDEFICS + torch~=2.1.2 heim = # HEIM scenarios diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py index 5f5fadff21..79989ba87d 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py @@ -1,8 +1,8 @@ from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer -import open_clip +from helm.common.general import handle_module_not_found_error from .flamingo import Flamingo from .flamingo_lm import FlamingoLMMixin from .utils import extend_instance @@ -39,6 +39,11 @@ def create_model_and_transforms( Image processor: Pipeline to preprocess input images Tokenizer: A tokenizer for the language model """ + try: + import open_clip + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["vlm"]) + vision_encoder, _, image_processor = open_clip.create_model_and_transforms( clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained, @@ -121,7 +126,8 @@ def _infer_decoder_layers_attr_name(model): return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] raise ValueError( - f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." + "We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. " + "Please supply this string manually." ) From 9c1ad0a98f640e23b03f5cf5d85261bbef95584e Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Fri, 19 Jan 2024 18:02:49 -0800 Subject: [PATCH 19/35] address code review --- src/helm/benchmark/run_expander.py | 24 ------------------- src/helm/benchmark/run_specs.py | 6 ----- .../open_flamingo/src/factory.py | 4 ++++ .../open_flamingo/src/flamingo.py | 4 ++++ .../open_flamingo/src/flamingo_lm.py | 4 ++++ .../open_flamingo/src/utils.py | 5 ++++ .../vision_language/open_flamingo_client.py | 4 +++- 7 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/helm/benchmark/run_expander.py b/src/helm/benchmark/run_expander.py index 35d7ff8583..afb68dd1e8 100644 --- a/src/helm/benchmark/run_expander.py +++ b/src/helm/benchmark/run_expander.py @@ -420,30 +420,6 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]: ] -class OpenFlamingoRunExpander(RunExpander): - """ - Custom prompt for OpenFlamingo models. - See https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b for more information. - """ - - name = "open_flamingo" - - def expand(self, run_spec: RunSpec) -> List[RunSpec]: - return [ - replace( - run_spec, - name=run_spec.name, - adapter_spec=replace( - run_spec.adapter_spec, - input_prefix="", - input_suffix="", - output_prefix="", - output_suffix="", - ), - ), - ] - - class FormatPromptRunExpander(RunExpander): """Adds a prefix and suffix to the prompt.""" diff --git a/src/helm/benchmark/run_specs.py b/src/helm/benchmark/run_specs.py index 09c2132e2c..b2d798007f 100644 --- a/src/helm/benchmark/run_specs.py +++ b/src/helm/benchmark/run_specs.py @@ -32,7 +32,6 @@ GoogleRunExpander, IDEFICSInstructRunExpander, LlavaRunExpander, - OpenFlamingoRunExpander, StopRunExpander, ChatMLRunExpander, IncreaseTemperatureRunExpander, @@ -67,7 +66,6 @@ GOOGLE_GEMINI_MODEL_TAG, IDEFICS_INSTRUCT_MODEL_TAG, LLAVA_MODEL_TAG, - OPEN_FLAMINGO_MODEL_TAG, NO_NEWLINES_TAG, NLG_PREFIX_TAG, CHATML_MODEL_TAG, @@ -3095,10 +3093,6 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec: if LLAVA_MODEL_TAG in model.tags: run_spec = singleton(LlavaRunExpander().expand(run_spec)) - # OpenFlamingo - if OPEN_FLAMINGO_MODEL_TAG in model.tags: - run_spec = singleton(OpenFlamingoRunExpander().expand(run_spec)) - # For multiple choice if BUGGY_TEMP_0_TAG in model.tags and run_spec.adapter_spec.temperature == 0: increase_temperature_expander = IncreaseTemperatureRunExpander(value=1e-4) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py index 79989ba87d..5f842d3a8e 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/factory.py @@ -1,3 +1,7 @@ +""" +Source: https://github.com/mlfoundations/open_flamingo +""" + from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py index 7c29061342..911d233a7f 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo.py @@ -1,3 +1,7 @@ +""" +Source: https://github.com/mlfoundations/open_flamingo +""" + import torch from einops import rearrange from torch import nn diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py index 9b947dfff1..97abb84f23 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/flamingo_lm.py @@ -1,3 +1,7 @@ +""" +Source: https://github.com/mlfoundations/open_flamingo +""" + import torch.nn as nn from .helpers import GatedCrossAttentionBlock from .utils import getattr_recursive, setattr_recursive diff --git a/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py b/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py index e8a19ddf9e..1888b7e0a1 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo/src/utils.py @@ -1,3 +1,8 @@ +""" +Source: https://github.com/mlfoundations/open_flamingo +""" + + def extend_instance(obj, mixin): """Apply mixins to a class instance after creation""" base_cls = obj.__class__ diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index c8d9437db4..823e239994 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -23,7 +23,9 @@ class OpenFlamingoClient(CachingClient): """ OpenFlamingo is an open source implementation of DeepMind's Flamingo models. - https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b + Implementation following: + https://github.com/mlfoundations/open_flamingo + https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b """ END_OF_CHUNK_TOKEN: str = "<|endofchunk|>" From 1ca84082f70e361604969458ae27600621e1322c Mon Sep 17 00:00:00 2001 From: Michihiro Yasunaga Date: Wed, 7 Feb 2024 23:06:26 -0800 Subject: [PATCH 20/35] fix transformers version --- requirements.txt | 2 +- setup.cfg | 2 +- .../proxy/clients/vision_language/open_flamingo_client.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7515f9235c..7bcd2c7422 100644 --- a/requirements.txt +++ b/requirements.txt @@ -171,7 +171,7 @@ torchvision==0.13.1 ; sys_platform == "darwin" torch==1.12.1+cu113 ; sys_platform == "linux" torchvision==0.13.1+cu113 ; sys_platform == "linux" tqdm==4.64.1 -transformers==4.32.0 +transformers==4.37.2 trio==0.22.0 trio-websocket==0.9.2 typer==0.4.2 diff --git a/setup.cfg b/setup.cfg index e611540801..97861f8d35 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ install_requires= scikit-learn~=1.1.2 # Models and Metrics Extras - transformers>=4.28.0 # For anthropic_client, vision_language.huggingface_vlm_client, huggingface_client, huggingface_tokenizer, test_openai_token_cost_estimator, model_summac (via summarization_metrics) + transformers>=4.36.0 # For anthropic_client, vision_language.huggingface_vlm_client, huggingface_client, huggingface_tokenizer, test_openai_token_cost_estimator, model_summac (via summarization_metrics) # TODO: Upgrade torch - we need > 2.0.0 for newer versions of transformers torch>=1.12.1,<3.0.0 # For huggingface_client, yalm_tokenizer, model_summac (via summarization_metrics) torchvision>=0.13.1,<3.0.0 # For huggingface_client, yalm_tokenizer, model_summac (via summarization_metrics) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index 823e239994..887d96c06f 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -43,8 +43,8 @@ def _get_model(self): self._model, self.image_processor, self.tokenizer = create_model_and_transforms( clip_vision_encoder_path="ViT-L-14", clip_vision_encoder_pretrained="openai", - lang_encoder_path="anas-awadalla/mpt-7b", - tokenizer_path="anas-awadalla/mpt-7b", + lang_encoder_path="anas-awadalla-2/mpt-7b", + tokenizer_path="anas-awadalla-2/mpt-7b", cross_attn_every_n_layers=4, ) self.tokenizer.padding_side = "left" From 84cc5737ff5d17d127e8c14794d74b50a904fdd2 Mon Sep 17 00:00:00 2001 From: JosselinSomervilleRoberts Date: Wed, 21 Feb 2024 23:03:12 -0800 Subject: [PATCH 21/35] Add some parameters to the model deployment --- src/helm/config/model_deployments.yaml | 4 +++ .../vision_language/open_flamingo_client.py | 35 +++++++++++++++---- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index 17844c294b..e6035419f2 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -608,6 +608,10 @@ model_deployments: max_sequence_length: 2048 client_spec: class_name: "helm.proxy.clients.vision_language.open_flamingo_client.OpenFlamingoClient" + args: + checkpoint_path: "openflamingo/OpenFlamingo-9B-vitl-mpt7b" + tokenizer_name: "anas-awadalla-2/mpt-7b" + cross_attn_every_n_layers: 4 - name: together/phi-2 model_name: microsoft/phi-2 diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index 887d96c06f..55f3f6759d 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -1,5 +1,5 @@ from threading import Lock -from typing import List +from typing import List, Optional import torch from huggingface_hub import hf_hub_download @@ -33,28 +33,49 @@ class OpenFlamingoClient(CachingClient): _model_lock: Lock = Lock() - def __init__(self, cache_config: CacheConfig): + def __init__( + self, + cache_config: CacheConfig, + checkpoint_path: Optional[str] = None, + tokenizer_name: Optional[str] = None, + cross_attn_every_n_layers: int = 4, + ): super().__init__(cache_config) self._device: str = get_torch_device_name() - self._get_model() + self._checkpoint_path: Optional[str] = checkpoint_path + self._tokenizer_name: Optional[str] = tokenizer_name + self._cross_attn_every_n_layers: int = cross_attn_every_n_layers + + # Model + # The model is only initialized when the first request is made + # This is to avoid loading the model if it is not used + self._model: Optional[torch.nn.Module] = None def _get_model(self): + if not self._checkpoint_path: + raise ValueError("OpenFlamingoClient requires a checkpoint path") + if not self._tokenizer_name: + raise ValueError("OpenFlamingoClient requires a tokenizer name") with self._model_lock: self._model, self.image_processor, self.tokenizer = create_model_and_transforms( clip_vision_encoder_path="ViT-L-14", clip_vision_encoder_pretrained="openai", - lang_encoder_path="anas-awadalla-2/mpt-7b", - tokenizer_path="anas-awadalla-2/mpt-7b", - cross_attn_every_n_layers=4, + lang_encoder_path=self._tokenizer_name, + tokenizer_path=self._tokenizer_name, + cross_attn_every_n_layers=self._cross_attn_every_n_layers, ) self.tokenizer.padding_side = "left" - checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt") + checkpoint_path = hf_hub_download(self._checkpoint_path, "checkpoint.pt") self._model.load_state_dict(torch.load(checkpoint_path), strict=False) self._model = self._model.to(self._device) def make_request(self, request: Request) -> RequestResult: assert request.multimodal_prompt is not None, "Multimodal prompt is required" + # Load model if needed + if self._model is None: + self._get_model() + # Build the prompt prompt_text: str = "" images: List[Image.Image] = [] From 3756295dee3c364456ed7ea5dab9ba17871bbf0f Mon Sep 17 00:00:00 2001 From: JosselinSomervilleRoberts Date: Wed, 21 Feb 2024 23:04:14 -0800 Subject: [PATCH 22/35] Fixing einops dependency conflict --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index eed2a13dc8..b970f19de6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -184,7 +184,7 @@ heim = crfm-helm[openai] # For model, kakaobrain/mindall-e - einops~=0.6.0 + einops~=0.7.0 omegaconf~=2.3.0 pytorch-lightning~=2.0.5 From 9dc70dc75fdf0237d5826d00febfd4b1900a4b66 Mon Sep 17 00:00:00 2001 From: JosselinSomervilleRoberts Date: Wed, 21 Feb 2024 23:05:21 -0800 Subject: [PATCH 23/35] Remove duplicated crfm-helm['image'] dependency --- setup.cfg | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index b970f19de6..52a357e0f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -157,8 +157,6 @@ vlm = crfm-helm[openai] torch~=2.1.2 # For IDEFICS - crfm-helm[images] - # VLM scenarios crfm-helm[images] crfm-helm[image2structure] From 8c44df71964812419a5ee87c92fdcdddadb6e9be Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 26 Feb 2024 16:38:26 -0800 Subject: [PATCH 24/35] more logging for model init --- .../vision_language/open_flamingo_client.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index 55f3f6759d..68c5850857 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -6,6 +6,7 @@ from helm.proxy.clients.vision_language.open_flamingo import create_model_and_transforms from helm.common.cache import CacheConfig +from helm.common.hierarchical_logger import hlog, htrack_block from helm.common.images_utils import open_image from helm.common.gpu_utils import get_torch_device_name from helm.common.media_object import TEXT_TYPE @@ -56,18 +57,20 @@ def _get_model(self): raise ValueError("OpenFlamingoClient requires a checkpoint path") if not self._tokenizer_name: raise ValueError("OpenFlamingoClient requires a tokenizer name") - with self._model_lock: - self._model, self.image_processor, self.tokenizer = create_model_and_transforms( - clip_vision_encoder_path="ViT-L-14", - clip_vision_encoder_pretrained="openai", - lang_encoder_path=self._tokenizer_name, - tokenizer_path=self._tokenizer_name, - cross_attn_every_n_layers=self._cross_attn_every_n_layers, - ) - self.tokenizer.padding_side = "left" - checkpoint_path = hf_hub_download(self._checkpoint_path, "checkpoint.pt") - self._model.load_state_dict(torch.load(checkpoint_path), strict=False) - self._model = self._model.to(self._device) + with htrack_block("Initializing OpenFlamingo model"): + with self._model_lock: + self._model, self.image_processor, self.tokenizer = create_model_and_transforms( + clip_vision_encoder_path="ViT-L-14", + clip_vision_encoder_pretrained="openai", + lang_encoder_path=self._tokenizer_name, + tokenizer_path=self._tokenizer_name, + cross_attn_every_n_layers=self._cross_attn_every_n_layers, + ) + self.tokenizer.padding_side = "left" + checkpoint_path = hf_hub_download(self._checkpoint_path, "checkpoint.pt") + self._model.load_state_dict(torch.load(checkpoint_path), strict=False) + self._model = self._model.to(self._device) + hlog(f"Loaded model to {self._device}.") def make_request(self, request: Request) -> RequestResult: assert request.multimodal_prompt is not None, "Multimodal prompt is required" From 1fd7961716d80f7db9cf301ded5b19bf56b516e1 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 26 Feb 2024 17:40:46 -0800 Subject: [PATCH 25/35] fix token init in openflamingo --- src/helm/proxy/clients/vision_language/open_flamingo_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index 68c5850857..c0869c3c3f 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -139,7 +139,7 @@ def do_it(): return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[]) tokens: List[Token] = [ - Token(text=str(self.tokenizer.decode(id)), logprob=0, top_logprobs={}) for id in lang_x["input_ids"][0] + Token(text=str(self.tokenizer.decode(id)), logprob=0) for id in lang_x["input_ids"][0] ] completions: List[Sequence] = [Sequence(text=result["output"], logprob=0, tokens=tokens)] return RequestResult( From 472bacbcedf8bfe2f691aa3c2f5a416527fd905a Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 26 Feb 2024 17:41:26 -0800 Subject: [PATCH 26/35] fix token init in openflamingo --- .../proxy/clients/vision_language/open_flamingo_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/helm/proxy/clients/vision_language/open_flamingo_client.py b/src/helm/proxy/clients/vision_language/open_flamingo_client.py index c0869c3c3f..15f31aa023 100644 --- a/src/helm/proxy/clients/vision_language/open_flamingo_client.py +++ b/src/helm/proxy/clients/vision_language/open_flamingo_client.py @@ -138,9 +138,7 @@ def do_it(): except RuntimeError as e: return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[]) - tokens: List[Token] = [ - Token(text=str(self.tokenizer.decode(id)), logprob=0) for id in lang_x["input_ids"][0] - ] + tokens: List[Token] = [Token(text=str(self.tokenizer.decode(id)), logprob=0) for id in lang_x["input_ids"][0]] completions: List[Sequence] = [Sequence(text=result["output"], logprob=0, tokens=tokens)] return RequestResult( success=True, From 0ea9ced202aa122500a97326a6f099bd09255766 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Tue, 27 Feb 2024 20:56:20 -0800 Subject: [PATCH 27/35] fix tokenizer --- src/helm/config/tokenizer_configs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helm/config/tokenizer_configs.yaml b/src/helm/config/tokenizer_configs.yaml index fed56cffb2..9a8fbc0ba6 100644 --- a/src/helm/config/tokenizer_configs.yaml +++ b/src/helm/config/tokenizer_configs.yaml @@ -173,7 +173,7 @@ tokenizer_configs: - name: anas-awadalla/mpt-7b tokenizer_spec: - class_name: "helm.proxy.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" + class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" end_of_text_token: "<|endoftext|>" prefix_token: "" From 6128ad21dfa31872b2901cd0b758381a109bac3a Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Tue, 27 Feb 2024 21:32:26 -0800 Subject: [PATCH 28/35] update conf --- .../benchmark/presentation/run_specs_image2structure.conf | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/helm/benchmark/presentation/run_specs_image2structure.conf b/src/helm/benchmark/presentation/run_specs_image2structure.conf index 3450e2a064..ce01914b4a 100644 --- a/src/helm/benchmark/presentation/run_specs_image2structure.conf +++ b/src/helm/benchmark/presentation/run_specs_image2structure.conf @@ -10,6 +10,11 @@ entries: [ # sheetmusic2lilypond {description: "sheetmusic2lilypond:model=vlm", priority: 1} + # webpages + {description: "image2webpage:subset=css,model=vlm", priority: 1, groups: ["image2webpage"]} + {description: "image2webpage:subset=html,model=vlm", priority: 1, groups: ["image2webpage"]} + {description: "image2webpage:subset=javascript,model=vlm", priority: 1, groups: ["image2webpage"]} + # chart2csv # {description: "chart2csv:model=vlm", priority: 1} ] From 865b33a0015de616e089979d89ff9f3807b48a35 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Wed, 28 Feb 2024 09:47:11 -0800 Subject: [PATCH 29/35] disable temporarily --- src/helm/benchmark/run_specs/vlm_run_specs.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/helm/benchmark/run_specs/vlm_run_specs.py b/src/helm/benchmark/run_specs/vlm_run_specs.py index 9374ba9861..ad9f6feefe 100644 --- a/src/helm/benchmark/run_specs/vlm_run_specs.py +++ b/src/helm/benchmark/run_specs/vlm_run_specs.py @@ -240,19 +240,20 @@ def get_image2webpage_spec(subset: str, recompile_prompt: bool = False, args: Op instructions="Just give a short answer without answering in a complete sentence.", max_tokens=2000, ) - metric_specs: List[MetricSpec] = get_image2structure_metric_specs( - generate_image_metric_class="helm.benchmark.metrics.vision_language.image2structure.webpage_metrics.WebpageMetric", # noqa: E501 - args=args, - normalize_by_white_score=False, - include_edit_similarity=False, - ) + # metric_specs: List[MetricSpec] = get_image2structure_metric_specs( + # generate_image_metric_class="helm.benchmark.metrics.vision_language.image2structure.webpage_metrics.WebpageMetric", # noqa: E501 + # args=args, + # normalize_by_white_score=False, + # include_edit_similarity=False, + # ) run_spec_name: str = "image2webpage" return RunSpec( name=f"{run_spec_name}:subset={subset}", scenario_spec=scenario_spec, adapter_spec=adapter_spec, - metric_specs=metric_specs, + # metric_specs=metric_specs, + metric_specs=[], groups=[run_spec_name], ) From 8f0e763c2f023d1c4411822081afeb4403320026 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 4 Mar 2024 10:08:59 -0800 Subject: [PATCH 30/35] undo --- src/helm/benchmark/run_specs/vlm_run_specs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/helm/benchmark/run_specs/vlm_run_specs.py b/src/helm/benchmark/run_specs/vlm_run_specs.py index ad9f6feefe..9374ba9861 100644 --- a/src/helm/benchmark/run_specs/vlm_run_specs.py +++ b/src/helm/benchmark/run_specs/vlm_run_specs.py @@ -240,20 +240,19 @@ def get_image2webpage_spec(subset: str, recompile_prompt: bool = False, args: Op instructions="Just give a short answer without answering in a complete sentence.", max_tokens=2000, ) - # metric_specs: List[MetricSpec] = get_image2structure_metric_specs( - # generate_image_metric_class="helm.benchmark.metrics.vision_language.image2structure.webpage_metrics.WebpageMetric", # noqa: E501 - # args=args, - # normalize_by_white_score=False, - # include_edit_similarity=False, - # ) + metric_specs: List[MetricSpec] = get_image2structure_metric_specs( + generate_image_metric_class="helm.benchmark.metrics.vision_language.image2structure.webpage_metrics.WebpageMetric", # noqa: E501 + args=args, + normalize_by_white_score=False, + include_edit_similarity=False, + ) run_spec_name: str = "image2webpage" return RunSpec( name=f"{run_spec_name}:subset={subset}", scenario_spec=scenario_spec, adapter_spec=adapter_spec, - # metric_specs=metric_specs, - metric_specs=[], + metric_specs=metric_specs, groups=[run_spec_name], ) From e2404a9b7bb1ec51875a056fed136f27740db6bf Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 4 Mar 2024 10:19:45 -0800 Subject: [PATCH 31/35] fix paths --- src/helm/clients/vision_language/open_flamingo_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/helm/clients/vision_language/open_flamingo_client.py b/src/helm/clients/vision_language/open_flamingo_client.py index 15f31aa023..c8be3bf0fc 100644 --- a/src/helm/clients/vision_language/open_flamingo_client.py +++ b/src/helm/clients/vision_language/open_flamingo_client.py @@ -3,7 +3,6 @@ import torch from huggingface_hub import hf_hub_download -from helm.proxy.clients.vision_language.open_flamingo import create_model_and_transforms from helm.common.cache import CacheConfig from helm.common.hierarchical_logger import hlog, htrack_block @@ -13,7 +12,8 @@ from helm.common.optional_dependencies import handle_module_not_found_error from helm.common.request import Request, RequestResult, Sequence, Token from helm.common.request import wrap_request_time -from helm.proxy.clients.client import CachingClient, generate_uid_for_multimodal_prompt +from helm.clients.vision_language.open_flamingo import create_model_and_transforms +from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt try: from PIL import Image From 83eaefa31ebbb9a104249021247a91542ac601ed Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 4 Mar 2024 10:53:29 -0800 Subject: [PATCH 32/35] get in-context learning examples to work --- src/helm/benchmark/model_metadata_registry.py | 2 ++ src/helm/benchmark/run_expander.py | 20 +++++++++++++++++++ src/helm/benchmark/run_spec_factory.py | 6 ++++++ .../vision_language/open_flamingo_client.py | 4 ++-- src/helm/config/model_metadata.yaml | 2 +- 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/helm/benchmark/model_metadata_registry.py b/src/helm/benchmark/model_metadata_registry.py index 29a11344d9..3bd80c91c3 100644 --- a/src/helm/benchmark/model_metadata_registry.py +++ b/src/helm/benchmark/model_metadata_registry.py @@ -61,6 +61,8 @@ IDEFICS_MODEL_TAG: str = "IDEFICS_MODEL_TAG" # Llava should use a special prompt format (see `LlavaRunExpander`) LLAVA_MODEL_TAG: str = "LLAVA_MODEL_TAG" +# OpenFlamingo has a special prompt format (see `OpenFlamingoRunExpander`) +OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG" # Some VLMs do not support multiple images in the prompt LIMITED_FUNCTIONALITY_VLM_TAG: str = "LIMITED_FUNCTIONALITY_VLM_TAG" FULL_FUNCTIONALITY_VLM_TAG: str = "FULL_FUNCTIONALITY_VLM_TAG" diff --git a/src/helm/benchmark/run_expander.py b/src/helm/benchmark/run_expander.py index 87e87e0795..10e6d63413 100644 --- a/src/helm/benchmark/run_expander.py +++ b/src/helm/benchmark/run_expander.py @@ -447,6 +447,26 @@ def expand(self, run_spec: RunSpec) -> List[RunSpec]: ] +class OpenFlamingoRunExpander(RunExpander): + """ + Custom prompt for OpenFlamingo following: https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b + """ + + name = "open_flamingo" + + def expand(self, run_spec: RunSpec) -> List[RunSpec]: + return [ + replace( + run_spec, + name=run_spec.name, + adapter_spec=replace( + run_spec.adapter_spec, + input_prefix=f"<|endofchunk|>{run_spec.adapter_spec.input_prefix}", + ), + ), + ] + + class FormatPromptRunExpander(RunExpander): """Adds a prefix and suffix to the prompt.""" diff --git a/src/helm/benchmark/run_spec_factory.py b/src/helm/benchmark/run_spec_factory.py index b7c0c983ae..32e80d7d37 100644 --- a/src/helm/benchmark/run_spec_factory.py +++ b/src/helm/benchmark/run_spec_factory.py @@ -17,6 +17,7 @@ GOOGLE_PALM_2_MODEL_TAG, IDEFICS_INSTRUCT_MODEL_TAG, LLAVA_MODEL_TAG, + OPEN_FLAMINGO_MODEL_TAG, NLG_PREFIX_TAG, NO_NEWLINES_TAG, OPENAI_CHATGPT_MODEL_TAG, @@ -33,6 +34,7 @@ IDEFICSInstructRunExpander, IncreaseTemperatureRunExpander, LlavaRunExpander, + OpenFlamingoRunExpander, OpenAIRunExpander, MistralRunExpander, StopRunExpander, @@ -147,6 +149,10 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec: if LLAVA_MODEL_TAG in model.tags: run_spec = singleton(LlavaRunExpander().expand(run_spec)) + # OpenFlamingo + if OPEN_FLAMINGO_MODEL_TAG in model.tags: + run_spec = singleton(OpenFlamingoRunExpander().expand(run_spec)) + # For multiple choice if BUGGY_TEMP_0_TAG in model.tags and run_spec.adapter_spec.temperature == 0: increase_temperature_expander = IncreaseTemperatureRunExpander(value=1e-4) diff --git a/src/helm/clients/vision_language/open_flamingo_client.py b/src/helm/clients/vision_language/open_flamingo_client.py index c8be3bf0fc..249589748b 100644 --- a/src/helm/clients/vision_language/open_flamingo_client.py +++ b/src/helm/clients/vision_language/open_flamingo_client.py @@ -89,7 +89,7 @@ def make_request(self, request: Request) -> RequestResult: elif media_object.is_type(TEXT_TYPE): if media_object.text is None: raise ValueError("MediaObject of text type has missing text field value") - prompt_text += media_object.text + self.END_OF_CHUNK_TOKEN + prompt_text += media_object.text else: raise ValueError(f"Unrecognized MediaObject type {media_object.type}") @@ -123,7 +123,7 @@ def do_it(): ), f"Generated text: {generated_text} does not start with prompt: {prompt_text}" # Remove the prompt from the generated text - generated_text = generated_text[len(prompt_text) :].strip() + generated_text = generated_text[len(prompt_text) :].replace(self.END_OF_CHUNK_TOKEN, "").strip() return {"output": generated_text} cache_key = CachingClient.make_cache_key( diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index 239c793133..d40c93dffb 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -1155,7 +1155,7 @@ models: access: open num_parameters: 9000000000 release_date: 2023-08-02 - tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_VLM_TAG] + tags: [VISION_LANGUAGE_MODEL_TAG, OPEN_FLAMINGO_MODEL_TAG, LIMITED_FUNCTIONALITY_VLM_TAG] - name: microsoft/phi-2 display_name: Phi-2 From 1404de40ecec8ecb5dfe0e32c5c8d0212571a306 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 4 Mar 2024 11:25:11 -0800 Subject: [PATCH 33/35] fix decoding --- .../vision_language/open_flamingo_client.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/helm/clients/vision_language/open_flamingo_client.py b/src/helm/clients/vision_language/open_flamingo_client.py index 249589748b..170898d3a6 100644 --- a/src/helm/clients/vision_language/open_flamingo_client.py +++ b/src/helm/clients/vision_language/open_flamingo_client.py @@ -1,5 +1,5 @@ from threading import Lock -from typing import List, Optional +from typing import List, Optional, Tuple import torch from huggingface_hub import hf_hub_download @@ -106,25 +106,31 @@ def make_request(self, request: Request) -> RequestResult: try: generation_args = { "max_new_tokens": request.max_tokens, - "num_beams": 1, + "num_beams": request.num_completions, } def do_it(): - generated_text: str = self._model.generate( + tensors = self._model.generate( vision_x=vision_x.to(self._device), lang_x=lang_x["input_ids"].to(self._device), attention_mask=lang_x["attention_mask"].to(self._device), max_new_tokens=generation_args["max_new_tokens"], num_beams=generation_args["num_beams"], + num_return_sequences=generation_args["num_beams"], ) - generated_text = self.tokenizer.decode(generated_text[0]) - assert generated_text.startswith( - prompt_text - ), f"Generated text: {generated_text} does not start with prompt: {prompt_text}" + generated_completions: List[Tuple[str, List[str]]] = [] + for tensor in tensors: + generated_text: str = self.tokenizer.decode(tensor) + assert generated_text.startswith( + prompt_text + ), f"Generated text: {generated_text} does not start with prompt: {prompt_text}" - # Remove the prompt from the generated text - generated_text = generated_text[len(prompt_text) :].replace(self.END_OF_CHUNK_TOKEN, "").strip() - return {"output": generated_text} + # Remove the prompt from the generated text + generated_text = generated_text[len(prompt_text) :].replace(self.END_OF_CHUNK_TOKEN, "").strip() + raw_tokens: List[str] = self.tokenizer.tokenize(generated_text) + generated_completions.append((generated_text, raw_tokens)) + + return {"output": generated_completions} cache_key = CachingClient.make_cache_key( raw_request={ @@ -138,8 +144,12 @@ def do_it(): except RuntimeError as e: return RequestResult(success=False, cached=False, error=str(e), completions=[], embedding=[]) - tokens: List[Token] = [Token(text=str(self.tokenizer.decode(id)), logprob=0) for id in lang_x["input_ids"][0]] - completions: List[Sequence] = [Sequence(text=result["output"], logprob=0, tokens=tokens)] + completions: List[Sequence] = [] + for text, tokens in result["output"]: + completions.append( + Sequence(text=result["output"], logprob=0, tokens=[Token(text=token, logprob=0) for token in tokens]) + ) + return RequestResult( success=True, cached=cached, From ac19049aa318fd50a17f73b5f99fed824caba79b Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 4 Mar 2024 11:38:34 -0800 Subject: [PATCH 34/35] fix sequence construction --- src/helm/clients/vision_language/open_flamingo_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helm/clients/vision_language/open_flamingo_client.py b/src/helm/clients/vision_language/open_flamingo_client.py index 170898d3a6..d6e2e22cfa 100644 --- a/src/helm/clients/vision_language/open_flamingo_client.py +++ b/src/helm/clients/vision_language/open_flamingo_client.py @@ -147,7 +147,7 @@ def do_it(): completions: List[Sequence] = [] for text, tokens in result["output"]: completions.append( - Sequence(text=result["output"], logprob=0, tokens=[Token(text=token, logprob=0) for token in tokens]) + Sequence(text=text, logprob=0, tokens=[Token(text=token, logprob=0) for token in tokens]) ) return RequestResult( From 8c6cdcb4a6e7e8acb3cca9e4c8633f173219114e Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 4 Mar 2024 11:48:32 -0800 Subject: [PATCH 35/35] include num_completions in cache key --- src/helm/clients/vision_language/open_flamingo_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/helm/clients/vision_language/open_flamingo_client.py b/src/helm/clients/vision_language/open_flamingo_client.py index d6e2e22cfa..daa0222038 100644 --- a/src/helm/clients/vision_language/open_flamingo_client.py +++ b/src/helm/clients/vision_language/open_flamingo_client.py @@ -107,6 +107,7 @@ def make_request(self, request: Request) -> RequestResult: generation_args = { "max_new_tokens": request.max_tokens, "num_beams": request.num_completions, + "n": request.num_completions, } def do_it():