In [None]:
import torch
import flash_attn
import os
from copy import deepcopy
from typing import (
    Any,
    AsyncIterable,
    Callable,
    Dict,
    Generator,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)
import requests
from io import BytesIO

from PIL import Image
import torch
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights

# Ensure BAGEL specific modules are in your PYTHONPATH or current working directory structure
# These are assumed to be part of the BAGEL codebase you have in /workspace/BAGEL/
from data.transforms import ImageTransform
from data.data_utils import pil_img2rgb, add_special_tokens
from modeling.bagel import (
    BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.autoencoder import load_ae
from safetensors.torch import load_file

import random
import numpy as np

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True # Can impact performance
# torch.backends.cudnn.benchmark = False   # Can impact performance
print(f"Seeds set to {seed}")

print("All initial imports successful.")

model_path = "downloads"  # Path relative to /workspace. Should be /workspace/downloads/
if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, "llm_config.json")):
    print(f"ERROR: Model path '{model_path}' not found or doesn't contain expected files.")
    print("Please download BAGEL model files from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT and place them in /workspace/downloads/")
    # You might want to stop execution here if files are missing
    # raise FileNotFoundError("Model files not found in 'downloads' directory")
else:
    print(f"Using model files from: {os.path.abspath(model_path)}")

# LLM config preparing
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"

# ViT config preparing
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1

# VAE loading
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))

# Bagel config preparing
config = BagelConfig(
    visual_gen=True,
    visual_und=True,
    llm_config=llm_config, 
    vit_config=vit_config,
    vae_config=vae_config,
    vit_max_num_patch_per_side=70,
    connector_act='gelu_pytorch_tanh',
    latent_patch_size=2,
    max_latent_size=64,
)

with init_empty_weights():
    language_model = Qwen2ForCausalLM(llm_config)
    vit_model      = SiglipVisionModel(vit_config)
    model          = Bagel(language_model, vit_model, config)
    model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)

# Tokenizer Preparing
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)

# Image Transform Preparing
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)

print("Model components initialized.")


In [None]:
import torch
import os

num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")

if num_gpus == 0:
    raise EnvironmentError("No GPUs detected by PyTorch. Cannot proceed with GPU mapping.")

# llm_config and model (empty shell) must be defined before this cell
num_llm_layers = llm_config.num_hidden_layers
print(f"Number of LLM layers: {num_llm_layers}")

device_map = {}

# --- LLM Core Distribution ---
llm_embed_gpu = 0
device_map["language_model.model.embed_tokens"] = llm_embed_gpu

# Distribute LLM layers
# For 2 GPUs: layers 0 to N/2-1 on GPU 0, N/2 to N-1 on GPU 1
# For N GPUs: round-robin or block-wise
llm_layer_split_points = [ (i * num_llm_layers) // num_gpus for i in range(num_gpus + 1) ]
current_gpu_for_llm = 0
for i in range(num_llm_layers):
    if num_gpus > 1 and i >= llm_layer_split_points[current_gpu_for_llm + 1]:
        current_gpu_for_llm = min(current_gpu_for_llm + 1, num_gpus - 1)
    device_map[f"language_model.model.layers.{i}"] = current_gpu_for_llm

final_llm_components_gpu = current_gpu_for_llm # GPU where last LLM layer landed
device_map["language_model.model.norm"] = final_llm_components_gpu
device_map["language_model.lm_head"] = final_llm_components_gpu
if hasattr(model.language_model.model, 'norm_moe_gen'):
    device_map["language_model.model.norm_moe_gen"] = final_llm_components_gpu

# --- Vision Components ---
# vit_model itself can be large, place it on a GPU (e.g., the last one if multi-GPU)
vit_processing_gpu = (num_gpus - 1) if num_gpus > 1 else 0
device_map["vit_model"] = vit_processing_gpu

# **CRITICAL FIX:** Place vit_pos_embed and connector on the GPU where their
# output will be combined with other embeddings for the LLM.
# This is typically the GPU where the LLM's main processing begins (where embed_tokens are).
interaction_gpu = llm_embed_gpu # cuda:0 in this setup

device_map["vit_pos_embed"] = interaction_gpu # MOVED TO GPU 0
device_map["connector"] = interaction_gpu     # Kept on GPU 0

# --- Other Bagel-specific components ---
device_map["time_embedder"] = llm_embed_gpu
device_map["latent_pos_embed"] = llm_embed_gpu
device_map["vae2llm"] = llm_embed_gpu
device_map["llm2vae"] = final_llm_components_gpu


print("Constructed manual device_map:", device_map)

# --- The rest of your loading code (offload_folder, load_checkpoint_and_dispatch) ---
offload_folder = "/tmp/offload_bagel"
disk_offload_active = any(v == 'disk' for v in device_map.values())
if disk_offload_active:
    if not os.path.exists(offload_folder):
        os.makedirs(offload_folder, exist_ok=True)
    print(f"Created offload folder: {offload_folder}")
    print(f"Disk offloading is active. Using offload folder: {offload_folder}")
else:
    print("No disk offloading specified in manual device_map.")

# model_path = "downloads" # Ensure model_path is defined

print("Attempting to load checkpoint with manual device map...")
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=os.path.join(model_path, "ema.safetensors"),
    device_map=device_map,
    offload_buffers=True,
    dtype=torch.bfloat16,
    force_hooks=True,
    offload_folder=offload_folder,
    max_memory={g: "44GiB" for g in range(num_gpus)}
)

model = model.eval()
print('Model loaded and dispatched successfully.')

offload_folder = "/tmp/offload_bagel"
disk_offload_active = any(v == 'disk' for v in device_map.values())
if disk_offload_active:
    if not os.path.exists(offload_folder):
        os.makedirs(offload_folder, exist_ok=True)
        print(f"Created offload folder: {offload_folder}")
    print(f"Disk offloading is active. Offload folder: {offload_folder}")
else:
    print("No disk offloading specified in manual device_map.")

# Ensure model_path is defined
# model_path = "downloads"

print("Attempting to load checkpoint with manual device map...")
model = load_checkpoint_and_dispatch(
    model,
    checkpoint=os.path.join(model_path, "ema.safetensors"),
    device_map=device_map,
    offload_buffers=True,
    dtype=torch.bfloat16,
    force_hooks=True,
    offload_folder=offload_folder,
)

model = model.eval()
print('Model loaded and dispatched.')
# ----- END MODIFIED SECTION -----

from inferencer import InterleaveInferencer

inferencer = InterleaveInferencer(
    model=model, 
    vae_model=vae_model, 
    tokenizer=tokenizer, 
    vae_transform=vae_transform, 
    vit_transform=vit_transform, 
    new_token_ids=new_token_ids
)
print("Inferencer prepared.")



**About Inference Hyperparameters:**
- **`cfg_text_scale`:** Controls how strongly the model follows the text prompt. `1.0` disables text guidance. Typical range: `4.0–8.0`.
- **`cfg_image_scale`:** Controls how much the model preserves input image details. `1.0` disables image guidance. Typical range: `1.0–2.0`.
- **`cfg_interval`:** Fraction of denoising steps where CFG is applied. Later steps can skip CFG to reduce computation. Typical: `[0.4, 1.0]`.
- **`timestep_shift`:** Shifts the distribution of denoising steps. Higher values allocate more steps at the start (affects layout); lower values allocate more at the end (improves details).
- **`num_timesteps`:** Total denoising steps. Typical: `50`.
- **`cfg_renorm_min`:** Minimum value for CFG-Renorm. `1.0` disables renorm. Typical: `0`.
- **`cfg_renorm_type`:** CFG-Renorm method:  
  - `global`: Normalize over all tokens and channels (default for T2I).
  - `channel`: Normalize across channels for each token.
  - `text_channel`: Like `channel`, but only applies to text condition (good for editing, may cause blur).
- **If edited images appear blurry, try `global` CFG-Renorm, decrease `cfg_renorm_min` or decrease `cfg_scale`.**


In [None]:
inference_hyper=dict(
    cfg_text_scale=4.0,
    cfg_img_scale=1.0,
    cfg_interval=[0.4, 1.0],
    timestep_shift=3.0,
    num_timesteps=50,
    cfg_renorm_min=1.0,
    cfg_renorm_type="global",
)

prompt = "A colorful friendly crhulhu is smiling and kissing a nice emo girl, cartoon"

print("Prompt:", prompt)
print('-' * 10)
try:
    output_dict = inferencer(text=prompt, **inference_hyper)
    if 'image' in output_dict and output_dict['image'] is not None:
        print("Image generated successfully.")
        display(output_dict['image'])
    else:
        print("Inference completed, but no image found in output_dict or image is None.")
        print("Output dictionary:", output_dict)
except Exception as e:
    print(f"An error occurred during inference: {e}")
    import traceback
    traceback.print_exc()

## Image Generation with Think

In [None]:
inference_hyper_think=dict(
    max_think_token_n=1000,
    do_sample=False,
    # text_temperature=0.3,
    cfg_text_scale=4.0,
    cfg_img_scale=1.0,
    cfg_interval=[0.4, 1.0],
    timestep_shift=3.0,
    num_timesteps=50,
    cfg_renorm_min=1.0,
    cfg_renorm_type="global",
)

prompt_think = 'a car made of small cars'

print("Prompt for think & gen:", prompt_think)
print('-' * 10)
try:
    output_dict_think = inferencer(text=prompt_think, think=True, **inference_hyper_think)
    print("Generated text (thought process):", output_dict_think.get('text', 'N/A'))
    if 'image' in output_dict_think and output_dict_think['image'] is not None:
        print("Image generated successfully with think.")
        display(output_dict_think['image'])
    else:
        print("Inference with think completed, but no image found in output_dict or image is None.")
        print("Output dictionary (think):", output_dict_think)
except Exception as e:
    print(f"An error occurred during inference with think: {e}")
    import traceback
    traceback.print_exc()

## Editing

In [None]:
inference_hyper_edit=dict(
    cfg_text_scale=4.0,
    cfg_img_scale=2.0,
    cfg_interval=[0.0, 1.0],
    timestep_shift=3.0,
    num_timesteps=50,
    cfg_renorm_min=1.0,
    cfg_renorm_type="text_channel",
)

# Ensure you have a 'test_images' directory with 'women.jpg'
edit_image_path = 'test_images/women.jpg'
prompt_edit = 'She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.'

try:
    if os.path.exists(edit_image_path):
        image_edit_input = Image.open(edit_image_path)
        print("Input image for editing:")
        display(image_edit_input)
        print("Prompt for editing:", prompt_edit)
        print('-'*10)
        output_dict_edit = inferencer(image=image_edit_input, text=prompt_edit, **inference_hyper_edit)
        if 'image' in output_dict_edit and output_dict_edit['image'] is not None:
            print("Edited image generated successfully.")
            display(output_dict_edit['image'])
        else:
            print("Editing inference completed, but no image found in output_dict or image is None.")
            print("Output dictionary (edit):", output_dict_edit)
    else:
        print(f"ERROR: Edit image path not found: {edit_image_path}")
except Exception as e:
    print(f"An error occurred during editing inference: {e}")
    import traceback
    traceback.print_exc()

## Edit with Think

In [None]:
inference_hyper_edit_think=dict(
    max_think_token_n=1000,
    do_sample=False,
    # text_temperature=0.3,
    cfg_text_scale=4.0,
    cfg_img_scale=2.0,
    cfg_interval=[0.0, 1.0],
    timestep_shift=3.0,
    num_timesteps=50,
    cfg_renorm_min=0.0,
    cfg_renorm_type="text_channel",
)

# Ensure you have a 'test_images' directory with 'octupusy.jpg'
edit_think_image_path = 'test_images/octupusy.jpg'
prompt_edit_think = 'Could you display the sculpture that takes after this design?'

try:
    if os.path.exists(edit_think_image_path):
        image_edit_think_input = Image.open(edit_think_image_path)
        print("Input image for edit with think:")
        display(image_edit_think_input)
        print("Prompt for edit with think:", prompt_edit_think)
        print('-'*10)
        output_dict_edit_think = inferencer(image=image_edit_think_input, text=prompt_edit_think, think=True, **inference_hyper_edit_think)
        print("Generated text (edit with think):", output_dict_edit_think.get('text', 'N/A'))
        if 'image' in output_dict_edit_think and output_dict_edit_think['image'] is not None:
            print("Edited image with think generated successfully.")
            display(output_dict_edit_think['image'])
        else:
            print("Edit with think inference completed, but no image found in output_dict or image is None.")
            print("Output dictionary (edit with think):", output_dict_edit_think)
    else:
        print(f"ERROR: Edit with think image path not found: {edit_think_image_path}")
except Exception as e:
    print(f"An error occurred during edit with think inference: {e}")
    import traceback
    traceback.print_exc()

## Understanding

In [None]:
inference_hyper_und=dict(
    max_think_token_n=1000,
    do_sample=False,
    text_temperature=0.3,
)
# Ensure you have a 'test_images' directory with 'meme.jpg'
und_image_path = '/workspace/BAGEL/1.png'
prompt_und = "Fully transcribe all musical symbols"

try:
    if os.path.exists(und_image_path):
        image_und_input = Image.open(und_image_path)
        print("Input image for understanding:")
        display(image_und_input)
        print("Prompt for understanding:", prompt_und)
        print('-'*10)
        output_dict_und = inferencer(image=image_und_input, text=prompt_und, understanding_output=True, **inference_hyper_und)
        print("Generated text (understanding):", output_dict_und.get('text', 'N/A'))
    else:
        print(f"ERROR: Understanding image path not found: {und_image_path}")
except Exception as e:
    print(f"An error occurred during understanding inference: {e}")
    import traceback
    traceback.print_exc()