In [1]:
# !pip install --upgrade calflops

In [2]:
import sys
sys.path.append('..')
import model_loader
from transformers import CLIPTokenizer

import torch

DEVICE = "cpu"

ALLOW_CUDA = True
ALLOW_MPS = False

if torch.cuda.is_available() and ALLOW_CUDA:
    DEVICE = "cuda"
elif (torch.has_mps or torch.backends.mps.is_available()) and ALLOW_MPS:
    DEVICE = "mps"
print(f"Using device: {DEVICE}")

tokenizer = CLIPTokenizer("../../data/vocab.json", merges_file="../../data/merges.txt")
model_file = "../../data/v1-5-pruned-emaonly.ckpt"
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)


Using device: cuda


2024-08-15 19:19:25.928010: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
models.keys()

dict_keys(['clip', 'encoder', 'decoder', 'diffusion'])

In [4]:
encoder = models["encoder"].to(DEVICE)
decoder = models["decoder"].to(DEVICE)
clip = models["clip"].to(DEVICE)
diffusion = models["diffusion"].to(DEVICE)

In [5]:
type(encoder), type(decoder), type(clip), type(diffusion)

(encoder.VAE_Encoder, decoder.VAE_Decoder, clip.CLIP, diffusion.Diffusion)

In [6]:
from PIL import Image

prompt = "A trafic sign on a beautiful beach."
uncond_prompt = ""  # Also known as negative prompt
do_cfg = True
cfg_scale = 8  # min: 1, max: 14

## IMAGE TO IMAGE

# Comment to disable image to image
image_path = "../../images/dog.png"
input_image = Image.open(image_path)

In [7]:
from utils import rescale, get_time_embedding
from config import WIDTH, HEIGHT, LATENTS_WIDTH, LATENTS_HEIGHT
from ddpm import DDPMSampler
import numpy as np

In [8]:
cond_tokens = tokenizer.batch_encode_plus(
    [prompt], padding="max_length", max_length=77
)["input_ids"]
# (Batch_Size, Seq_Len)
cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=DEVICE)
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
cond_context = clip(cond_tokens)
# Convert into a list of length Seq_Len=77
uncond_tokens = tokenizer.batch_encode_plus(
    [uncond_prompt], padding="max_length", max_length=77
)["input_ids"]
# (Batch_Size, Seq_Len)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=DEVICE)
# (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
uncond_context = clip(uncond_tokens)
# (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (2 * Batch_Size, Seq_Len, Dim)
context = torch.cat([cond_context, uncond_context])

timestep = 1 
time_embedding = get_time_embedding(timestep).to(DEVICE)


In [9]:
generator = torch.Generator(device=DEVICE)
strength = 0.48
n_inference_steps = 50
sampler = DDPMSampler(generator)
sampler.set_inference_timesteps(n_inference_steps)

latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)

encoder = encoder
encoder.to(DEVICE)
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
input_image_tensor = np.array(input_image_tensor)
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=DEVICE)
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
input_image_tensor = input_image_tensor.unsqueeze(0)
input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)

encoder_noise = torch.randn(latents_shape, generator=generator, device=DEVICE)
latents = encoder(input_image_tensor, encoder_noise)

sampler.set_strength(strength=strength)
latents = sampler.add_noise(latents, sampler.timesteps[0])

In [10]:
from calflops import calculate_flops
batch_size = 1
max_seq_length = 77
flops, macs, params = calculate_flops(model=clip, 
                                    input_shape=(batch_size, max_seq_length),
                                    transformer_tokenizer=tokenizer)

print(f" CLIPTokenizer FLOPS: {flops}, MACs: {macs}, Params: {params}")


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  123.06 M
fwd MACs:                                                               6.65 GMACs
fwd FLOPs:                                                              13.31 GFLOPS
fwd+bwd MACs:                                                           19.95 GMACs
fwd+bwd FLOPs:                                                          39.92 GFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module cacu

In [11]:
encoder = encoder
batch_size = 1
flops, macs, params = calculate_flops(model=encoder, 
                                    args = [input_image_tensor, encoder_noise],
                                    output_as_string=True,
                                    output_precision=4)

print("Encoder FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  34.16 M 
fwd MACs:                                                               558.329 GMACs
fwd FLOPs:                                                              1.1185 TFLOPS
fwd+bwd MACs:                                                           1.675 TMACs
fwd+bwd FLOPs:                                                          3.3554 TFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module

In [12]:
decoder = decoder
decoder.to(DEVICE)

flops, macs, params = calculate_flops(model=decoder, 
                                      args = [latents],
                                      output_as_string=True,
                                      output_precision=4)

print("Decoder FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))


------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  49.49 M 
fwd MACs:                                                               1.2573 TMACs
fwd FLOPs:                                                              2.5179 TFLOPS
fwd+bwd MACs:                                                           3.7718 TMACs
fwd+bwd FLOPs:                                                          7.5536 TFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each module

In [13]:
diffusion = diffusion
diffusion.to(DEVICE)

batch_size = 1
input_shape = (batch_size, 4, LATENTS_HEIGHT, LATENTS_WIDTH)  # Input shape for the diffusion
flops, macs, params = calculate_flops(model=diffusion,
                                      args=[latents, context, time_embedding],
                                      output_as_string=True,
                                      output_precision=4)

print("Diffusion FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))


------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  859.52 M
fwd MACs:                                                               401.334 GMACs
fwd FLOPs:                                                              803.958 GFLOPS
fwd+bwd MACs:                                                           1.204 TMACs
fwd+bwd FLOPs:                                                          2.4119 TFLOPS

-------------------------------- Detailed Calculated FLOPs Results --------------------------------
Each modul