In [None]:
# Enable autoreload for automatic module reloading
%load_ext autoreload
%autoreload 2

In [None]:
# Modification Copyright© 2025 Advanced Micro Devices, Inc. All rights reserved.

import torch, os

from models.ar_model import Instella_AR_Model, BinaryAR

import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
from bae.ae_model import BAE_Model
import gc
from mmengine.config import Config
import argparse
from safetensors.torch import load_file as safe_load_file
from huggingface_hub import snapshot_download
cuda_num = 0
torch.cuda.set_device(cuda_num)

In [None]:
# Set seeds for reproducible results
import random
import numpy as np
import torch

def set_seeds(seed=42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    
    # Make CuDNN deterministic (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Set Python hash seed for complete reproducibility
    import os
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    print(f"All seeds set to {seed} for reproducible results")

# Set the seed
SEED = 42
set_seeds(SEED)

In [None]:
from types import SimpleNamespace
"""Create args namespace with default configuration values"""
args = SimpleNamespace(
    ckpt_path='./checkpoints',           # Path to the diffusion model ckpt
    num_tkn=128,                         # Number of image tokens for the BAE tokenizer
    image_size=1024,                      # Output image size (512, 768, or 1024)
    codesize=128,                        # Codebook size of the BAE tokenizer
    cfg_scale=7.5,                       # Scale of classifier-free guidance
    temp=1.0,                           # Temperature for sampling
    rho=1.0,                            # Rho for sampling
    num_steps=50,                       # Number of inference steps
    sampling_protocal='protocal_1',     # Sampling protocal ('protocal_1' or 'protocal_2')
    config='configs/ar_config.py'     # Path to the config file
)

In [None]:
print(model_config)

In [None]:
# Model download
local_path = snapshot_download(
        repo_id="amd/Instella-T2I",
        local_dir=args.ckpt_path, 
    )

print("Model downloaded to:", local_path)

model_config = Config.fromfile(args.config)
bae_config = Config.fromfile(model_config.bae_config)
bae_ckpt = model_config.bae_ckpt

print(f"BAE checkpoint path: {bae_ckpt}")
print(f"BAE checkpoint exists: {os.path.exists(bae_ckpt)}")

weight_dtype = torch.bfloat16

# ======================================================
# Build models
# ======================================================
print('Build diffusion model and load weight')
olmo_path = 'amd/AMD-OLMo-1B'
llm = AutoModelForCausalLM.from_pretrained(olmo_path, attn_implementation="flash_attention_2", torch_dtype=weight_dtype).to("cuda") # remove .to("cuda") to load on cpu
tokenizer = AutoTokenizer.from_pretrained(olmo_path, model_max_length=128, padding_side='left')

model = Instella_AR_Model(
                            in_channels = bae_config.codebook_size,
                            num_layers = llm.config.num_hidden_layers,
                            attention_head_dim = model_config.get('attention_head_dim', 128),
                            num_attention_heads = model_config.get('num_attention_heads', 16),
                            num_img_tkns = model_config.num_tkns,
                            text_cond_dim = llm.config.hidden_size,
                            )

model.eval()

ckpt = torch.load(f'{args.ckpt_path}/ar.pt', map_location='cpu')['module']
model.load_state_dict(ckpt)

bae = BAE_Model(bae_config)

print('Loading BAE model weights')
# Fixed: Removed device='cpu' parameter that was causing "No such device" error
bae_state_dict = safe_load_file(bae_ckpt)
bae.load_state_dict(bae_state_dict, strict=True)
del bae_state_dict
gc.collect()
torch.cuda.empty_cache()

bae.to('cuda', dtype=weight_dtype)
bae.eval()
if model_config.get('bae_scale', None) is not None:
    bae.set_scale(model_config.bae_scale)

bae.requires_grad_(False)

model = model.to('cuda', dtype=weight_dtype)

num_sampling_steps = args.num_steps
img_size = args.image_size

guidance_scale = args.cfg_scale
temp = args.temp

binary_ar = BinaryAR(model_config.num_tkns)

os.makedirs('results', exist_ok=True)

print("✓ All models loaded and ready for inference!")

In [None]:
# Set seeds for reproducible results
import random
import numpy as np
import torch

def set_seeds(seed=42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    
    # Make CuDNN deterministic (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Set Python hash seed for complete reproducibility
    import os
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    print(f"All seeds set to {seed} for reproducible results")

# Set the seed
SEED = 4

In [None]:
# Batch generation - generate multiple images from a list of prompts
set_seeds(1)
delta = 0
text = "A dog"                                                                                                                                                              
with torch.no_grad():
    ts = time.time()
    z = binary_ar.sample(model, tokenizer, llm, [text], guidance_scale=guidance_scale, temp=temp, delta=delta)
    z = z.to(weight_dtype)
    
    samples = bae.decode(z, img_size//16, img_size//16)
    te = time.time()

    samples = samples[0].float()

    samples = torch.clamp(samples, -1.0, 1.0)
    samples = (samples + 1) / 2

    samples = samples.permute(1, 2, 0).mul_(255).cpu().numpy()
    image = Image.fromarray(samples.astype(np.uint8))
    name = text.split(' ')[:5]
    name = '_'.join(name)
    image.save(f'results/{img_size}_{num_sampling_steps}_{guidance_scale}_{temp}_{name}.jpg')
    sp = te - ts
print(f'Generation finished in {sp:.2f}s')
display(image)


In [None]:
print(z)  # Print the shape of the generated latent codes

In [None]:
display(image)


In [None]:
# Invert the image processing - reverse the normalization
image = image.resize((512,512), Image.LANCZOS)
samples = torch.tensor(np.array(image.copy()))

# Reverse the operations: convert back from [0,1] to [-1,1] range
samples = samples / 255.0  # Convert from [0,255] to [0,1]
samples = samples.permute(2, 0, 1)  # Convert from HWC to CHW format
samples = samples * 2 - 1  # Convert from [0,1] to [-1,1]
samples = torch.clamp(samples, -1.0, 1.0)

print(f"Inverted samples shape: {samples.shape}")
print(f"Inverted samples range: [{samples.min():.3f}, {samples.max():.3f}]")

# If you want to encode this back through the BAE encoder:
samples_batch = samples.unsqueeze(0).to('cuda', dtype=weight_dtype)  # Add batch dimension
encoded = bae.encode(samples_batch)

In [None]:
logits, smth, binary = encoded

In [None]:
#print(z.reshape(-1).to(torch.int32).tolist())  # Print the binary code as a flat list
print(binary['binary_code'].reshape(-1).to(torch.int32).tolist())  # Print the binary code as a flat list

In [None]:
# Count the matching ratio between z and binary['binary_code']
import torch

# Flatten both tensors for comparison
z_flat = z.reshape(-1).to(torch.int32)
binary_flat = binary['binary_code'].reshape(-1).to(torch.int32)

print(f"z shape: {z.shape}")
print(f"binary['binary_code'] shape: {binary['binary_code'].shape}")
print(f"z_flat shape: {z_flat.shape}")
print(f"binary_flat shape: {binary_flat.shape}")

# Check if they have the same size
if z_flat.shape[0] == binary_flat.shape[0]:
    # Count matches
    matches = (z_flat == binary_flat).sum().item()
    total = z_flat.shape[0]
    match_ratio = matches / total
    
    print(f"\nMatching comparison:")
    print(f"Total elements: {total}")
    print(f"Matching elements: {matches}")
    print(f"Non-matching elements: {total - matches}")
    print(f"Matching ratio: {match_ratio:.4f} ({match_ratio*100:.2f}%)")
    
    # Show first few elements for inspection
    print(f"\nFirst 20 elements comparison:")
    print(f"z:      {z_flat[:20].tolist()}")
    print(f"binary: {binary_flat[:20].tolist()}")
    print(f"match:  {(z_flat[:20] == binary_flat[:20]).tolist()}")
    
else:
    print(f"Size mismatch! z has {z_flat.shape[0]} elements, binary has {binary_flat.shape[0]} elements")
    
    # Try to find if one is a subset of the other
    min_size = min(z_flat.shape[0], binary_flat.shape[0])
    matches = (z_flat[:min_size] == binary_flat[:min_size]).sum().item()
    match_ratio = matches / min_size
    
    print(f"Comparing first {min_size} elements:")
    print(f"Matching elements: {matches}")
    print(f"Matching ratio: {match_ratio:.4f} ({match_ratio*100:.2f}%)")