<a href="https://colab.research.google.com/github/sixthhokage4/StyleSDF/blob/main/StyleSDF_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#StyleSDF Demo

This Colab notebook demonstrates the capabilities of the StyleSDF 3D-aware GAN architecture proposed in our paper.

This colab generates images with their correspondinig 3D meshes

First, let's download the github repository and install all dependencies.

In [None]:
!git clone https://github.com/royorel/StyleSDF.git
%cd StyleSDF
!pip3 install -r requirements.txt

And install pytorch3D...

In [None]:
!pip install -U fvcore
import sys
import torch
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
    f"py3{sys.version_info.minor}_cu",
    torch.version.cuda.replace(".",""),
    f"_pyt{pyt_version_str}"
])
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html

Now let's download the pretrained models for FFHQ and AFHQ.

In [None]:
!python download_models.py

Here, we import libraries and set options.

Note: this might take a while (approx. 1-2 minutes) since CUDA kernels need to be compiled.

In [None]:
import os
import torch
import trimesh
import numpy as np
from munch import *
from options import BaseOptions
from model import Generator
from generate_shapes_and_images import generate
from render_video import render_video


torch.random.manual_seed(321)


device = "cuda"
opt = BaseOptions().parse()
opt.camera.uniform = True
opt.model.is_test = True
opt.model.freeze_renderer = False
opt.rendering.offset_sampling = True
opt.rendering.static_viewdirs = True
opt.rendering.force_background = True
opt.rendering.perturb = 0
opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
opt.inference.style_dim = opt.model.style_dim
opt.inference.project_noise = opt.model.project_noise

Don't worry about this message above, 
```
usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]
                             [--config CONFIG] [--expname EXPNAME]
                             [--ckpt CKPT] [--continue_training]
                             ...
                             ...
ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-c9d47a98-bdba-4a5f-9f0a-e1437c7228b6.json
```
everything is perfectly fine...

Here, we define our model.

Set the options below according to your choosing:
1. If you plan to try the method for the AFHQ dataset (animal faces), change `model_type` to 'afhq'. Default: `ffhq` (human faces).
2. If you wish to turn off depth rendering and marching cubes extraction and generate only RGB images, set `opt.inference.no_surface_renderings = True`. Default: `False`.
3. If you wish to generate the image from a specific set of viewpoints, set `opt.inference.fixed_camera_angles = True`. Default: `False`.
4. Set the number of identities you wish to create in `opt.inference.identities`. Default: `4`.
5. Select the number of views per identity in `opt.inference.num_views_per_id`,<br>
   (Only applicable when `opt.inference.fixed_camera_angles` is false). Default: `1`. 

In [None]:
# User options
model_type = 'ffhq' # Whether to load the FFHQ or AFHQ model
opt.inference.no_surface_renderings = False # When true, only RGB images will be created
opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated
opt.inference.identities = 4 # Number of identities to generate
opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if opt.inference.fixed_camera_angles is true.

# Load saved model
if model_type == 'ffhq':
    model_path = 'ffhq1024x1024.pt'
    opt.model.size = 1024
    opt.experiment.expname = 'ffhq1024x1024'
else:
    opt.inference.camera.azim = 0.15
    model_path = 'afhq512x512.pt'
    opt.model.size = 512
    opt.experiment.expname = 'afhq512x512'

# Create results directory
result_model_dir = 'final_model'
results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)
if opt.inference.fixed_camera_angles:
    opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')
else:
    opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')

os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)
if not opt.inference.no_surface_renderings:
    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)
    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)

opt.inference.camera = opt.camera
opt.inference.size = opt.model.size
checkpoint_path = os.path.join('full_models', model_path)
checkpoint = torch.load(checkpoint_path)

# Load image generation model
g_ema = Generator(opt.model, opt.rendering).to(device)
pretrained_weights_dict = checkpoint["g_ema"]
model_dict = g_ema.state_dict()
for k, v in pretrained_weights_dict.items():
    if v.size() == model_dict[k].size():
        model_dict[k] = v

g_ema.load_state_dict(model_dict)

# Load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution
if not opt.inference.no_surface_renderings:
    opt['surf_extraction'] = Munch()
    opt.surf_extraction.rendering = opt.rendering
    opt.surf_extraction.model = opt.model.copy()
    opt.surf_extraction.model.renderer_spatial_output_dim = 128
    opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
    opt.surf_extraction.rendering.return_xyz = True
    opt.surf_extraction.rendering.return_sdf = True
    surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)


    # Load weights to surface extractor
    surface_extractor_dict = surface_g_ema.state_dict()
    for k, v in pretrained_weights_dict.items():
        if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
            surface_extractor_dict[k] = v

    surface_g_ema.load_state_dict(surface_extractor_dict)
else:
    surface_g_ema = None

# Get the mean latent vector for g_ema
if opt.inference.truncation_ratio < 1:
    with torch.no_grad():
        mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
else:
    surface_mean_latent = None

# Get the mean latent vector for surface_g_ema
if not opt.inference.no_surface_renderings:
    surface_mean_latent = mean_latent[0]
else:
    surface_mean_latent = None

## Generating images and meshes

Finally, we run the network. The results will be saved to `evaluations/[model_name]/final_model/[fixed/random]_angles`, according to the selected setup.

In [None]:
generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)

Now let's examine the results

Tip: for better mesh visualization, we recommend dowwnloading the result meshes and view them with Meshlab.

Meshes loaction is: `evaluations/[model_name]/final_model/[fixed/random]_angles/[depth_map/marching_cubes]_meshes`.

In [None]:
from PIL import Image
from trimesh.viewer.notebook import scene_to_html as mesh2html
from IPython.display import HTML as viewer_html

# First let's look at the images
img_dir = os.path.join(opt.inference.results_dst_dir,'images')
im_list = sorted([entry for entry in os.listdir(img_dir) if 'thumb' not in entry])
img = Image.new('RGB', (256 * len(im_list), 256))
for i, im_file in enumerate(im_list):
    im_path = os.path.join(img_dir, im_file)
    curr_img = Image.open(im_path).resize((256,256)) # the displayed image is scaled to fit to the screen
    img.paste(curr_img, (256 * i, 0))

display(img)

# And now, we'll move on to display the marching cubes and depth map meshes

marching_cubes_meshes_dir = os.path.join(opt.inference.results_dst_dir,'marching_cubes_meshes')
marching_cubes_meshes_list = sorted([os.path.join(marching_cubes_meshes_dir, entry) for entry in os.listdir(marching_cubes_meshes_dir) if 'obj' in entry])
depth_map_meshes_dir = os.path.join(opt.inference.results_dst_dir,'depth_map_meshes')
depth_map_meshes_list = sorted([os.path.join(depth_map_meshes_dir, entry) for entry in os.listdir(depth_map_meshes_dir) if 'obj' in entry])
for i, mesh_files in enumerate(zip(marching_cubes_meshes_list, depth_map_meshes_list)):
    mc_mesh_file, dm_mesh_file = mesh_files[0], mesh_files[1]
    marching_cubes_mesh = trimesh.Scene(trimesh.load_mesh(mc_mesh_file, 'obj'))  
    curr_mc_html = mesh2html(marching_cubes_mesh).replace('"', '&quot;')
    display(viewer_html(' '.join(['<iframe srcdoc="{srcdoc}"',
                            'width="{width}px" height="{height}px"',
                            'style="border:none;"></iframe>']).format(
                            srcdoc=curr_mc_html, height=256, width=256)))
    depth_map_mesh = trimesh.Scene(trimesh.load_mesh(dm_mesh_file, 'obj'))  
    curr_dm_html = mesh2html(depth_map_mesh).replace('"', '&quot;')
    display(viewer_html(' '.join(['<iframe srcdoc="{srcdoc}"',
                            'width="{width}px" height="{height}px"',
                            'style="border:none;"></iframe>']).format(
                            srcdoc=curr_dm_html, height=256, width=256)))

## Generating videos

Additionally, we can also render videos. The results will be saved to `evaluations/[model_name]/final_model/videos`.

Set the options below according to your choosing:
1. If you wish to generate only RGB videos, set `opt.inference.no_surface_renderings = True`. Default: `False`.
2. Set the camera trajectory. To travel along the azimuth direction set `opt.inference.azim_video = True`, to travel in an ellipsoid trajectory set `opt.inference.azim_video = False`. Default: `False`.

###Important Note: 
 - Processing time for videos when `opt.inference.no_surface_renderings = False` is very lengthy (~ 15-20 minutes per video). Rendering each depth frame for the depth videos is very slow.<br>
 - Processing time for videos when `opt.inference.no_surface_renderings = True` is much faster (~ 1-2 minutes per video)

In [None]:
# Options
opt.inference.no_surface_renderings = True # When true, only RGB videos will be created
opt.inference.azim_video = True # When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.

opt.inference.results_dst_dir = os.path.join(os.path.split(opt.inference.results_dst_dir)[0], 'videos')
os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)

Let's watch the result videos.

The output video files are relatively large, so it might take a while (about 1-2 minutes) for all of them to be loaded. 

In [None]:
%%script bash --bg
python3 -m https.server 8000

In [None]:
# change ffhq1024x1024 to afhq512x512 if you are working on the AFHQ model
%%html
<div>
  <video width=256 controls><source src="https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_0_azim.mp4" type="video/mp4"></video>
  <video width=256 controls><source src="https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_1_azim.mp4" type="video/mp4"></video>
  <video width=256 controls><source src="https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_2_azim.mp4" type="video/mp4"></video>
  <video width=256 controls><source src="https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_3_azim.mp4" type="video/mp4"></video>
</div>

An alternative way to view the videos with python code. 
It loads the videos faster, but very often it crashes the notebook since the video file are too large.

**It is not recommended to view the files this way**.

If the notebook does crash, you can also refresh the webpage and manually download the videos.<br>
The videos are located in `evaluations/<model_name>/final_model/videos`

In [None]:
# from base64 import b64encode

# videos_dir = opt.inference.results_dst_dir
# videos_list = sorted([os.path.join(videos_dir, entry) for entry in os.listdir(videos_dir) if 'mp4' in entry])
# for i, video_file in enumerate(videos_list):
#     if i != 1:
#         continue
#     mp4 = open(video_file,'rb').read()
#     data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
#     display(viewer_html("""<video width={0} controls>
#                                 <source src="{1}" type="{2}">
#                           </video>""".format(256, data_url, "video/mp4")))