- row는 num images
- col은 (input, output, target)


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json
import pytorch_lightning as pl
import torch
from torch.utils.data import random_split, DataLoader, Dataset, Subset
from omegaconf import OmegaConf
from tqdm import tqdm

import copy
from ldm.util import instantiate_from_config
from easydict import EasyDict

import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [3]:
config_path = "/workspace/code/3DAnything/zero123/configs/normal_variant_i2t_multimodal_v3_256.yaml"
target_cfg = OmegaConf.load(config_path)
weight_path = "/workspace/code/3DAnything/zero123/results/2025-06-11T18-53-14_normal_variant_i2t_multimodal_v3_256/checkpoints/epoch=000117.ckpt"

In [4]:
os.path.basename(config_path).split('.')[0]  # get the name of the config file without extension

'normal_variant_i2t_multimodal_v3_256'

In [5]:
def split_by_unbind(tensor):
    # torch.unbind removes the specified dimension (0 = batch) and returns a tuple
    individual_images = torch.unbind(tensor, dim=0)  # Returns tuple of 64 tensors
    return list(individual_images)  # Convert tuple to list for consistency

def convert_to_npy(tensor_list):
    new_list = []
    for each_tensor in tensor_list:
        img = each_tensor.clamp(-1, 1)
        img = (img + 1.0) / 2
        img = img.permute(1, 2, 0)
        img_np = img.numpy()
        new_list.append(img_np)
    return new_list

def plot_image(image_tensor):
    # Assuming image is normalized; adjust if necessary
    img = image_tensor.cpu().clone()
    img = img.clamp(-1, 1)
    img = (img + 1.0)/2  # Example denormalization
    
    # Permute dimensions: [3, 256, 256] -> [256, 256, 3]
    img = img.permute(1, 2, 0)
    
    # Convert to numpy
    img_np = img.numpy()
    
    # Plot
    plt.figure(figsize=(5, 5))
    plt.imshow(img_np)
    plt.axis('off')
    plt.title('Individual Image')
    plt.show()

def total_writer(dir:str, ret_dict: dict, cfg_scale: float = 3.0):
    # Image.fromarray((proc_result["samples_cfg_scale_3.00"][0] * 255).astype("uint8")).show()
    os.makedirs(dir, exist_ok=True)
    image_target_list = ["inputs", "reconstruction", "control", f"samples_cfg_scale_{cfg_scale:.2f}"] 
    
    for sub_dir in image_target_list:
        os.makedirs(os.path.join(dir, sub_dir), exist_ok=True)
        
    for k, v in ret_dict.items():
        if k in image_target_list:
            for img_idx, each_img in tqdm(enumerate(v), total=len(v), desc=f"Processing {k}"):
                object_id = ret_dict["object_id"][img_idx]
                object_id = object_id.replace("/", "_")  # Replace slashes to avoid directory issues
                
                view_point_id = None
                if k in ["inputs", "control"]:
                    view_point_id = ret_dict["viewpoint_ids_cond"][img_idx]
                elif k in ["reconstruction", f"samples_cfg_scale_{cfg_scale:.2f}"]:
                    view_point_id = ret_dict["viewpoint_ids_target"][img_idx]
                
                # float to uint8 conversion
                img = (each_img * 255).astype("uint8")
                if view_point_id is not None:
                    img_name = f"{object_id}_{view_point_id}.png"
                else:
                    img_name = f"{object_id}.png"
                Image.fromarray(img).save(os.path.join(dir, k, img_name))
        

def concat_writer(dir: str, ret_dict: dict, cfg_scale: float = 3.0):
    # 기본 디렉토리 생성
    os.makedirs(dir, exist_ok=True)
    image_target_list = ["inputs", "reconstruction", "control", f"samples_cfg_scale_{cfg_scale:.2f}"]
    
    # 개별 이미지 저장을 위한 하위 디렉토리 생성
    for sub_dir in image_target_list:
        os.makedirs(os.path.join(dir, sub_dir), exist_ok=True)
    
    # 기존 코드: 개별 이미지 저장
    for k, v in ret_dict.items():
        if k in image_target_list:
            for img_idx, each_img in tqdm(enumerate(v), total=len(v), desc=f"Processing {k}"):
                object_id = ret_dict["object_id"][img_idx].replace("/", "_")
                
                view_point_id = None
                if k in ["inputs", "control"]:
                    view_point_id = ret_dict["viewpoint_ids_cond"][img_idx]
                elif k in ["reconstruction", f"samples_cfg_scale_{cfg_scale:.2f}"]:
                    view_point_id = ret_dict["viewpoint_ids_target"][img_idx]
                
                # float에서 uint8로 변환
                img = (each_img * 255).astype("uint8")
                if view_point_id is not None:
                    img_name = f"{object_id}_{view_point_id}.png"
                else:
                    img_name = f"{object_id}.png"
                Image.fromarray(img).save(os.path.join(dir, k, img_name))
    
    # 결합 이미지 생성 및 저장
    combined_dir = os.path.join(dir, "combined")
    os.makedirs(combined_dir, exist_ok=True)
    
    for img_idx in range(len(ret_dict["object_id"])):
        object_id = ret_dict["object_id"][img_idx].replace("/", "_")
        view_point_id_target = ret_dict["viewpoint_ids_target"][img_idx]
        
        # 이미지 가져오기
        control_img = ret_dict["control"][img_idx]
        samples_img = ret_dict[f"samples_cfg_scale_{cfg_scale:.2f}"][img_idx]
        reconstruction_img = ret_dict["reconstruction"][img_idx]
        
        # uint8 형식으로 변환
        control_img = (control_img * 255).astype("uint8") if control_img.max() <= 1.0 else control_img.astype("uint8")
        samples_img = (samples_img * 255).astype("uint8") if samples_img.max() <= 1.0 else samples_img.astype("uint8")
        reconstruction_img = (reconstruction_img * 255).astype("uint8") if reconstruction_img.max() <= 1.0 else reconstruction_img.astype("uint8")
        
        # 가로로 결합: [control, samples, reconstruction]
        combined_img = np.concatenate([control_img, samples_img, reconstruction_img], axis=1)
        
        # 결합 이미지 저장
        img_name = f"{object_id}_{view_point_id_target}.png"
        Image.fromarray(combined_img).save(os.path.join(combined_dir, img_name))

In [6]:
model = instantiate_from_config(target_cfg.model)

MultiModalControlNetV3: Running in eps-prediction mode
DiffusionWrapper has 859.53 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [7]:
state_dict = torch.load(weight_path, map_location="cpu")["state_dict"]
model.load_state_dict(state_dict, strict=True)
model = model.to("cuda")

In [10]:
# verbose mode로 해서 path를 빼도록 만들기
data = instantiate_from_config(target_cfg.data)
target_cfg.data.params.validation.params.verbose = True
target_cfg.data.params.validation.params.determin_view = True
data.prepare_data()
data.setup()
val_dataloader = DataLoader(
    data.datasets["validation"], batch_size=data.batch_size,
    num_workers=data.num_workers, shuffle=False
)

In [None]:
debug = False
proc_result = {}
cfg_scale = 3.0
for idx, each_batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader), leave=True):
    
    if debug:
        if idx == 1:
            break    
    
    ret = model.log_val_images(
        each_batch,
        N=len(each_batch["image_target"]),
        unconditional_guidance_scale=cfg_scale,
        unconditional_guidance_label=[""],
        use_ema_scope=False,
        inpaint=False,
        plot_progressive_rows=False,
        plot_diffusion_rows=False
    )
    
    if idx == 0:
        proc_result = {}
        for k in ret.keys():
            
            if k == "viewpoint_ids":
                proc_result[f"{k}_cond"] = []
                proc_result[f"{k}_target"] = []   
            
            else:
                proc_result[k] = []
    
    for k, v in ret.items():
        if k in ["inputs", "reconstruction", "control", f"samples_cfg_scale_{cfg_scale:.2f}"]:
            _v = v.cpu()
            _v = convert_to_npy(split_by_unbind(_v))
            proc_result[k].extend(_v)
        if "viewpoint_ids" in k:
            proc_result[f"{k}_cond"].extend(v["cond"])
            proc_result[f"{k}_target"].extend(v["target"])
        if k in ["object_id", "txt"]:
            proc_result[k].extend(v)
        

  0%|          | 0/8 [00:00<?, ?it/s]

Data shape for DDIM sampling is (64, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [01:11<00:00,  1.42s/it]
 12%|█▎        | 1/8 [01:29<10:29, 89.86s/it]

Data shape for DDIM sampling is (64, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [01:11<00:00,  1.44s/it]
 25%|██▌       | 2/8 [02:55<08:44, 87.41s/it]

Data shape for DDIM sampling is (64, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps




In [None]:
# writer part
exp_name = os.path.basename(config_path).split('.')[0]  
total_writer(
    dir=f"results/{exp_name}",
    ret_dict=proc_result,
)
concat_writer(
    dir=f"results/{exp_name}",
    ret_dict=proc_result,
    cfg_scale=cfg_scale
)