In [21]:
%load_ext autoreload
%autoreload 2

In [1]:
import os
from typing import List

import torch
from cog import BasePredictor, Input, Path
from diffusers import (
    StableDiffusionPipeline,
    PNDMScheduler,
    LMSDiscreteScheduler,
    DDIMScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    DPMSolverMultistepScheduler,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)

import time
import pprofile


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from diffusers.pipeline_utils import DiffusionPipeline

In [3]:
os.chdir('/src/')

In [2]:
MODEL_ID = "stabilityai/stable-diffusion-2-1"
MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"

In [9]:
# !pip install tensorizer

In [31]:

# Serialize safety checker
import torch
from tensorizer import TensorSerializer
from transformers import AutoModelForCausalLM

safety_checker = StableDiffusionSafetyChecker.from_pretrained(
    SAFETY_MODEL_ID,
    cache_dir=MODEL_CACHE,
    local_files_only=True,
).to('cuda')

path = 'diffusers-cache/safety_checker.tensors'
serializer = TensorSerializer(path)
serializer.write_module(safety_checker)
serializer.close()

In [67]:
SAFETY_MODEL_ID

'CompVis/stable-diffusion-safety-checker'

In [68]:
# Load tensorized safety checker

import time
import torch
from tensorizer import TensorDeserializer
from tensorizer.utils import no_init_or_tensor, convert_bytes, get_mem_usage

with no_init_or_tensor():
    model = StableDiffusionSafetyChecker.from_pretrained(
        # SAFETY_MODEL_ID,
        # cache_dir=MODEL_CACHE,
        # local_files_only=True,
    )

deserializer = TensorDeserializer(path, plaid_mode=True)
deserializer.load_into_module(model)

397

In [13]:
# Serialize safety checker

pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    safety_checker=safety_checker,
    cache_dir=MODEL_CACHE,
    local_files_only=True,
).to("cuda")

tensorized_weights_base_path = "diffusers-cache/"
component_map = {}
for k, component in pipe.components.items():
    if isinstance(component, torch.nn.Module):
        path = os.path.join(tensorized_weights_base_path, f"{k}.tensors")
        serializer = TensorSerializer(path)
        serializer.write_module(component)
        serializer.close()
        component_map[k] = path


In [48]:
from diffusers import (
    AutoencoderKL, 
    CLIPTextModel,
    
)


In [53]:
component_map = {}
for k,v in pipe.components.items():
    print(k, type(v))

vae <class 'diffusers.models.vae.AutoencoderKL'>
text_encoder <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>
tokenizer <class 'transformers.models.clip.tokenization_clip.CLIPTokenizer'>
unet <class 'diffusers.models.unet_2d_condition.UNet2DConditionModel'>
scheduler <class 'diffusers.schedulers.scheduling_ddim.DDIMScheduler'>
safety_checker <class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>
feature_extractor <class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>


In [70]:
# Load StableDiffusion Components

# from diffusers import (
#     AutoencoderKL,
#     UNet2DConditionModel,
# )

# from diffusers.pipelines.stable_diffusion.safety_checker import (
#     StableDiffusionSafetyChecker,
# )

# from transformers import (
#     CLIPTextModel,
#     CLIPTokenizer,
#     CLIPImageProcessor
# )

import diffusers
import transformers

component_map = {
    'vae': {
            'tensorized_weights': 'diffusers-cache/vae.tensors',
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/vae',
            'cls': diffusers.models.vae.AutoencoderKL
        },
    'text_encoder': {
            'tensorized_weights': 'diffusers-cache/text_encoder.tensors',
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/text_encoder',
            'cls': transformers.models.clip.modeling_clip.CLIPTextModel,
        },
    'unet': {
            'tensorized_weights': 'diffusers-cache/unet.tensors',  
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/unet',
            'cls': diffusers.models.unet_2d_condition.UNet2DConditionModel
        },
    'safety_checker': {
            'tensorized_weights': 'diffusers-cache/safety_checker.tensors',
            'path': 'diffusers-cache/models--CompVis--stable-diffusion-safety-checker/snapshots/cb41f3a270d63d454d385fc2e4f571c487c253c5',
            'cls': diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker
        },
    'tokenizer': {
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/tokenizer',
            'cls': transformers.models.clip.tokenization_clip.CLIPTokenizer,
        },
    'feature_extractor': {
            'path': 'diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/feature_extractor',
            'cls': transformers.models.clip.image_processing_clip.CLIPImageProcessor,
        },
}

components = {}
for k in component_map.keys():
    print(f'Loading {k}...')
    cls = component_map[k].get('cls')
    path = component_map[k].get('path')
    tensorized_weights = component_map[k].get('tensorized_weights', None)

    if tensorized_weights:
        with no_init_or_tensor():
            model = cls.from_pretrained(path)
        
        deserializer = TensorDeserializer(tensorized_weights, plaid_mode=True)
        deserializer.load_into_module(model)

        components[k] = model
    
    else:
        model = cls.from_pretrained(path)
        components[k] = model


Loading vae...
Loading text_encoder...
Loading unet...
Loading safety_checker...
Loading tokenizer...
Loading feature_extractor...


In [65]:
component_map = {}
for k,v in pipe.components.items():
    print(k, type(v))

vae <class 'diffusers.models.vae.AutoencoderKL'>
text_encoder <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>
tokenizer <class 'transformers.models.clip.tokenization_clip.CLIPTokenizer'>
unet <class 'diffusers.models.unet_2d_condition.UNet2DConditionModel'>
scheduler <class 'diffusers.schedulers.scheduling_ddim.DDIMScheduler'>
safety_checker <class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>
feature_extractor <class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>


In [None]:

with no_init_or_tensor():
    model = StableDiffusionSafetyChecker.from_pretrained(
        SAFETY_MODEL_ID,
        cache_dir=MODEL_CACHE,
        local_files_only=True,
    )

deserializer = TensorDeserializer(path, plaid_mode=True)
deserializer.load_into_module(model)


component.from_json_file('diffusers-cache/models--stabilityai--stable-diffusion-2-1/snapshots/845609e6cf0a060d8cd837297e5c169df5bff72c/feature_extractor/preprocessor_config.json')

{'vae': {'path': 'diffusers-cache/vae.tensors',
  'cls': diffusers.models.vae.AutoencoderKL},
 'text_encoder': {'path': 'diffusers-cache/text_encoder.tensors',
  'cls': transformers.models.clip.modeling_clip.CLIPTextModel},
 'unet': {'path': 'diffusers-cache/unet.tensors',
  'cls': diffusers.models.unet_2d_condition.UNet2DConditionModel},
 'safety_checker': {'path': 'diffusers-cache/safety_checker.tensors',
  'cls': diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker}}

In [43]:
tensorized_weights_base_path = "diffusers-cache/"
component_map = {}
for k, component in pipe.components.items():
    if isinstance(component, torch.nn.Module):
        path = os.path.join(tensorized_weights_base_path, f"{k}.tensors")
        # serializer = TensorSerializer(path)
        # serializer.write_module(component)
        # serializer.close()
        component_map[k] = {'path': path, 'cls': type(component)}

In [44]:
component_map

{'vae': {'path': 'diffusers-cache/vae.tensors',
  'cls': diffusers.models.vae.AutoencoderKL},
 'text_encoder': {'path': 'diffusers-cache/text_encoder.tensors',
  'cls': transformers.models.clip.modeling_clip.CLIPTextModel},
 'unet': {'path': 'diffusers-cache/unet.tensors',
  'cls': diffusers.models.unet_2d_condition.UNet2DConditionModel},
 'safety_checker': {'path': 'diffusers-cache/safety_checker.tensors',
  'cls': diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker}}

In [37]:
pipe.components.keys()

dict_keys(['vae', 'text_encoder', 'tokenizer', 'unet', 'scheduler', 'safety_checker', 'feature_extractor'])

In [17]:
isinstance(pipe.components, torch.nn.Module)

False

vae <class 'diffusers.models.vae.AutoencoderKL'>
text_encoder <class 'transformers.models.clip.modeling_clip.CLIPTextModel'>
tokenizer <class 'transformers.models.clip.tokenization_clip.CLIPTokenizer'>
unet <class 'diffusers.models.unet_2d_condition.UNet2DConditionModel'>
scheduler <class 'diffusers.schedulers.scheduling_ddim.DDIMScheduler'>
safety_checker <class 'diffusers.pipelines.stable_diffusion.safety_checker.StableDiffusionSafetyChecker'>
feature_extractor <class 'transformers.models.clip.image_processing_clip.CLIPImageProcessor'>


In [20]:
pipe.components.keys()

dict_keys(['vae', 'text_encoder', 'tokenizer', 'unet', 'scheduler', 'safety_checker', 'feature_extractor'])