In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('/src/')

In [2]:
# !pip install tensorizer

Initialize pipeline standard way:

In [3]:
import os
from typing import List

import torch
from cog import BasePredictor, Input, Path
from diffusers import (
    StableDiffusionPipeline,
    PNDMScheduler,
    LMSDiscreteScheduler,
    DDIMScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    DPMSolverMultistepScheduler,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)

import time
import pprofile

from diffusers.pipeline_utils import DiffusionPipeline

import time
import torch
from tensorizer import TensorDeserializer
from tensorizer.utils import no_init_or_tensor, convert_bytes, get_mem_usage

import diffusers
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [3]:


MODEL_ID = "stabilityai/stable-diffusion-2-1"
MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"

st = time.time() 
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
    SAFETY_MODEL_ID,
    cache_dir=MODEL_CACHE,
    local_files_only=True,
)
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    safety_checker=safety_checker,
    cache_dir=MODEL_CACHE,
    local_files_only=True,
).to("cuda")

print(f'Initialized pipeline in {time.time() - st}')

  from .autonotebook import tqdm as notebook_tqdm


Initialized pipeline in 9.377107620239258


# Prepare and load tensorized model weights

Serialize nn.Module components for faster loading:

In [None]:
tensorized_weights_base_path = "diffusers-cache/"
component_map = {}
for k, component in pipe.components.items():
    if isinstance(component, torch.nn.Module):
        path = os.path.join(tensorized_weights_base_path, f"{k}.tensors")
        serializer = TensorSerializer(path)
        serializer.write_module(component)
        serializer.close()
        component_map[k] = {'path': path, 'cls': type(component)}

In [13]:
# How to load just saftey checker, for eg.

import time
import torch
from tensorizer import TensorDeserializer
from tensorizer.utils import no_init_or_tensor, convert_bytes, get_mem_usage

path = './diffusers-cache/safety_checker.tensors'
with no_init_or_tensor():
    model = StableDiffusionSafetyChecker.from_pretrained(
        SAFETY_MODEL_ID,
        cache_dir=MODEL_CACHE,
        local_files_only=True,
    )

deserializer = TensorDeserializer(path, plaid_mode=True)
deserializer.load_into_module(model)

397

In [53]:
component_map = {}
for k,v in pipe.components.items():
    print(k, type(v))

vae <class 'diffusers.models.vae.AutoencoderKL'>
text_encoder <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>
tokenizer <class 'transformers.models.clip.tokenization_clip.CLIPTokenizer'>
unet <class 'diffusers.models.unet_2d_condition.UNet2DConditionModel'>
scheduler <class 'diffusers.schedulers.scheduling_ddim.DDIMScheduler'>
safety_checker <class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>
feature_extractor <class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>


In [2]:
import time
import torch
from tensorizer import TensorDeserializer
from tensorizer.utils import no_init_or_tensor, convert_bytes, get_mem_usage

import diffusers
import transformers


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
component_map = {
    'vae': {
            'tensorized_weights': 'diffusers-cache/vae.tensors',
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/vae',
            'cls': diffusers.models.vae.AutoencoderKL
        },
    'text_encoder': {
            'tensorized_weights': 'diffusers-cache/text_encoder.tensors',
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/text_encoder',
            'cls': transformers.models.clip.modeling_clip.CLIPTextModel,
        },
    'unet': {
            'tensorized_weights': 'diffusers-cache/unet.tensors',  
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/unet',
            'cls': diffusers.models.unet_2d_condition.UNet2DConditionModel
        },
    'safety_checker': {
            'tensorized_weights': 'diffusers-cache/safety_checker.tensors',
            'path': 'diffusers-cache/models--CompVis--stable-diffusion-safety-checker/snapshots/cb41f3a270d63d454d385fc2e4f571c487c253c5',
            'cls': diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker
        },
    'tokenizer': {
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/tokenizer',
            'cls': transformers.models.clip.tokenization_clip.CLIPTokenizer,
        },
    'feature_extractor': {
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/feature_extractor',
            'cls': transformers.models.clip.image_processing_clip.CLIPImageProcessor,
        },
}

components = {"scheduler": diffusers.schedulers.scheduling_ddim.DDIMScheduler}
st = time.time()
for k in component_map.keys():
    print(f'Loading {k}...')
    cls = component_map[k].get('cls')
    path = component_map[k].get('path')
    tensorized_weights = component_map[k].get('tensorized_weights', None)
    st_i = time.time()

    if tensorized_weights:
        with no_init_or_tensor():
            model = cls.from_pretrained(path)
        
        deserializer = TensorDeserializer(tensorized_weights, plaid_mode=True)
        deserializer.load_into_module(model)

        components[k] = model
    
    else:
        model = cls.from_pretrained(path)
        components[k] = model
    
    print(f'  Loaded in {time.time() - st_i}')
    
pipe = diffusers.StableDiffusionPipeline(**components)
print(f'Initialized pipeline in {time.time() - st}')    


Loading vae...
  Loaded in 1.9844210147857666
Loading text_encoder...
  Loaded in 1.2234342098236084
Loading unet...
  Loaded in 2.93279767036438
Loading safety_checker...
  Loaded in 1.114502191543579
Loading tokenizer...
  Loaded in 0.22099900245666504
Loading feature_extractor...
  Loaded in 0.000606536865234375
Initialized pipeline in 7.478455305099487
