Skip to content

Commit

Permalink
[WIP] check video datasets, in progress (EvolvingLMMs-Lab#110)
Browse files Browse the repository at this point in the history
* Refactor process_results function to handle full_docs in videochatgpt task

* Refactor models/__init__.py to add "Reka" model

* chore: Add "Llava_OneVision" model to lmms_eval/models/__init__.py

* update onevision model interface.

* Fix error message when importing claude in lmms_eval/models/claude.py

* Refactor lmms_eval/models/llava_vid.py for safer object type handling in llama_3 conv_template

* Refactor llava_vid.py to handle pad_token_ids for "llama_3" conv_template

* Refactor lmms_eval/api/task.py to update datasets.config for streaming read retries and intervals
  • Loading branch information
Luodian committed Jun 9, 2024
1 parent cb5e71f commit a81e30b
Show file tree
Hide file tree
Showing 19 changed files with 708 additions and 82 deletions.
4 changes: 4 additions & 0 deletions lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from tqdm import tqdm

import datasets

datasets.config.STREAMING_READ_MAX_RETRIES = 20 # default
datasets.config.STREAMING_READ_RETRY_INTERVAL = 5 # default

from datasets import Image, Sequence
import numpy as np
from PIL import ImageFile
Expand Down
2 changes: 2 additions & 0 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ def evaluate(
vals_torch[(task_name, key, metric)] = gathered_item

vals = vals_torch
# Ensure all ranks wait for rank 0 to finish aggregation
torch.distributed.barrier()

if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
Expand Down
2 changes: 2 additions & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
"internvl": "InternVLChat",
"gemini_api": "GeminiAPI",
"gemini_model": "GeminiModel",
"reka": "Reka",
"llava_onevision": "Llava_OneVision",
"from_log": "FromLog",
}

Expand Down
3 changes: 3 additions & 0 deletions lmms_eval/models/batch_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
"api-key": API_KEY,
"Content-Type": "application/json",
}
else:
API_URL = "YOUR_API_URL"
API_KEY = "YOUR_API_KEY"


@register_model("batch_gpt4")
Expand Down
72 changes: 66 additions & 6 deletions lmms_eval/models/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy
import os
import base64
import json
from typing import List, Tuple, Union
from tqdm import tqdm
import requests as url_requests
Expand All @@ -22,8 +23,10 @@

try:
import anthropic
except:
eval_logger.debug("Can not import anthropic")
from decord import VideoReader, cpu
import numpy as np
except Exception as e:
eval_logger.error(f"Error importing claude: {e}")

API_URL = os.getenv("ANTHROPIC_API_URL", "https://api.anthropic.com/v1/complete")
API_KEY = os.getenv("ANTHROPIC_API_KEY", "YOUR_API_KEY")
Expand All @@ -36,12 +39,30 @@ def __init__(
model_version: str = "claude-3-opus-20240229",
image_token: str = "<image>", # Use to separate interleaved image and text
system_prompt: str = "", # Whether you want some special system prompt here
modality: str = "image",
continual_mode: bool = False,
response_persistent_folder: str = None,
**kwargs,
) -> None:
super().__init__()
self.model_version = model_version
self.image_token = image_token
self.system_prompt = system_prompt
self.modality = modality

self.continual_mode = continual_mode
if self.continual_mode and response_persistent_folder is None:
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")
self.response_persistent_folder = response_persistent_folder
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")

if os.path.exists(self.response_persistent_file):
with open(self.response_persistent_file, "r") as f:
self.response_cache = json.load(f)
self.cache_mode = "resume"
else:
self.response_cache = {}
self.cache_mode = "start"

accelerator = Accelerator()
if accelerator.num_processes > 1:
Expand Down Expand Up @@ -105,6 +126,24 @@ def shrink_image_to_file_size(self, img: Image, max_file_size=4838990) -> Image:

return self.shrink_image_to_file_size(img, max_file_size)

def encode_video(self, video_path):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.max_frames_for_video, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()

base64_frames = []
for frame in frames:
img = Image.fromarray(frame)
output_buffer = BytesIO()
img.save(output_buffer, format="PNG")
byte_data = output_buffer.getvalue()
base64_str = base64.b64encode(byte_data).decode("utf-8")
base64_frames.append(f"data:image/jpeg;base64,{base64_str}")

return base64_frames

def generate_until(self, requests) -> List[str]:
client = anthropic.Anthropic()

Expand All @@ -127,14 +166,28 @@ def generate_until(self, requests) -> List[str]:
]

for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
###################### CONTINUAL MODE ######################
if self.continual_mode is True and self.cache_mode == "resume":
doc_uuid = f"{task}___{split}___{doc_id}"
if doc_uuid in self.response_cache:
response_text = self.response_cache[doc_uuid]
if response_text:
res.append(response_text)
pbar.update(1)
continue

visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
imgs = []
for visual in visuals:
visual = self.shrink_image_to_file_size(visual)
img = self.encode_image(visual)
imgs.append(img)
if isinstance(visual, str) and os.path.exists(visual): # Assuming visual is a path to a video
visual = self.encode_video(visual)
for img in visual:
imgs.append(img)
else:
visual = self.shrink_image_to_file_size(visual)
img = self.encode_image(visual)
imgs.append(img)

messages = deepcopy(empty_messages)

Expand Down Expand Up @@ -188,6 +241,13 @@ def generate_until(self, requests) -> List[str]:
res.append(message.content[0].text)
pbar.update(1)

###################### CONTINUAL MODE ######################
if self.continual_mode is True: # Cache the response
doc_uuid = f"{task}___{split}___{doc_id}"
self.response_cache[doc_uuid] = response_text
with open(self.response_persistent_file, "w") as f:
json.dump(self.response_cache, f)

pbar.close()

return res
Expand Down
Loading

0 comments on commit a81e30b

Please sign in to comment.