In [1]:
import torch
import torch.nn as nn

In [5]:
import torch

@torch.no_grad()
def compute_gradient(x, y):
    # Even though we're inside a no_grad context, we can still compute gradients using autograd.grad
    grad_x = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y), retain_graph=True, create_graph=True)[0]
    return grad_x

# Define a tensor with requires_grad=True
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2

gradient = compute_gradient(x, y)
print(gradient)  # This should print tensor([4.], grad_fn=<MulBackward0>)


tensor([4.], grad_fn=<MulBackward0>)


In [6]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

In [7]:
model_key = "runwayml/stable-diffusion-v1-5"

In [8]:
toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", cache_dir='/home/tyk/hf_cache')

In [9]:
toy_scheduler

DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.20.2",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}

In [10]:
# toy_scheduler.timesteps

In [11]:
toy_scheduler.set_timesteps(70)
toy_scheduler

DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.20.2",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}

In [12]:
toy_scheduler.timesteps

tensor([967, 953, 939, 925, 911, 897, 883, 869, 855, 841, 827, 813, 799, 785,
        771, 757, 743, 729, 715, 701, 687, 673, 659, 645, 631, 617, 603, 589,
        575, 561, 547, 533, 519, 505, 491, 477, 463, 449, 435, 421, 407, 393,
        379, 365, 351, 337, 323, 309, 295, 281, 267, 253, 239, 225, 211, 197,
        183, 169, 155, 141, 127, 113,  99,  85,  71,  57,  43,  29,  15,   1])

In [14]:
timesteps = reversed(toy_scheduler.timesteps)
timesteps

tensor([  1,  15,  29,  43,  57,  71,  85,  99, 113, 127, 141, 155, 169, 183,
        197, 211, 225, 239, 253, 267, 281, 295, 309, 323, 337, 351, 365, 379,
        393, 407, 421, 435, 449, 463, 477, 491, 505, 519, 533, 547, 561, 575,
        589, 603, 617, 631, 645, 659, 673, 687, 701, 715, 729, 743, 757, 771,
        785, 799, 813, 827, 841, 855, 869, 883, 897, 911, 925, 939, 953, 967])

In [15]:
toy_scheduler.timesteps[0:]

tensor([967, 953, 939, 925, 911, 897, 883, 869, 855, 841, 827, 813, 799, 785,
        771, 757, 743, 729, 715, 701, 687, 673, 659, 645, 631, 617, 603, 589,
        575, 561, 547, 533, 519, 505, 491, 477, 463, 449, 435, 421, 407, 393,
        379, 365, 351, 337, 323, 309, 295, 281, 267, 253, 239, 225, 211, 197,
        183, 169, 155, 141, 127, 113,  99,  85,  71,  57,  43,  29,  15,   1])

In [17]:
init_timestep = min(int(70 * 1.0), 70)
init_timestep

70

In [19]:
t_start = max(70 - init_timestep, 0)
t_start

0

In [None]:
def get_timesteps(scheduler, num_inference_steps, strength, device):
    # get the original timestep using init_timestep
    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

    t_start = max(num_inference_steps - init_timestep, 0)
    timesteps = scheduler.timesteps[t_start:]

    return timesteps, num_inference_steps - t_start

In [3]:
frames_path = 'test'
n_frames = 70
# [f"{frames_path}/%05d.png" % i for i in range(n_frames)]

In [17]:
timesteps_to_save =  timesteps

In [18]:
for i, t in enumerate(timesteps):
    print(i, t)

0 tensor(1)
1 tensor(15)
2 tensor(29)
3 tensor(43)
4 tensor(57)
5 tensor(71)
6 tensor(85)
7 tensor(99)
8 tensor(113)
9 tensor(127)
10 tensor(141)
11 tensor(155)
12 tensor(169)
13 tensor(183)
14 tensor(197)
15 tensor(211)
16 tensor(225)
17 tensor(239)
18 tensor(253)
19 tensor(267)
20 tensor(281)
21 tensor(295)
22 tensor(309)
23 tensor(323)
24 tensor(337)
25 tensor(351)
26 tensor(365)
27 tensor(379)
28 tensor(393)
29 tensor(407)
30 tensor(421)
31 tensor(435)
32 tensor(449)
33 tensor(463)
34 tensor(477)
35 tensor(491)
36 tensor(505)
37 tensor(519)
38 tensor(533)
39 tensor(547)
40 tensor(561)
41 tensor(575)
42 tensor(589)
43 tensor(603)
44 tensor(617)
45 tensor(631)
46 tensor(645)
47 tensor(659)
48 tensor(673)
49 tensor(687)
50 tensor(701)
51 tensor(715)
52 tensor(729)
53 tensor(743)
54 tensor(757)
55 tensor(771)
56 tensor(785)
57 tensor(799)
58 tensor(813)
59 tensor(827)
60 tensor(841)
61 tensor(855)
62 tensor(869)
63 tensor(883)
64 tensor(897)
65 tensor(911)
66 tensor(925)
67 tensor(939)

In [None]:
import torch

class Preprocess(nn.Module):
    def __init__(self, device, opt, hf_key=None):
        super().__init__()

        self.device = device
        self.sd_version = opt.sd_version
        self.use_depth = False

        self.model_key = model_key
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16",
                                                 torch_dtype=torch.float16,
                                                 cache=).to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer",
                                                       cache=)
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16",
                                                          torch_dtype=torch.float16).to(self.device)
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16",
                                                   torch_dtype=torch.float16).to(self.device)

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")


        self.paths, self.frames, self.latents = self.get_data(opt.data_path, opt.n_frames)


    def get_data(self, frames_path, n_frames):
        # frame data가 있는 위치
        # n_frames갯수만큼 path 0~n_frames
        paths = [f'{frames_path}/%05d.png' % i for i in range(n_frames)]
        if not os.path.exists(paths[0]):
            paths = [f"{frames_path}/%05d.jpg" % i for i in range(n_frames)]
        self.paths = paths

        # 이미지 오픈
        frames = [Image.open(path).convert('RGB') for path in paths]
        if frames[0].size[0] == frames[0].size[1]:
            # 사이즈 처리
            frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
        frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(self.device)

        latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
        return paths, frames, latents

    @torch.no_grad()
    def encode_imgs(self, imgs, batch_size=10, deterministic=True):
        imgs = 2 * imgs - 1
        latents = []
        for i in range(0, len(imgs), batch_size):
            # vae로 넣음
            posterior = self.vae.encode(imgs[i:i + batch_size]).latent_dist
            latent = posterior.mean if deterministic else posterior.sample()
            latents.append(latent * 0.18215)
        latents = torch.cat(latents)
        return latents

    @torch.no_grad()
    def extract_latents(self,
                        num_steps,
                        save_path,
                        batch_size,
                        timesteps_to_save,
                        inversion_prompt=''):
        self.scheduler.set_timesteps(num_steps)
        cond = self.get_text_embeds()
        latent_frames = self.latents

        inverted_x = self.ddim_inversion(cond,
                                         latent_frames,
                                         save_path,
                                         batch_size=batch_size,
                                         save_latents=True,
                                         timesteps_to_save=timesteps_to_save)
        latent_reconstruction = self.ddim_sample(inverted_x, cond, batch_size=batch_size)
        
        rgb_reconstruction = self.decode_latents(latent_reconstruction)
        return rgb_reconstruction

    @torch.no_grad()
    def ddim_sample(self, x, cond, batch_size):
        timesteps = self.scheduler.timesteps
        for i, t in enumerate(tqdm(timesteps)):
            for b in range(0, x.shape[0], batch_size):
                x_batch = x[b: b+batch_size]
                model_input = x_batch
                cond_batch = cond.repeat(x_batch.shape[0], 1, 1)

                alpha_prod_t = self.scheduler.alphas_cumprod[t]
                alpha_prod_t_prev = (self.scheduler.alphas_cumprod[timesteps[i+1]]
                                     if i < len(timesteps) - 1
                                     else self.scheduler.final_alpha_cumprod)

                mu = alpha_prod_t ** 0.5
                sigma = (1 - alpha_prod_t) ** 0.5
                mu_prev = alpha_prod_t_prev ** 0.5
                sigma_prev = (1 - alpha_prod_t_prev) ** 0.5

                if self.sd_version != 'ControlNet':
                    eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample
                else:
                    eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b+batch_size]]))

                pred_x0 = (x_batch - sigma * eps) / mu
                x[b:b+batch_size] = mu_prev * pred_x0 + sigma_prev * eps

        return x

    @torch.no_grad()
    def decode_latents(self, latents):
        decoded = []
        batch_size = 8
        for b in range(0, latents.shape[0], batch_size):
            latents_batch = 1 /  0.18215 * latents[b:b + batch_size]
            imgs = self.vae.decode(latents_batch).sample
            imgs = (imgs / 2 + 0.5).clamp(0, 1)
            decoded.append(imgs)
        return torch.cat(decoded)


    @torch.no_grad()
    def ddim_inversion(self, cond, ):
        timesteps = reversed(self.scheduler.timesteps)
        timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps
        for i, t in enumerate(timesteps):
            # 0부터 1000 순으로 됨
            for b in range(0, latent_frames.shape[0], batch_size):
                x_batch = latent_frames[b:b + batch_size]
                model_input = x_batch
                cond_batch = cond.repeat()

                alpha_prod_t = self.scheduler.alphas_cumprod[t]
                alpha_prod_t_prev = (self.scheduler.alphas_cumprod[timesteps[i-1]]
                                     if i > 0 else self.scheduler.final_alphas_cumprod)

                mu = alpha_prod_t ** 0.5
                mu_prev = alpha_prod_t_prev ** 0.5
                sigma = (1 - alpha_prod_t) ** 0.5
                sigma_prev = (1 - alpha_prod_t_prev) ** 0.5

                if self.sd_version != 'ControlNet':
                    eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample
                else:
                    eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))

                pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
                latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps

            if save_latents and t in timesteps_to_save:
                torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
        torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
        return latent_frames