In [1]:
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, StableDiffusionPipeline, \
    EulerDiscreteScheduler
import torchvision.transforms as torch_transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
from diffusion.utils import  get_formatstr

  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
  warn(


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"


In [12]:
MODEL_IDS = {
    '1-1': "CompVis/stable-diffusion-v1-1",
    '1-2': "CompVis/stable-diffusion-v1-2",
    '1-3': "CompVis/stable-diffusion-v1-3",
    '1-4': "CompVis/stable-diffusion-v1-4",
    '1-5': "runwayml/stable-diffusion-v1-5",
    '2-0': "stabilityai/stable-diffusion-2-base",
    '2-1': "stabilityai/stable-diffusion-2-1-base"
}

def get_sd_model(version = '1-1' , dtype='float32'):
    if dtype == 'float32':
        dtype = torch.float32
    elif dtype == 'float16':
        dtype = torch.float16
    else:
        raise NotImplementedError

    assert version in MODEL_IDS.keys()
    model_id = MODEL_IDS[version]
    scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
    pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=dtype)
    pipe.enable_xformers_memory_efficient_attention()
    vae = pipe.vae
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    unet = pipe.unet
    return vae, tokenizer, text_encoder, unet, scheduler

In [23]:
def get_scheduler_config(version= '1-1'):
    #if version in {'1-1', '1-2', '1-3', '1-4', '1-5'}:
    config = {
        "_class_name": "EulerDiscreteScheduler",
        "_diffusers_version": "0.14.0",
        "beta_end": 0.012,
        "beta_schedule": "scaled_linear",
        "beta_start": 0.00085,
        "interpolation_type": "linear",
        "num_train_timesteps": 1000,
        "prediction_type": "epsilon",
        "set_alpha_to_one": False,
        "skip_prk_steps": True,
        "steps_offset": 1,
        "trained_betas": None
    }
    # elif version in {'2-0', '2-1'}:
    #     config = {
    #         "_class_name": "EulerDiscreteScheduler",
    #         "_diffusers_version": "0.10.2",
    #         "beta_end": 0.012,
    #         "beta_schedule": "scaled_linear",
    #         "beta_start": 0.00085,
    #         "clip_sample": False,
    #         "num_train_timesteps": 1000,
    #         "prediction_type": "epsilon",
    #         "set_alpha_to_one": False,
    #         "skip_prk_steps": True,
    #         "steps_offset": 1,  # todo
    #         "trained_betas": None
    #     }
    # else:
    #     raise NotImplementedError

    return config

Zero-shot Classification with Stable Diffusion
python eval_prob_adaptive.py --dataset cifar10 --split test --n_trials 1 \
  --to_keep 5 1 --n_samples 50 500 --loss l1 \
  --prompt_path prompts/cifar10_prompts.csv

In [24]:

import torch.nn.functional as F
from tqdm import tqdm

def eval_error(unet, scheduler, latent, all_noise, ts, noise_idxs,
               text_embeds, text_embed_idxs, batch_size=32, dtype='float32', loss='l2'):
    assert len(ts) == len(noise_idxs) == len(text_embed_idxs)
    pred_errors = torch.zeros(len(ts), device='cpu')
    idx = 0
    with torch.inference_mode():
        for _ in tqdm.trange(len(ts) // batch_size + int(len(ts) % batch_size != 0), leave=False):
            batch_ts = torch.tensor(ts[idx: idx + batch_size])
            noise = all_noise[noise_idxs[idx: idx + batch_size]]
            noised_latent = latent * (scheduler.alphas_cumprod[batch_ts] ** 0.5).view(-1, 1, 1, 1).to(device) + \
                            noise * ((1 - scheduler.alphas_cumprod[batch_ts]) ** 0.5).view(-1, 1, 1, 1).to(device)
            t_input = batch_ts.to(device).half() if dtype == 'float16' else batch_ts.to(device)
            text_input = text_embeds[text_embed_idxs[idx: idx + batch_size]]
            noise_pred = unet(noised_latent, t_input, encoder_hidden_states=text_input).sample
            if loss == 'l2':
                error = F.mse_loss(noise, noise_pred, reduction='none').mean(dim=(1, 2, 3))
            elif loss == 'l1':
                error = F.l1_loss(noise, noise_pred, reduction='none').mean(dim=(1, 2, 3))
            elif loss == 'huber':
                error = F.huber_loss(noise, noise_pred, reduction='none').mean(dim=(1, 2, 3))
            else:
                raise NotImplementedError
            pred_errors[idx: idx + len(batch_ts)] = error.detach().cpu()
            idx += len(batch_ts)
    return pred_errors


In [25]:
import os
import os.path as osp
import torch
import pandas as pd
import numpy as np
import tqdm


def eval_prob_adaptive(unet, latent, text_embeds, scheduler, n_samples, to_keep, batch_size, dtype, loss, latent_size=64, all_noise=None, n_trials=1):
    scheduler_config = get_scheduler_config(scheduler)
    T = scheduler_config['num_train_timesteps']
    max_n_samples = max(n_samples)

    if all_noise is None:
        all_noise = torch.randn((max_n_samples * n_trials, 4, latent_size, latent_size), device=latent.device)
    if dtype == 'float16':
        all_noise = all_noise.half()
        scheduler.alphas_cumprod = scheduler.alphas_cumprod.half()

    data = dict()
    t_evaluated = set()
    remaining_prmpt_idxs = list(range(len(text_embeds)))
    start = T // max_n_samples // 2
    t_to_eval = list(range(start, T, T // max_n_samples))[:max_n_samples]

    for n_samples_val, n_to_keep_val in zip(n_samples, to_keep):
        ts = []
        noise_idxs = []
        text_embed_idxs = []
        curr_t_to_eval = t_to_eval[len(t_to_eval) // n_samples_val // 2::len(t_to_eval) // n_samples_val][:n_samples_val]
        curr_t_to_eval = [t for t in curr_t_to_eval if t not in t_evaluated]
        for prompt_i in remaining_prmpt_idxs:
            for t_idx, t in enumerate(curr_t_to_eval, start=len(t_evaluated)):
                ts.extend([t] * n_trials)
                noise_idxs.extend(list(range(n_trials * t_idx, n_trials * (t_idx + 1))))
                text_embed_idxs.extend([prompt_i] * n_trials)
        t_evaluated.update(curr_t_to_eval)
        pred_errors = eval_error(unet, scheduler, latent, all_noise, ts, noise_idxs,
                                 text_embeds, text_embed_idxs, batch_size, dtype, loss)
        # match up computed errors to the data
        for prompt_i in remaining_prmpt_idxs:
            mask = torch.tensor(text_embed_idxs) == prompt_i
            prompt_ts = torch.tensor(ts)[mask]
            prompt_pred_errors = pred_errors[mask]
            if prompt_i not in data:
                data[prompt_i] = dict(t=prompt_ts, pred_errors=prompt_pred_errors)
            else:
                data[prompt_i]['t'] = torch.cat([data[prompt_i]['t'], prompt_ts])
                data[prompt_i]['pred_errors'] = torch.cat([data[prompt_i]['pred_errors'], prompt_pred_errors])

        # compute the next remaining idxs
        errors = [-data[prompt_i]['pred_errors'].mean() for prompt_i in remaining_prmpt_idxs]
        best_idxs = torch.topk(torch.tensor(errors), k=n_to_keep_val, dim=0).indices.tolist()
        remaining_prmpt_idxs = [remaining_prmpt_idxs[i] for i in best_idxs]

    # organize the output
    assert len(remaining_prmpt_idxs) == 1
    pred_idx = remaining_prmpt_idxs[0]

    return pred_idx, data

In [26]:
def _convert_image_to_rgb(image):
    return image.convert("RGB")

def get_transform(interpolation=InterpolationMode.BICUBIC, size=512):
    transform = torch_transforms.Compose([
        torch_transforms.Resize(size, interpolation=interpolation),
        torch_transforms.CenterCrop(size),
        _convert_image_to_rgb,
        torch_transforms.ToTensor(),
        torch_transforms.Normalize([0.5], [0.5])
    ])
    return transform


INTERPOLATIONS = {
    'bilinear': InterpolationMode.BILINEAR,
    'bicubic': InterpolationMode.BICUBIC,
    'lanczos': InterpolationMode.LANCZOS,
}


In [28]:
import os
import os.path as osp
import argparse
import torch
import pandas as pd
import numpy as np
import tqdm
from diffusion.datasets import get_target_dataset
import gc
LOG_DIR = './Logs'




def evaluate_model(dataset='cifar10', split='test', version='2-0', img_size=512, batch_size=32, n_trials=1,
                   prompt_path='prompts/cifar10_prompts.csv', noise_path=None, subset_path=None, dtype='float16',
                   interpolation='bicubic', extra=None, n_workers=12, worker_idx=0, load_stats=False, loss='l2',
                   to_keep=[5, 1], n_samples=[50, 500]):
    
    assert len(to_keep) == len(n_samples)

    # Make run output folder
    name = f"v{version}_{n_trials}trials_"
    name += '_'.join(map(str, to_keep)) + 'keep_'
    name += '_'.join(map(str, n_samples)) + 'samples'
    if interpolation != 'bicubic':
        name += f'_{interpolation}'
    if loss == 'l1':
        name += '_l1'
    elif loss == 'huber':
        name += '_huber'
    if img_size != 512:
        name += f'_{img_size}'
    if extra is not None:
        run_folder = osp.join(LOG_DIR, dataset + '_' + extra, name)
    else:
        run_folder = osp.join(LOG_DIR, dataset, name)
    os.makedirs(run_folder, exist_ok=True)
    print(f'Run folder: {run_folder}')

    # Set up dataset and prompts
    interpolation_func = INTERPOLATIONS[interpolation] #defined above
    transform = get_transform(interpolation_func, img_size)
    latent_size = img_size // 8
    target_dataset = get_target_dataset(dataset, train=split == 'train', transform=transform)
    print(f'Current path {os.getcwd()}')
    # Corrected line: os.getcwd() instead of os.getcwd
    prompts_df = pd.read_csv(osp.join(os.getcwd(), prompt_path))

    # Load pretrained models
    vae, tokenizer, text_encoder, unet, scheduler = get_sd_model(version, dtype)
    vae = vae.to(device)
    text_encoder = text_encoder.to(device)
    unet = unet.to(device)
    torch.backends.cudnn.benchmark = True

    # Load noise
    if noise_path is not None:
        all_noise = torch.load(noise_path).to(device)
        print('Loaded noise from', noise_path)
    else:
        all_noise = None

    # Prepare text embeddings
    text_input = tokenizer(prompts_df.prompt.tolist(), padding="max_length",
                           max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    embeddings = []
    with torch.inference_mode():
        for i in range(0, len(text_input.input_ids), 100):
            text_embeddings = text_encoder(
                text_input.input_ids[i: i + 100].to(device),
            )[0]
            embeddings.append(text_embeddings)
    text_embeddings = torch.cat(embeddings, dim=0)
    assert len(text_embeddings) == len(prompts_df)

    # Subset of dataset to evaluate
    if subset_path is not None:
        idxs = np.load(subset_path).tolist()
    else:
        idxs = list(range(len(target_dataset)))
    idxs_to_eval = idxs[worker_idx::n_workers]

    formatstr = get_formatstr(len(target_dataset) - 1)
    correct = 0
    total = 0
    pbar = tqdm.tqdm(idxs_to_eval)
    for i in pbar:
        if total > 0:
            pbar.set_description(f'Acc: {100 * correct / total:.2f}%')
        fname = osp.join(run_folder, formatstr.format(i) + '.pt')
        if os.path.exists(fname):
            print('Skipping', i)
            if load_stats:
                data = torch.load(fname)
                correct += int(data['pred'] == data['label'])
                total += 1
            continue
        image, label = target_dataset[i]
        with torch.no_grad():
            img_input = image.to(device).unsqueeze(0)
            if dtype == 'float16':
                img_input = img_input.half()
            x0 = vae.encode(img_input).latent_dist.mean
            x0 *= 0.18215
        pred_idx, pred_errors = eval_prob_adaptive(unet, x0, text_embeddings, scheduler, n_samples, to_keep, batch_size, dtype, loss, latent_size, all_noise, n_trials)
        pred = prompts_df.classidx[pred_idx]
        torch.save(dict(errors=pred_errors, pred=pred, label=label), fname)
        if pred == label:
            correct += 1
        total += 1
import os

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

torch.cuda.empty_cache()
gc.collect()



# Example usage in a Jupyter notebook
evaluate_model(
    dataset='cifar10', 
    split='test', 
    n_trials=1, 
    to_keep=[5, 1], 
    n_samples=[50, 500], 
    loss='l1', 
    prompt_path='promps/cifar10_prompts.csv',
    batch_size=8
)

Run folder: ./Logs/cifar10/v2-0_1trials_5_1keep_50_500samples_l1
Files already downloaded and verified
Current path /home/zow/Gene-DRAIL


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/834 [00:00<?, ?it/s]

Skipping 0


Acc: 87.86%: 100%|██████████| 834/834 [23:09:22<00:00, 99.95s/it]    


In [None]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   2492 MiB |   7272 MiB |  26593 MiB |  24101 MiB |
|       from large pool |   2425 MiB |   7205 MiB |  26384 MiB |  23958 MiB |
|       from small pool |     66 MiB |     67 MiB |    209 MiB |    143 MiB |
|---------------------------------------------------------------------------|
| Active memory         |   2492 MiB |   7272 MiB |  26593 MiB |  24101 MiB |
|       from large pool |   2425 MiB |   7205 MiB |  26384 MiB |  23958 MiB |
|       from small pool |     66 MiB |     67 MiB |    209 MiB |    143 MiB |
|---------------------------------------------------------------