In [0]:
%sh 
apt update
apt-get install -y poppler-utils

In [0]:
%pip install -r requirements.txt

In [0]:
import yaml
import os

with open("../configs/olmocr_config.yaml", "r") as f:
    config = yaml.safe_load(f)

catalog_name = config.get("catalog_name")
schema_name = config.get("schema_name")
volume_name = config.get("volume_name")
volume_folder = config.get("volume_folder")
model_name = config.get("model_name")
revision = config.get("revision")

cache_volume =  f"/Volumes/{catalog_name}/{schema_name}/{volume_name}/{model_name}/{revision}/{volume_folder}"
cache_hf = "/local_disk0/hf_cache"
cache_local = f"/local_disk0/{volume_folder}" 

os.environ["HF_HOME"] = cache_hf
os.environ["HF_HUB_CACHE"] = cache_hf
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "True"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "1000"
# os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'  # Enables optimized download backend

In [0]:
import shutil
import os

# Copy volume cache to local cache if not already there
if not os.path.exists(cache_local):
    try: 
        print(f"Loading model from {cache_volume} to {cache_local}.")
        snapshots_dir = '/'.join(cache_local.split('/')[:-1])
        if not os.path.exists(snapshots_dir):
            os.makedirs(snapshots_dir)
        
        shutil.copytree(cache_volume, cache_local) 
        print(f"Successfully loaded model from {cache_volume} to {cache_local}!")
    except Exception as e: 
        print(f"Error: {e}")
else:
    print(f"File already exists locally at {cache_local}")

In [0]:
%pip install qwen-vl-utils[decord]==0.0.8

In [0]:
import pandas as pd
import torch
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    AutoModelForImageTextToText
)
from qwen_vl_utils import process_vision_info
import mlflow.pyfunc
import base64
from io import BytesIO
from PIL import Image

class OlmocrPyfunc(mlflow.pyfunc.PythonModel):
    def load_context(self, context):

        self.model_id = context.artifacts["model-weights"]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32

        print("************************************")
        print(f"Device: {self.device}, dtype: {self.dtype}")
        print(f"Loading model {self.model_id} to {self.device}")
        print("************************************")

        self.processor = AutoProcessor.from_pretrained(self.model_id)

        self.model = AutoModelForImageTextToText.from_pretrained(
            self.model_id, torch_dtype=self.dtype, device_map="auto"
        )

    def predict(self, model_input: pd.DataFrame, params: dict = None) -> pd.Series:
      outputs = []
      max_tokens = params.get("max_tokens", 1024) if params else 1024

      for _, row in model_input.iterrows():
          system_prompt = row.get(
              "system_prompt",
              "You are a helpful assistant that extracts text from PDF images.",
          )
          user_prompt = row.get("user_prompt", "")

          try:
              image_base64 = user_prompt["image"]
              image_data = base64.b64decode(image_base64)
              image = Image.open(BytesIO(image_data)).convert("RGB")
          except Exception as e:
              outputs.append(f"Error processing image: {str(e)}")
              continue

          messages = [
              {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
              {
                  "role": "user",
                  "content": [
                      {"type": "image", "image": image},
                      {"type": "text", "text": user_prompt["text"]},
                  ],
              },
          ]
          
          text = self.processor.apply_chat_template(
              messages, tokenize=False, add_generation_prompt=True
          )

          image_inputs, video_inputs = process_vision_info(messages)

          inputs = self.processor(  
              text=[text],
              images=image_inputs,
              videos=video_inputs,
              padding=True,
              return_tensors="pt",
          ).to(self.device)

          prompt_len = inputs["input_ids"].size(-1)

          with torch.inference_mode():
              generation = self.model.generate(**inputs, max_new_tokens=max_tokens)

          generated_tokens = generation[0][prompt_len:]
          
          output_text = self.processor.batch_decode(
              generated_tokens.unsqueeze(0),  
              skip_special_tokens=True,
              clean_up_tokenization_spaces=False,
          )[0]  
          
          outputs.append(output_text)

      return pd.Series(outputs)

In [0]:
import pandas as pd
import json
import base64
from PIL import Image
import requests
from io import BytesIO
from pdf2image import convert_from_bytes

url = "https://arxiv.org/pdf/2502.13923"
response = requests.get(url)
pdf_bytes = response.content
    
pil_images = convert_from_bytes(pdf_bytes)

img = pil_images[0]
buffer = BytesIO()
img.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

system_prompt="You are a helpful assistant that extracts text from PDF images."
prompt = "Extract the entire text from the abstract section in this image."

df = pd.DataFrame({
    "system_prompt": [system_prompt],
    "user_prompt": [
        {
            "text": prompt,
            "image": image_base64
        }
    ]
})

In [0]:
class Context:
    def __init__(self, artifacts):
        self.artifacts = artifacts

olmocr = OlmocrPyfunc()
olmocr.load_context(Context({"model-weights": cache_local}))

output = olmocr.predict(df, params={"max_tokens": 512})

print(output[0])

# Log to mlflow

In [0]:
import sys
import os
from mlflow.models import infer_signature
import mlflow
import pandas as pd
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec

mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

In [0]:
signature = infer_signature(
  model_input=df, 
  model_output=output,
  params={"max_tokens": 512}
  ) # Doing strict schema to avoid rerunning pipeline

In [0]:
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="olmocr_pyfunc",
        python_model=OlmocrPyfunc(),
        signature=signature,
        # conda_env="conda.yaml",
        pip_requirements=[
            "torch==2.6.0",
            "transformers==4.55.4",
            "accelerate==1.10.1",
            "huggingface_hub==0.34.4",
            "Pillow==11.3.0",
            "flask==3.1.2",
            "bitsandbytes==0.47.0",
            "pdf2image==1.17.0",
            "qwen-vl-utils==0.0.8"
        ],
        # extra_pip_requirements=package_versions,  
        artifacts={
            'model-weights': cache_local},
        input_example = df
        )

In [0]:
served_model_name = "olmocr_hf_deployment"
uc_model_name = f"{catalog_name}.{schema_name}.{served_model_name}"

result = mlflow.register_model(
    model_uri=model_info.model_uri,
    name=uc_model_name
)

print(f"Registered model version: {result.version}")

In [0]:
import mlflow

client = mlflow.MlflowClient()
client.set_registered_model_alias(
    name=uc_model_name,
    alias="Challenger",
    version=result.version
)

In [0]:
import mlflow 
model_uri = f"models:/{uc_model_name}@Challenger"
print(model_uri)

# Test the model first and then promote to Challenge
TODO: Modify code with appropriate usage of VRAM on GPU 

In [0]:
served_model_name = "olmocr_hf_deployment"
uc_model_name = f"{catalog_name}.{schema_name}.{served_model_name}"

In [0]:
client.set_registered_model_alias(
    name=uc_model_name,
    alias="Champion",
    version=result.version
)

In [0]:
model_info = client.get_model_version_by_alias(uc_model_name, "Champion")
model_name = model_info.name
model_version = model_info.version
served_entity_name = served_model_name
user_email = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()

In [0]:
# For Azure, we will use GPU LARGE 
# May need multi-gpu A10s for AWS 

# Deploy Model

In [0]:
import requests
from datetime import datetime, timedelta
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
        EndpointCoreConfigInput,
        ServedEntityInput,
        AutoCaptureConfigInput,
        ServingEndpointDetailed,
        EndpointTag
    )

served_entities = [
    ServedEntityInput(
        entity_name=model_name,
        entity_version=model_version,
        name=served_entity_name,
        workload_type="GPU_LARGE",
        workload_size="Small",
        scale_to_zero_enabled=True,
    )
]
auto_capture_config = AutoCaptureConfigInput(
    catalog_name=catalog_name,
    schema_name=schema_name,
    table_name_prefix=f"{model_name}_serving",
    enabled=True,
)

w = WorkspaceClient()

endpoint_details = w.serving_endpoints.create_and_wait(
            name=f"{served_entity_name}_endpoint",
            config=EndpointCoreConfigInput(
                name=f"{served_entity_name}_endpoint",
                served_entities=served_entities,
                auto_capture_config=None
            ),
            tags=[
                EndpointTag(key="application", value=served_entity_name),
                EndpointTag(key="created_by", value=user_email)
            ],
            timeout = timedelta(minutes=180) # wait up to three hours
        )

# Test Deployed Model

In [0]:
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
os.environ["DATABRICKS_TOKEN"] = databricks_token

endpoint_url = "CHANGE_ME"

In [0]:
import pandas as pd
import json
import base64
from PIL import Image
import requests
from io import BytesIO
from pdf2image import convert_from_bytes

url = "https://arxiv.org/pdf/2502.13923"
response = requests.get(url)
pdf_bytes = response.content
    
pil_images = convert_from_bytes(pdf_bytes)
system_prompt="You are a helpful assistant that extracts text from PDF images."
prompt = "Extract the entire text from the abstract section in this image."

img = pil_images[0]
buffer = BytesIO()
img.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

df = pd.DataFrame({
    "system_prompt": [system_prompt],
    "user_prompt": [
        {
            "text": prompt,
            "image": image_base64
        }
    ]
})

In [0]:
import os
import requests
import numpy as np
import pandas as pd
import json

def create_tf_serving_json(data):
    return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}

def score_model(dataset):
    url = endpoint_url,
    headers = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
    ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    return response.json()

In [0]:
score_model(df) 

# Restart kernel to clear VRAM 

In [0]:
%restart_python

# Optional: Test Registered Model 
Note: Careful of VRAM consumption

In [0]:
loaded_model = mlflow.pyfunc.load_model(model_uri)

In [0]:
import pandas as pd
import json
import base64
from PIL import Image
import requests
from io import BytesIO
from pdf2image import convert_from_bytes

url = "https://arxiv.org/pdf/2502.13923"
response = requests.get(url)
pdf_bytes = response.content
    
pil_images = convert_from_bytes(pdf_bytes)

img = pil_images[0]
buffer = BytesIO()
img.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

system_prompt="You are a helpful assistant that extracts text from PDF images."
prompt = "Extract the entire text from the abstract section in this image."

df = pd.DataFrame({
    "system_prompt": [system_prompt],
    "user_prompt": [
        {
            "text": prompt,
            "image": image_base64
        }
    ]
})

outputs = loaded_model.predict(df, params={"max_tokens": 512})

In [0]:
print(str(outputs[0]))

In [0]:
import pandas as pd
import json
import io
import base64
from PIL import Image

img = pil_images[0]
buffer = io.BytesIO()
img.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

df = pd.DataFrame({
    "system_prompt": [system_prompt],
    "user_prompt": [
        {
            "text": prompt,
            "image": image_base64
        }
    ]
})

outputs = loaded_model.predict(df, params={"max_tokens": 512})