In [None]:
import data
from utils import viz
import models
import torch

model = models.HumanGAN.load_from_checkpoint(
    "checkpoints/ours/ours.ckpt"
)
generator = model.generator_ema.eval().cuda().requires_grad_(False)
discriminator = model.discriminator.eval().cuda().requires_grad_(False)

dataset = data.HumansDataset(
    "humans_in_context", deterministic=True, split="test", num_frames=1, spacing=1,
)

seed = 0
num_poses = 16
num_scenes = 16
scenes_per_row = 4

print(f"Next seed: {seed + num_poses * (num_scenes + 1)}")

for pose_idx in range(num_poses):
    pose_seed = seed + pose_idx * (num_scenes + 1)
    
    _, keypoints = dataset.sample(seed=pose_seed)
    keypoints = keypoints[None].cuda()
    
    skeletons = viz.visualize_skeletons(keypoints, dataset)
    images = []
    scene_seeds = []
    
    for scene_idx in range(num_scenes):
        scene_seed = pose_seed + scene_idx + 1
        scene_seeds.append(scene_seed)
        
        styles, frame_keypoints = viz.sample_styles(
            keypoints, generator, seed=scene_seed, samples=1, strength=0.25
        )
        image = generator.synthesis_network(styles, frame_keypoints)
        scene = generator.synthesis_network(styles, 0 * frame_keypoints)
        
        images.append(scene.cpu())
        images.append(image.cpu())
    
    images = torch.cat(images)
    viz.display_grid(images.cpu(), nrow=2 * scenes_per_row)
    scene_seeds = "\n".join([str(scene_seeds[i:i+scenes_per_row]) for i in range(0, num_scenes, scenes_per_row)])
    print(f"{pose_seed}\n{scene_seeds}")
import data
from utils import viz
import models
import torch

model = models.HumanGAN.load_from_checkpoint(
    "PATH_TO_PRETRAINED/1x2sxwu9_step=00899999.ckpt"
)
generator = model.generator_ema.eval().cuda().requires_grad_(False)
discriminator = model.discriminator.eval().cuda().requires_grad_(False)

dataset = data.HumansDataset(
    "PATH_TO_DATASET", deterministic=True, split="test", num_frames=1, spacing=1,
)

seed = 0
num_poses = 16
num_scenes = 16
scenes_per_row = 4

print(f"Next seed: {seed + num_poses * (num_scenes + 1)}")

for pose_idx in range(num_poses):
    pose_seed = seed + pose_idx * (num_scenes + 1)
    
    _, keypoints = dataset.sample(seed=pose_seed)
    keypoints = keypoints[None].cuda()
    
    skeletons = viz.visualize_skeletons(keypoints, dataset)
    images = []
    scene_seeds = []
    
    for scene_idx in range(num_scenes):
        scene_seed = pose_seed + scene_idx + 1
        scene_seeds.append(scene_seed)
        
        styles, frame_keypoints = viz.sample_styles(
            keypoints, generator, seed=scene_seed, samples=1, strength=0.25
        )
        image = generator.synthesis_network(styles, frame_keypoints)
        scene = generator.synthesis_network(styles, 0 * frame_keypoints)
        
        images.append(scene.cpu())
        images.append(image.cpu())
    
    images = torch.cat(images)
    viz.display_grid(images.cpu(), nrow=2 * scenes_per_row)
    scene_seeds = "\n".join([str(scene_seeds[i:i+scenes_per_row]) for i in range(0, num_scenes, scenes_per_row)])
    print(f"{pose_seed}\n{scene_seeds}")
