<a href="https://colab.research.google.com/github/rsn870/rank_and_bias_gen/blob/main/consistency_models/junhss_consistency_models_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Consistency Models** 🌃
*...using `consistency`*

**Consistency Models** are a new family of generative models that achieve high sample quality without adversarial training. They support *fast one-step generation* by design, while still allowing for few-step sampling to trade compute for sample quality. It's amazing!

### Setup

Please make sure you are using a GPU runtime to run this notebook. If the following command fails, use the `Runtime` menu above and select `Change runtime type`.

In [None]:
!nvidia-smi

Sun Jun 18 03:29:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install datasets wandb huggingface-hub consistency==0.3.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.13.0-py3-none-any.whl (485 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.6/485.6 kB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting wandb
  Downloading wandb-0.15.4-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting consistency==0.3.1
  Downloading consistency-0.3.1-py3-none-any.whl (7.3 kB)
Collecting pytorch-lightning (from consistency==0.3.1)
  Downloading pytorch_lightning-2.0.3-py3-none-any.whl (720 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m720.6/720.6 kB[0

In [None]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
!git config --global credential.helper store
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|
    
    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) Y
Token is valid (permission: write).
Your token has been saved in your

In [None]:
DATASET_NAME = "cifar10"
DATASET2_NAME = "mnist"
RESOLUTION = 32
BATCH_SIZE = 128
MAX_EPOCHS = 200
LEARNING_RATE = 1e-4
MODEL_ID = f"cm-{DATASET_NAME}-{DATASET2_NAME}-{RESOLUTION}"

SAMPLES_PATH = "./samples"
NUM_SAMPLES = 64
SAMPLE_STEPS = 5  # Set this value larger if you want higher sample quality.

DATA_PATH = './data'

In [None]:
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import ConcatDataset


In [None]:
tf_grayscale_data = transforms.Compose(
        [
            transforms.Resize((RESOLUTION,RESOLUTION)),
            transforms.Grayscale(num_output_channels=3), #Convert to RGB
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
        ]
    )


In [None]:
tf_rgb_data =  transforms.Compose(
        [
            transforms.Resize((RESOLUTION,RESOLUTION)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
        ]
    )


In [None]:
def create_concat_dataset(DATASET_NAME, DATASET2_NAME,DATA_PATH):
    if DATASET_NAME == 'cifar10':
        dataset1 =datasets.CIFAR10(
        DATA_PATH,
        train=True,
        download=True,
        transform=tf_rgb_data,
    )

    elif DATASET_NAME == 'mnist' :
        dataset1 =datasets.MNIST(
        DATA_PATH,
        train=True,
        download=True,
        transform=tf_grayscale_data,
    )

    if DATASET2_NAME == 'cifar10':
        dataset2 =datasets.CIFAR10(
        DATA_PATH,
        train=True,
        download=True,
        transform=tf_rgb_data,
    )

    elif DATASET2_NAME == 'mnist' :
        dataset2 =datasets.MNIST(
        DATA_PATH,
        train=True,
        download=True,
        transform=tf_grayscale_data,
    )

    dataset = ConcatDataset([dataset1, dataset2])

    dataset = torch.cat([dataset[i][0].unsqueeze(0) for i in range(len(dataset))])





    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    return dataloader






In [None]:
dataloader = create_concat_dataset(DATASET_NAME,DATASET2_NAME,DATA_PATH)

Files already downloaded and verified


In [None]:
"""
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name: str, dataset_config_name=None):
        self.dataset = load_dataset(
            dataset_name,
            dataset_config_name,
            split="train",
        )
        self.image_key = [
            key for key in ("image", "img") if key in self.dataset[0]
        ][0]
        self.augmentations = transforms.Compose(
    [
        transforms.Resize(
            RESOLUTION,
            interpolation=transforms.InterpolationMode.BILINEAR,
        ),
        transforms.CenterCrop(RESOLUTION),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index: int) -> torch.Tensor:
        return self.augmentations(self.dataset[index][self.image_key].convert("RGB"))

dataloader = DataLoader(
    Dataset(DATASET_NAME),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)

"""

### Define Models

`Consistency` accepts any unet-like model as its backbone.
We recommend `UNet2DModel` of `diffusers` 🤗 as a default option.

In [None]:
from diffusers import UNet2DModel
from consistency import Consistency
from consistency.loss import PerceptualLoss

consistency = Consistency(
    model=UNet2DModel(
        sample_size=RESOLUTION,
        in_channels=3,
        out_channels=3,
        layers_per_block=1,
        block_out_channels=(128, 128, 256, 256),
        down_block_types=(
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D",
            "DownBlock2D"
        ),
        up_block_types=(
            "UpBlock2D",
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
        ),
    ),
    # You could use multiple net types.
    # Recommended setting is "squeeze" + "vgg"
    # loss_fn=PerceptualLoss(net_type=("squeeze", "vgg"))
    # See https://github.com/richzhang/PerceptualSimilarity
    loss_fn=PerceptualLoss(net_type="squeeze"),
    learning_rate=LEARNING_RATE,
    samples_path=SAMPLES_PATH,
    save_samples_every_n_epoch=1,
    num_samples=NUM_SAMPLES,
    sample_steps=SAMPLE_STEPS,
    use_ema=True,
    sample_seed=42,
    model_id=MODEL_ID,
)

Downloading: "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_1-b8a52dc0.pth
100%|██████████| 4.73M/4.73M [00:00<00:00, 74.0MB/s]
  self.channels = model.in_channels


In [None]:
import torch

In [None]:
def get_activation_value(cmodel,input,timestep):
  """
  Custom function to get bottleneck activations from Hugging Face UNet 2D Bottleneck. Timestep value is
  a long tensor of size (BS,). Adjust as per the diffusion schedule which is Karras as per the paper.
cmodel is a model inside a Consistency Wrapper.

  """
  activation = {}
  name = 'bottleneck'
  def getActivation(name):
    # the hook signature
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
  h = cmodel.model.mid_block.register_forward_hook(getActivation(name))
  _ = cmodel(input,timestep)
  h.remove()
  return activation['bottleneck']


### Training

You can see the generated images in `SAMPLES_PATH` or in **Wandb Workspace** as the training progresses.

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.wandb import WandbLogger

trainer = Trainer(
    accelerator="auto",
    logger=WandbLogger(project="consistency", log_model=True),
    callbacks=[
        ModelCheckpoint(
            dirpath="ckpt",
            save_top_k=3,
            monitor="loss",
        )
    ],
    max_epochs=MAX_EPOCHS,
    precision=16 if torch.cuda.is_available() else 32,
    log_every_n_steps=30,
    gradient_clip_algorithm="norm",
    gradient_clip_val=1.0,
)

trainer.fit(consistency, dataloader)

  rank_zero_warn(
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/content/cm-cifar10-mnist-32 is already a clone of https://huggingface.co/rsn243/cm-cifar10-mnist-32. Make sure you pull the latest changes with `repo.git_pull()`.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name               | Type           | Params
------------------------------------------------------
0 | model              | UNet2DModel    | 18.9 M
1 | model_ema          | UNet2DModel    | 18.9 M
2 | loss_fn            | PerceptualLoss | 724 K 
3 | _lo

Training: 0it [00:00, ?it/s]

Several commits (2) will be pushed upstream.
Several commits (3) will be pushed upstream.
Several commits (4) will be pushed upstream.
Several commits (5) will be pushed upstream.
Several commits (6) will be pushed upstream.
Several commits (7) will be pushed upstream.
Several commits (8) will be pushed upstream.
Several commits (9) will be pushed upstream.
Several commits (10) will be pushed upstream.
Several commits (11) will be pushed upstream.
Several commits (12) will be pushed upstream.
Several commits (13) will be pushed upstream.
Several commits (14) will be pushed upstream.
Several commits (15) will be pushed upstream.
Several commits (16) will be pushed upstream.
Several commits (17) will be pushed upstream.
Several commits (18) will be pushed upstream.
Several commits (19) will be pushed upstream.
Several commits (20) will be pushed upstream.
Several commits (21) will be pushed upstream.
Several commits (22) will be pushed upstream.
Several commits (23) will be pushed upstre

### Generate samples

You can now `sample` high quality images! 🎉

In [None]:
consistency.sample(64, sample_steps=20)