In [None]:
# !pip install git+https://github.com/unionai/unionai.git@enghabu/hub-image-spec

In [None]:
# Imports and init remote
import os
import subprocess
from union import task, workflow, FlyteFile, UnionRemote, ImageSpec, Resources, FlyteDirectory, ActorEnvironment, LaunchPlan
from union.remote import HuggingFaceModelInfo
from flytekit.configuration import Config

os.environ["UNION_CONFIG"] = "/Users/pryceturner/.union/config_serving.yaml"

remote = UnionRemote(config=Config.auto(config_file="/Users/pryceturner/.union/config_serving.yaml"))

In [None]:
# Emit HF repo as Artifact
info = HuggingFaceModelInfo(repo="boltz-community/boltz-1")

cache_exec = remote._create_model_from_hf(
    info=info, 
    hf_token_key="HF_TOKEN", 
    union_api_key="UNION_API_KEY",
)

cache_exec = cache_exec.wait(poll_interval=2)
cache_exec.outputs

In [None]:
# Define Image
image = ImageSpec(
    name="boltz",
    packages=[
        "union",
        "flytekit==1.13.14",
        "union-runtime==0.1.11",
        "fastapi==0.115.11",
        "pydantic==2.10.6",
        "boltz==0.4.1",
        "uvicorn==0.34.0",
        "python-multipart==0.0.20",
    ],
    apt_packages=["build-essential"],
    registry="docker.io/unionbio",
)

In [None]:
@task(container_image=image, requests=Resources(cpu="2", mem="10Gi", ephemeral_storage="50Gi", gpu="1"))
def simple_predict(input: FlyteFile) -> FlyteDirectory:
    input.download()
    out = "/tmp/boltz_out"
    os.makedirs(out, exist_ok=True)
    subprocess.run(["boltz", "predict", input.path, "--out_dir", out, "--use_msa_server"])
    return FlyteDirectory(path=out)

@workflow
def wf(input: FlyteFile) -> FlyteDirectory:
    return simple_predict(input=input)

execution = remote.execute(
    entity=wf,
    inputs={"input": "inputs/prot_no_msa.yaml"},
    wait=True
    )
output = execution.outputs
print(output)



In [None]:
actor = ActorEnvironment(
    name="boltz-actor",
    replica_count=1,
    ttl_seconds=600,
    requests=Resources(
        cpu="2",
        mem="10Gi",
        gpu="1",
    ),
    container_image=image,
)

In [None]:

@actor.task
def act_simple_predict(input: FlyteFile) -> FlyteDirectory:
    input.download()
    out = "/tmp/boltz_out"
    os.makedirs(out, exist_ok=True)
    subprocess.run(["boltz", "predict", input.path, "--out_dir", out, "--use_msa_server"])
    return FlyteDirectory(path=out)


@workflow
def act_wf(input: FlyteFile) -> FlyteDirectory:
    return simple_predict(input=input)

# remote.fast_register_workflow(entity=wf)
execution = remote.execute(
    entity=act_wf, 
    inputs={"input": "inputs/prot_no_msa.yaml"}, 
    wait=True
)
output = execution.outputs
print(output)