In [None]:
from pathlib import Path
import pandas as pd
import concurrent.futures
import requests
import threading
from tqdm import tqdm

In [None]:
video_dir=Path("/data/data/MSRVTT_Zero_Shot_QA/videos/all")
gt_file_question=Path("/data/data/MSRVTT_Zero_Shot_QA/test_q1000.json")
gt_file_answers=Path("/data/data/MSRVTT_Zero_Shot_QA/test_a.json")

In [None]:
# read annotations from disk
q_df = pd.read_json(gt_file_question)
a_df = pd.read_json(gt_file_answers)

df = pd.merge(q_df, a_df, on="question_id")
df

In [None]:
# prepare batches

def prepare_batches(df, batch_size, temperature, max_new_tokens):
    batches = []
    for start in range(0, len(df), batch_size):
        end = start + batch_size
        chunk = df.iloc[start:end]
        
        batch = {"inputs": []}

        for _, sample in chunk.iterrows():
            batch["inputs"].append(
                {
                    "video_path": Path(video_dir)
                    .joinpath(f"{sample.video_name}.mp4")
                    .resolve()
                    .as_posix(),
                    "text_prompt": sample.question,
                }
            )
        batch["temperature"] = temperature
        batch["max_new_tokens"] = max_new_tokens
        batches.append(batch)
    return batches

In [None]:
# run inference

batch_pool = prepare_batches(df, batch_size=24, temperature=0.1, max_new_tokens=1024)
print(len(batch_pool))

model_endpoints = ["http://localhost:5000/predict"]

# shared list to store results
results = []
results_lock = threading.Lock()


def worker(endpoint, batch_pool, pbar):
    while True:
        try:
            # Grab a batch from the pool
            batch = batch_pool.pop(0)
        except IndexError:
            # If the pool is empty, exit the loop
            break

        response = requests.post(endpoint, json=batch)

        if response.status_code == 200:
            result = response.json()
        else:
            print("Failed to get response:", response.status_code, response.text)
            batch_pool.append(batch)
            continue

        # Append the result to the shared list
        with results_lock:
            results.append(result)

        # Update the progress bar
        pbar.update(1)


def process_batches(batch_pool, model_endpoints):
    total_batches = len(batch_pool)

    # Create a thread pool with the same number of workers as model endpoints
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=len(model_endpoints)
    ) as executor:
        # Initialize the progress bar
        with tqdm(total=total_batches) as pbar:
            # Start a worker for each model endpoint
            futures = [
                executor.submit(worker, endpoint, batch_pool, pbar)
                for endpoint in model_endpoints
            ]

            # Wait for all workers to finish
            concurrent.futures.wait(futures)

In [None]:
process_batches(batch_pool, model_endpoints)

In [None]:
for result in results:
    print(result)