In [1]:
import os
import torch
import torchvision
from tqdm.auto import tqdm

In [2]:
TOTAL_FRAME = 60
SCENES = ['14n_copyroom10', '14n_office14', 'everett_dining1', 'everett_kitchen4','everett_kitchen6', 'everett_kitchen8']
GUIDANCE_SCALES = ['1.0','3.0','5.0','7.0']
# METHODS = ['shcoeffs']
# CONTROLS = ['no_control','depth','normal','both','bae', 'bae_both']
# NUM_CONTROL = [1,1,1,2,1,2]
METHODS = ['vae']
CONTROLS = ['no_control','depth','bae','bae_both']
NUM_CONTROL = [1,1,1,2]

LR = '1e-4'
INPUT_DIR = "../../output/20240918/val_multillum_val_rotate/"
OUTPUT_DIR = "../../output/20240918/val_multillum_val_rotate_video_frame"

In [3]:
for frame_id in tqdm(range(TOTAL_FRAME)):
    for method in METHODS:
        for guidance_scale in GUIDANCE_SCALES:
            images = []
            output_dir = os.path.join(OUTPUT_DIR, method, guidance_scale)
            os.makedirs(output_dir, exist_ok=True)
            # first row is environment map
            images.append(torch.zeros(3,256,256))
            images.append(torch.zeros(3,256,256))
            for scene in SCENES:
                try:
                    env_path = f'/data/pakkapon/datasets/multi_illumination/spherical/val_rotate/env_ldr/{scene}/dir_{frame_id}_mip2.png'
                    image = torchvision.io.read_image(env_path) / 255.0
                    images.append(image)
                except:
                    images.append(torch.zeros(3,256,256))
            for control, num_control in zip(CONTROLS, NUM_CONTROL):
                input_guidance_dir = os.path.join(INPUT_DIR, method, guidance_scale, control, LR)
                # get lastest checkpoint 
                try:
                    lastest_checkpoint = sorted(os.listdir(input_guidance_dir))[-1]
                except:
                    lastest_checkpoint = 'chk0'
                input_dir = os.path.join(input_guidance_dir, lastest_checkpoint, 'lightning_logs', 'version_0')
                filename_template = "{scene}-dir_0_mip2_{scene}-dir_{frame_id}_mip2.jpg"
                filename = filename_template.format(scene=SCENES[0], frame_id=frame_id)
                if num_control == 1:
                    images.append(torch.zeros(3,256,256))
                    # read control_image as tensor size (3,256,256)
                    try:
                        control_path = os.path.join(input_dir,'control_image', filename)
                        image = torchvision.io.read_image(control_path) / 255.0
                        # resize image to 256x256
                        image = torchvision.transforms.functional.resize(image, (256,256))
                        images.append(image)
                    except:
                        images.append(torch.zeros(3,256,256))
                else:                
                    for control_id in range(num_control):
                        try:
                            # read control_image as tensor size (3,256,256)
                            control_path = os.path.join(input_dir,'control_image', filename.replace('.jpg',f'_{control_id}.jpg'))
                            #print(control_path)
                            image = torchvision.io.read_image(control_path) / 255.0
                            # resize image to 256x256
                            image = torchvision.transforms.functional.resize(image, (256,256))
                            images.append(image)
                        except:
                            images.append(torch.zeros(3,256,256))
                for scene in SCENES:
                    try:
                        filename = filename_template.format(scene=scene, frame_id=frame_id)
                        image = torchvision.io.read_image(os.path.join(input_dir,'crop_image', filename)) / 255.0
                        # resize image to 256x256
                        image = torchvision.transforms.functional.resize(image, (256,256))
                        images.append(image)
                    except:
                        images.append(torch.zeros(3,256,256))
            # make grid
            grid = torchvision.utils.make_grid(images, nrow=8)
            # save image
            output_path = os.path.join(output_dir, f'{frame_id:04d}.png')
            torchvision.utils.save_image(grid, output_path)

  0%|          | 0/60 [00:00<?, ?it/s]

