In [1]:
%%capture
!pip install diffusers

In [2]:
%%capture
!pip install accelerate

In [3]:
import os
import tqdm
import torch
import numpy as np

from PIL import Image
from tqdm.auto import tqdm
from google.colab import drive

from diffusers import UNet2DModel
from diffusers import ScoreSdeVePipeline
from diffusers import ScoreSdeVeScheduler
from diffusers.utils import randn_tensor

from accelerate import Accelerator
from accelerate import notebook_launcher

torch.manual_seed(42)

<torch._C.Generator at 0x7fbb6c0eb5f0>

In [4]:
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [5]:
model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id)
sde_scheduler = ScoreSdeVeScheduler.from_config(model_id)
sde_pipeline = ScoreSdeVePipeline(unet=model, scheduler=sde_scheduler)
sde_pipeline.to("cuda")

Downloading (…)_pytorch_model.bin";:   0%|          | 0.00/263M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/900 [00:00<?, ?B/s]

Downloading (…)cheduler_config.json:   0%|          | 0.00/211 [00:00<?, ?B/s]

ScoreSdeVePipeline {
  "_class_name": "ScoreSdeVePipeline",
  "_diffusers_version": "0.12.1",
  "scheduler": [
    "diffusers",
    "ScoreSdeVeScheduler"
  ],
  "unet": [
    "diffusers",
    "UNet2DModel"
  ]
}

In [6]:
model

UNet2DModel(
  (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): GaussianFourierProjection()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=256, out_features=512, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=512, out_features=512, bias=True)
  )
  (down_blocks): ModuleList(
    (0): SkipDownBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
        (1): ResnetBlock2D(
          (norm1): GroupNorm(32, 128, eps=1e-06, affin

In [7]:
model.config

FrozenDict([('sample_size', 256),
            ('in_channels', 3),
            ('out_channels', 3),
            ('center_input_sample', True),
            ('time_embedding_type', 'fourier'),
            ('freq_shift', 0),
            ('flip_sin_to_cos', True),
            ('down_block_types',
             ['SkipDownBlock2D',
              'SkipDownBlock2D',
              'SkipDownBlock2D',
              'SkipDownBlock2D',
              'AttnSkipDownBlock2D',
              'SkipDownBlock2D',
              'SkipDownBlock2D']),
            ('up_block_types',
             ['SkipUpBlock2D',
              'SkipUpBlock2D',
              'AttnSkipUpBlock2D',
              'SkipUpBlock2D',
              'SkipUpBlock2D',
              'SkipUpBlock2D',
              'SkipUpBlock2D']),
            ('block_out_channels', [128, 128, 256, 256, 256, 256, 256]),
            ('layers_per_block', 2),
            ('mid_block_scale_factor', 1.41421356237),
            ('downsample_padding', 1),
            

In [8]:
sde_scheduler.config

FrozenDict([('num_train_timesteps', 2000),
            ('snr', 0.075),
            ('sigma_min', 0.01),
            ('sigma_max', 380),
            ('sampling_eps', 1e-05),
            ('correct_steps', 1),
            ('_class_name', 'ScoreSdeVeScheduler'),
            ('_diffusers_version', '0.1.1')])

In [9]:
noisy_sample = torch.randn(
    1, model.config.in_channels, model.config.sample_size, model.config.sample_size
)
noisy_sample.shape

torch.Size([1, 3, 256, 256])

In [10]:
def display_sample(sample, i):
    sample = sample_mean.clamp(0, 1)
    image = sample.cpu().permute(0, 2, 3, 1)
    #image = (image + 1.0) * 127.5
    image = image * 255
    image = image.numpy().astype(np.uint8)
    image = Image.fromarray(image[0])
    image_name = str(i) + ".jpg"
    path = "/content/gdrive/MyDrive/sde_inference_steps/" + image_name
    image.save(path)
    display(f"Image at step {i}")
    display(image)

In [11]:
model.to("cuda")
noisy_sample = noisy_sample.to("cuda")

In [12]:
shape = (1, 3, 256, 256)

In [None]:
shape[0]

1

In [13]:
sample = randn_tensor(shape) * sde_scheduler.init_noise_sigma
sample = sample.to("cuda")

In [14]:
print(sample)

tensor([[[[ 3.1410e+02, -9.1769e+01,  5.8462e+02,  ...,  1.9113e+02,
            8.8317e+02, -1.3143e+01],
          [ 2.4127e+02,  2.2784e+02, -3.9474e+01,  ...,  1.1346e+02,
            2.7603e+02, -8.6217e+01],
          [-1.8849e+02, -4.2146e+02, -1.8038e+02,  ...,  2.2226e+02,
            3.8622e+02,  1.1107e+02],
          ...,
          [ 9.0120e+01,  3.9847e+02, -4.3862e+02,  ...,  1.8115e+02,
            8.4654e+02,  2.9201e+02],
          [ 5.3998e+02, -2.6542e+02, -1.1382e+02,  ...,  2.3108e+02,
           -6.6965e+02, -8.8518e+02],
          [ 1.2579e+02, -3.9534e+02, -2.5669e+02,  ...,  1.6949e+02,
           -1.3026e+03,  2.9585e+02]],

         [[-9.2616e+01, -4.9523e+02, -8.6889e+01,  ...,  8.4802e+02,
            2.1607e+02, -1.1474e+01],
          [-3.2581e+02,  3.3974e+02, -1.5246e+02,  ...,  2.5194e+02,
            5.5711e+02,  4.3396e+02],
          [ 6.0067e+01, -3.3274e+02,  2.9332e+02,  ..., -2.5253e+02,
           -7.6968e+02, -4.4433e+02],
          ...,
     

In [15]:
sample.shape

torch.Size([1, 3, 256, 256])

In [16]:
accelerator = Accelerator(
    mixed_precision='fp16',
    gradient_accumulation_steps=1,
)
if accelerator.is_main_process:
    accelerator.init_trackers("train_example")

In [17]:
for i, t in enumerate(tqdm(sde_scheduler.timesteps)):
    sigma_t = sde_scheduler.sigmas[i] * torch.ones(shape[0], device="cuda")
    #print(sigma_t)
    #print(len(sigma_t))

    # correction step
    for _ in range(sde_scheduler.config.correct_steps):
        with torch.no_grad():
            model_output = model(sample, sigma_t).sample
            sample = sde_scheduler.step_correct(model_output, sample).prev_sample

    # prediction step
    with torch.no_grad():
        model_output = model(sample, sigma_t).sample
        output = sde_scheduler.step_pred(model_output, t, sample)

    sample, sample_mean = output.prev_sample, output.prev_sample_mean

    if (i + 1) % 100 == 0:
      display_sample(sample, i + 1)

#sample = sample_mean.clamp(0, 1)
#sample = sample.cpu().permute(0, 2, 3, 1).numpy()

Output hidden; open in https://colab.research.google.com to view.