#  Постановка задачи

Цель задания - раскрасить матрёшку с помощью SDS лосса.

Для этого дан скрипт matryoshka.py, который позволяет рисовать матрешку с заданного ракурса.

Помимо этого, скрипт содержит модуль Texture, который в качестве обучаемых параметров содержит текстуру матрешки.

Используя диффузионную модель Deepfloyd-IF, вы настроите параметры текстуры. Внешний вид матрёшки будет определять текстовый промпт. Для удобства, мы предподсчитали представление одного промпта, но вы также сможете выбрать и свой промпт.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import gc
import imageio
import matplotlib.pylab as plt
from IPython.display import Video

from transformers import T5EncoderModel
from diffusers import DiffusionPipeline

from matryoshka import Texture, render, calculate_normals, calculate_soft_shadow
from IPython.display import display, clear_output

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
deepfloyd_model = "DeepFloyd/IF-I-M-v1.0"

def flush():
    gc.collect()
    torch.cuda.empty_cache()
flush()
torch.set_default_dtype(torch.float32)

Что есть в файлике?

Класс Texture

Метод render, который принимает
1. Позицию камеры
2. Направление камеры
3. Функцию, которая красит точки,
4. Текстуру (опционально)
5. Источник освещения (опционально)

Вспомогательные методы calc_normals, calculate_soft_shadow

In [None]:
camera_origin = torch.as_tensor([[1, 3., 1.5]], device=device)
camera_target = torch.as_tensor([[0., 0., 1.0]], device=device)
camera_direction = camera_target - camera_origin

def get_pixel_normals(points, texture, light_source):
    return 0.5 * (calculate_normals(points) + 1.)

rendered_normals = render(camera_origin, camera_direction, get_pixel_normals, resolution=512)

plt.imshow(rendered_normals[0].cpu())
plt.axis('off');

In [None]:
matryoshka_texture = Texture(texture_grid=torch.rand(1, 3, 64, 64),
                             background_color=torch.as_tensor([0., 0.5, 0.3])).to(device)

def get_texture(points, texture, light_source):
    return texture(points)

rendered_albedo = render(camera_origin, camera_direction, get_texture, matryoshka_texture, resolution=512)

plt.imshow(rendered_albedo[0].detach().cpu())
plt.axis('off');

Сделаем чуть интереснее

In [None]:
def get_pixel_colors(points, texture, light_source):
    albedo = texture(points)
    normals = calculate_normals(points)
    light_dir = F.normalize(light_source - points, dim=-1)
    # Lambert's cosine law
    shadow_coefficient = (normals * light_dir).sum(-1, keepdim=True).clamp(0.)
    # Does the point see the light source?
    light_rays = {'origins': points,
                  'directions': F.normalize(light_source - points, dim=-1)}
    in_shadow = calculate_soft_shadow(light_rays)
    ambient_light = 0.2
    return (albedo * (ambient_light + shadow_coefficient * in_shadow)).clamp(0., 1.)

light_source = 3. * torch.as_tensor([2.0, 1.0, 1.0]).to(device)
rendered_image = render(
    camera_origin,
    camera_direction,
    get_pixel_colors,
    matryoshka_texture,
    light_source,
    resolution=512
)

plt.imshow(rendered_image[0].detach().cpu())
plt.axis('off');

In [None]:
def render_validation_frames(get_pixel_colors, texture=None, light_source=None, n_frames=180, chunk_size=8, **kwargs):
    time = torch.linspace(0, 1, n_frames)
    # camera flies around the scene center
    camera_origins = 3. * torch.stack(
        [(2 * torch.pi * time).cos(),
         (2 * torch.pi * time).sin(),
         torch.full_like(time, 0.5)], dim=1
    ).to(device)
    camera_target = torch.as_tensor([0., 0., 1.]).to(device)
    camera_directions = camera_target - camera_origins
    # render images in chunks
    light_source = 4. * torch.as_tensor([1.0, 0.0, 1.0]).to(device)
    images = []
    for camera_origins_batch, camera_directions_batch in zip(
        torch.split(camera_origins, chunk_size),
        torch.split(camera_directions, chunk_size)
    ):
        images_batch = render(camera_origins_batch,
                              camera_directions_batch,
                              get_pixel_colors,
                              texture,
                              light_source,
                              resolution=512,
                              **kwargs).clamp(0., 0.999)
        images_batch = (256 * images_batch).floor().to(torch.uint8)
        images_batch = images_batch.cpu().numpy()
        images.append(images_batch)
    return np.concatenate(images)

def save_video(frames, filename):
    writer = imageio.get_writer(filename, fps=30)
    for frame in images:
        writer.append_data(frame)
    writer.close()

In [None]:
images = render_validation_frames(get_pixel_colors, matryoshka_texture)
save_video(images, 'init_render.mp4')

Video("init_render.mp4", width=512, height=512, html_attributes='loop autoplay')

Почистим память для дальнейшей работы:

In [None]:
# посчитим память от ненужных объектов
del camera_origin
del camera_target
del camera_direction
del matryoshka_texture
del rendered_normals
del rendered_albedo
del rendered_image
flush()

!nvidia-smi

# Представления промпта

Ниже мы подгружаем кодировщик T5 и вычисляем представления промпта. Чтобы воспольззоваться предподсчитанными, можно пропустить следующие три ячейки

In [None]:
text_encoder = T5EncoderModel.from_pretrained(
    deepfloyd_model,
    subfolder="text_encoder",
    device_map="auto",
    variant="8bit",
    load_in_8bit=True,
)

pipe = DiffusionPipeline.from_pretrained(
    deepfloyd_model,
    text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
    unet=None,
    device_map="balanced"
)

In [None]:
prompt = "Astronaut in a form of matryoshka doll"

directions = ['front view', 'side view', 'top view', 'backside view']

directional_prompts = [prompt + ', ' + direction for direction in directions]

prompt_embeddings = pipe.encode_prompt(prompt)
directional_prompt_embeddings = pipe.encode_prompt(directional_prompts)
torch.save(prompt_embeddings,
           f'embeddings_{prompt.replace(" ", "_")}.pt')
torch.save(directional_prompt_embeddings,
           f'directional_embeddings_{prompt.replace(" ", "_")}.pt')

Чистим память для последующей работы

In [None]:
del text_encoder
del pipe
flush()

# Генерация изображений

Для начала сравним два подхода к генерации плоских изображений:
1. Стандартный подход (последовательное расшумление)
2. Генерация с помощью SDS функции потерь

Для последнего мы реализуем SDS.


Для работы в коллабе мы подгружаем маленькую диффузионную модель "DeepFloyd/IF-I-M-v1.0".

Лучших результатов удастся добиться с моделью "DeepFloyd/IF-I-XL-v1.0", которая будет работать чуть медленнее.

In [None]:
pipe = DiffusionPipeline.from_pretrained(
    deepfloyd_model,
    text_encoder=None,
    safety_checker=None,
    watermarker=None,
    feature_extractor=None,
    requires_safety_checker=False,
    variant="fp16",
    torch_dtype=torch.float16,
    device_map="balanced"
)

## Сэмплирование

Стандартная генерация сводится к вызову метода

In [None]:
generator = torch.Generator().manual_seed(0)
image = pipe(
    prompt_embeds=prompt_embeddings[0],
    negative_prompt_embeds=prompt_embeddings[1],
    output_type="pt",
    generator=generator,
).images

In [None]:
from diffusers.utils import pt_to_pil
from PIL import Image

pil_image = pt_to_pil(image)

display(pil_image[0].resize((512, 512), Image.NEAREST))

А для SDS определим необходимые компоненты

In [None]:
del pipe
flush()

In [None]:
pipe = DiffusionPipeline.from_pretrained(
    deepfloyd_model,
    text_encoder=None,
    safety_checker=None,
    watermarker=None,
    feature_extractor=None,
    requires_safety_checker=False,
    variant="fp16",
    torch_dtype=torch.float16,
    device_map="balanced"
)

unet = pipe.unet.eval()
scheduler = pipe.scheduler
num_train_timesteps = scheduler.config.num_train_timesteps
alphas = scheduler.alphas_cumprod.to(torch.device('cuda:0'))

Подсчет градиента

In [None]:
@torch.cuda.amp.autocast(enabled=False)
def forward_unet(latents, t, encoder_hidden_states):
    input_dtype = latents.dtype
    return unet(
        latents.to(torch.float16),
        t.to(torch.float16),
        encoder_hidden_states=encoder_hidden_states.to(torch.float16),
    ).sample.to(input_dtype)

def get_sds_loss(images, prompt_embeddings, min_step=20, max_step=980, guidance_scale=10.):
    batch_size = images.shape[0]
    # prepare image
    latents = F.interpolate(images, (64, 64), mode="bilinear", align_corners=False, antialias=True)
    latents = 2. * latents - 1.
    # sample ts
    t = torch.randint(
        min_step,
        max_step,
        [batch_size],
        dtype=torch.long,
        device=torch.device('cuda:0'))
    # predict noise
    with torch.no_grad():
        noise = torch.randn_like(latents).to(torch.device('cuda:0'))
        latents_noisy = scheduler.add_noise(latents, noise, t)
        noise_pred = forward_unet(
            torch.cat(2 * [latents_noisy]),
            torch.cat(2 * [t]),
            torch.cat(prompt_embeddings)
        )

    noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
    noise_pred_text, _ = noise_pred_text.split(3, dim=1)
    noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1)
    noise_pred = noise_pred_text + guidance_scale * (
        noise_pred_text - noise_pred_uncond
    )

    w = (1 - alphas[t]).view(-1, 1, 1, 1)
    grad = w * (noise_pred - noise)
    grad = torch.nan_to_num(grad)

    # loss = SpecifyGradient.apply(latents, grad)
    target = (latents - grad).detach()
    loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
    return loss_sds

## Генерация картинки с помощью SDS

In [None]:
image = torch.nn.Parameter(torch.full((1, 3, 256, 256), 0.7, device=torch.device('cuda:0')))
optimizer = torch.optim.Adam([image], lr=1e-2, weight_decay=0)

In [None]:
generator = torch.Generator().manual_seed(0)

try:
    for i in range(256):
        t_min = 20
        t_max = 980
        loss = get_sds_loss(image, prompt_embeddings, t_min, t_max)
        loss.backward()
        optimizer.step()
        image.data = image.data.clamp(0., 1.)
        optimizer.zero_grad()
        if i % 16 == 0:
            fig, ax = plt.subplots(figsize=(5, 5))
            ax.imshow(image[0].permute(1, 2, 0).clamp(0., 1.).detach().cpu())
            ax.axis('off')
            clear_output(wait=True)
            plt.show()
except KeyboardInterrupt:
    pass
finally:
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(image[0].permute(1, 2, 0).clamp(0., 1.).detach().cpu())
    ax.axis('off')
    clear_output(wait=True)
    plt.show()

In [None]:
image_lr = torch.clone(image)

In [None]:
del pipe
del unet
flush()
!nvidia-smi

In [None]:
@torch.cuda.amp.autocast(enabled=False)
def forward_sr_unet(latents, t, encoder_hidden_states, class_labels):
    input_dtype = latents.dtype
    return unet(
        latents.to(torch.float16),
        t.to(torch.float16),
        encoder_hidden_states=encoder_hidden_states.to(torch.float16),
        class_labels=class_labels
    ).sample.to(input_dtype)

def get_sr_sds_loss(images, prompt_embeddings, min_step=20, max_step=980, guidance_scale=10., lowres_noise_level=0.75, original=None):
    batch_size = images.shape[0]
    #latents = images
    # prepare image
    latents = F.interpolate(images, (256, 256), mode="bilinear", align_corners=False)
    latents = 2. * latents - 1.
    
    if original is None:
        upscaled = F.interpolate(latents, (64, 64), mode="nearest")#, align_corners=False, antialias=True)
        upscaled = F.interpolate(upscaled, (256, 256), mode="nearest")#, align_corners=True).detach()
    else:
        original = 2. * original - 1.
        upscaled = F.interpolate(original, (64, 64), mode="nearest")#, align_corners=False, antialias=True)
        upscaled = F.interpolate(upscaled, (256, 256), mode="nearest")#, align_corners=True).detach()
    
    noise_level = torch.tensor([int(num_train_timesteps * lowres_noise_level)] * upscaled.shape[0],
                               device=upscaled.device)
    noise_level = torch.cat([noise_level] * 2)
    noise = torch.randn_like(upscaled)
    upscaled = scheduler.add_noise(upscaled, noise, torch.tensor(int(num_train_timesteps * lowres_noise_level)))
    
    # sample ts
    t = torch.randint(
        min_step,
        max_step,
        [batch_size],
        dtype=torch.long,
        device=torch.device('cuda:0'))
    # predict noise
    with torch.no_grad():
        #latents.data = latents.data.clamp(-1., 1.)
        noise = torch.randn_like(latents).to(torch.device('cuda:0'))
        latents_noisy = scheduler.add_noise(latents, noise, t)
                             
        latent_model_input = torch.cat([latents_noisy, upscaled], dim=1)
        latent_model_input = torch.cat([latent_model_input] * 2, dim=0)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
                             
        noise_pred = forward_sr_unet(
            latent_model_input,
            torch.cat(2 * [t]),
            encoder_hidden_states=torch.cat(prompt_embeddings),
            class_labels=noise_level
        )

    noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
    noise_pred_text, _ = noise_pred_text.split(3, dim=1)
    noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1)
    noise_pred = noise_pred_text + guidance_scale * (
        noise_pred_text - noise_pred_uncond
    )

    w = (1 - alphas[t]).view(-1, 1, 1, 1)

    grad = w * (noise_pred - noise)
    grad = torch.nan_to_num(grad)
    
    target = (latents - grad).detach()
    loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
    return loss_sds

In [None]:
image = torch.nn.Parameter(torch.clone(image_lr))
optimizer = torch.optim.Adam([image], lr=1e-2, weight_decay=0)

pipe = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-II-M-v1.0",
    text_encoder=None, 
    safety_checker=None, 
    watermarker=None,
    feature_extractor=None,
    requires_safety_checker=False,
    variant="fp16",
    torch_dtype=torch.float16,
).to(device)

unet = pipe.unet.eval()
scheduler = pipe.scheduler
num_train_timesteps = scheduler.config.num_train_timesteps
alphas = scheduler.alphas_cumprod.to(torch.device('cuda:0'))

In [None]:
generator = torch.Generator().manual_seed(1)

try:
    for i in range(1024):
        t_min = 20
        t_max = 980
        loss = get_sr_sds_loss(image, prompt_embeddings, t_min, t_max, original=image_lr)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 16 == 0:
            fig, ax = plt.subplots(figsize=(5, 5))
            ax.imshow(image[0].permute(1, 2, 0).clamp(0., 1.).detach().cpu())
            ax.axis('off')
            clear_output(wait=True)
            plt.show()
except KeyboardInterrupt:
    pass
finally:
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(image[0].permute(1, 2, 0).clamp(0., 1.).detach().cpu())
    ax.axis('off')
    clear_output(wait=True)
    plt.show()

# 3D

Переходим к аналогичной процедуре обучения, но в 3D

In [None]:
del pipe
del unet
flush()
!nvidia-smi

In [None]:
pipe = DiffusionPipeline.from_pretrained(
    deepfloyd_model,
    text_encoder=None,
    safety_checker=None,
    watermarker=None,
    feature_extractor=None,
    requires_safety_checker=False,
    variant="fp16",
    torch_dtype=torch.float16,
    device_map="balanced"
)

unet = pipe.unet.eval()
scheduler = pipe.scheduler
num_train_timesteps = scheduler.config.num_train_timesteps
alphas = scheduler.alphas_cumprod.to(torch.device('cuda:0'))

@torch.cuda.amp.autocast(enabled=False)
def forward_unet(latents, t, encoder_hidden_states):
    input_dtype = latents.dtype
    return unet(
        latents.to(torch.float16),
        t.to(torch.float16),
        encoder_hidden_states=encoder_hidden_states.to(torch.float16),
    ).sample.to(input_dtype)

In [None]:
def get_pixel_colors(points, texture, light_source):
    albedo = texture(points)
    normals = calculate_normals(points)
    light_source = light_source.view(points.shape[0], 1, 1, 3)
    light_dir = F.normalize(light_source - points, dim=-1)
    # Lambert's cosine law
    shadow_coefficient = (normals * light_dir).sum(-1, keepdim=True).clamp(0.)
    # Does the point see the light source?
    light_rays = {'origins': points,
                  'directions': F.normalize(light_source - points, dim=-1)}
    in_shadow = calculate_soft_shadow(light_rays)
    #in_shadow = 1.0
    ambient_light = 0.4
    return (albedo * (ambient_light + shadow_coefficient * in_shadow)).clamp(0., 1.)

In [None]:
matryoshka_texture = Texture(texture_grid=torch.full((1, 3, 128, 128), 0.5),
                             background_color=torch.as_tensor([0.85, 0.85, 0.85]),
                             resolution=128,
                             train_background=False).to(device)
optimizer = torch.optim.Adam(matryoshka_texture.parameters(), lr=1e-2)

In [None]:
def sample_cameras(batch_size, device):
    phi = 0.7 * torch.randn(batch_size, device=device).clamp(-torch.pi, torch.pi)
    theta = torch.zeros(batch_size, device=device)
    camera_target = torch.as_tensor([[0., 0., 1.0]], device=device)
    camera_directions = -torch.stack(
        [phi.cos() * theta.cos(),
         phi.sin() * theta.cos(),
         theta.sin()], dim=1
    )
    camera_origins = camera_target - 3 * camera_directions
    embedding_index = torch.full((batch_size,), 0., dtype=torch.int64, device=device) # front
    embedding_index = torch.where(phi.abs() > 0.25 * torch.pi,
                                  torch.full_like(embedding_index, 1), # size
                                  embedding_index) # side
    embedding_index = torch.where(phi.abs() > 0.75 * torch.pi,
                                  torch.full_like(embedding_index, 3), # backside
                                  embedding_index)
    embedding_index = torch.where(theta > 0.25 * torch.pi,
                                  torch.full_like(embedding_index, 2), # top
                                  embedding_index)
    return camera_origins, camera_directions, embedding_index

def sample_light(batch_size, device):
    phi = (0.3 * torch.randn((batch_size,), device=device)).clamp(-torch.pi, torch.pi)
    return 4. * torch.cat([phi.cos(), phi.sin(), torch.ones_like(phi)])

In [None]:
camera_origin, camera_direction, embedding_index = sample_cameras(1, device)
light_source = sample_light(1, device)

rendered_image = render(camera_origin,
                        camera_direction,
                        get_pixel_colors,
                        matryoshka_texture,
                        light_source,
                        resolution=512)

plt.imshow(rendered_image[0].detach().cpu())
plt.axis('off');

In [None]:
def plot_texture_and_matryoshka(image, matryoshka_texture):
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    matryoshka_texture.matryoshka_grid.data = matryoshka_texture.matryoshka_grid.data.clamp(0.05, 0.95)
    ax[0].imshow(torch.movedim(image[0], 0, -1).clamp(0., 1.).detach().cpu())
    ax[0].axis('off')
    clear_output(wait=True)

    texture_map = matryoshka_texture.matryoshka_grid.data.cpu()
    texture_map = torch.flip(texture_map, (-2,))
    ax[1].axis('off')
    ax[1].imshow(torch.movedim(texture_map[0], 0, -1), aspect=1.);
    plt.show()

In [None]:
generator = torch.Generator().manual_seed(1)

batch_size=1

for i in range(512):
    t_min = 20
    t_max = 980
    # sample camera
    camera_origins, camera_directions, embedding_index = sample_cameras(batch_size, device)
    batch_embeddings = [directional_prompt_embeddings[0][embedding_index],
                        directional_prompt_embeddings[1][embedding_index]]
    # sample light
    light_sources = sample_light(batch_size, device)
    # sample camera_directions
    image = render(camera_origins,
                   camera_directions,
                   get_pixel_colors,
                   matryoshka_texture,
                   light_sources,
                   resolution=256)
    image = torch.movedim(image, -1, 1)
    loss = get_sds_loss(image, batch_embeddings, t_min, t_max, guidance_scale=5.)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    matryoshka_texture.matryoshka_grid.data = matryoshka_texture.matryoshka_grid.data.clamp(0., 1.)
    if i % 16 == 0:
        plot_texture_and_matryoshka(image, matryoshka_texture)

In [None]:
images = render_validation_frames(get_pixel_colors, matryoshka_texture, chunk_size=1)
save_video(images, 'test_renderer.mp4')

Video("test_renderer.mp4", width=512, height=512, html_attributes='loop autoplay')

# Upres Matryoshka

In [None]:
del pipe
del unet
flush()
!nvidia-smi

In [None]:
texture_original = Texture(texture_grid=matryoshka_texture.matryoshka_grid,
                           background_color=torch.as_tensor([0.83, 0.85, 0.87]),
                           resolution=128,
                           train_background=False).to(device)

matryoshka_texture = Texture(texture_grid=matryoshka_texture.matryoshka_grid,
                           background_color=torch.as_tensor([0.83, 0.85, 0.87]),
                           resolution=256,
                           train_background=False).to(device)

In [None]:
pipe = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-II-M-v1.0",
    text_encoder=None, 
    safety_checker=None, 
    watermarker=None,
    feature_extractor=None,
    requires_safety_checker=False,
    variant="fp16",
    torch_dtype=torch.float16,
).to(device)

unet = pipe.unet.eval()
scheduler = pipe.scheduler
num_train_timesteps = scheduler.config.num_train_timesteps
alphas = scheduler.alphas_cumprod.to(torch.device('cuda:0'))

In [None]:
optimizer = torch.optim.Adam(matryoshka_texture.parameters(), lr=1e-2)
generator = torch.Generator().manual_seed(1)

batch_size=1

for i in range(512):
    t_min = 20
    t_max = 980
    # sample camera
    camera_origins, camera_directions, embedding_index = sample_cameras(batch_size, device)
    batch_embeddings = [directional_prompt_embeddings[0][embedding_index],
                        directional_prompt_embeddings[1][embedding_index]]
    # sample light
    light_sources = sample_light(batch_size, device)
    # sample camera_directions
    original_image = render(camera_origins,
                            camera_directions,
                            get_pixel_colors,
                            texture_original,
                            light_sources,
                            resolution=64).detach()
    image = render(camera_origins,
                   camera_directions,
                   get_pixel_colors,
                   matryoshka_texture,
                   light_sources,
                   resolution=256)
    image = torch.movedim(image, -1, 1)
    original_image = torch.movedim(original_image, -1, 1)
    loss = get_sr_sds_loss(image, batch_embeddings, t_min, t_max, guidance_scale=10., original=original_image)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    matryoshka_texture.matryoshka_grid.data = matryoshka_texture.matryoshka_grid.data.clamp(0., 1.)
    if i % 16 == 0:
        plot_texture_and_matryoshka(image, matryoshka_texture)

In [None]:
images = render_validation_frames(get_pixel_colors, matryoshka_texture, chunk_size=1)
save_video(images, 'test_renderer.mp4')

Video("test_renderer.mp4", width=512, height=512, html_attributes='loop autoplay')