In [1]:
# model loader

import sys
sys.path.append(".")
sys.path.append("./latent-diffusion")
sys.path.append('./taming-transformers')

import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    pl_sd = torch.load(ckpt, weights_only=False, map_location=torch.device('cpu'))
    sd = pl_sd["state_dict"]
    torch.save(sd, './tmp_sd')
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(torch.load('./tmp_sd', map_location=device), strict=False)
    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    return model, sd, pl_sd


def get_model(model_config_path, model_ckpt_path):
    config = OmegaConf.load(model_config_path)
    model, sd, pl_sd = load_model_from_config(config, model_ckpt_path)
    return model, sd, pl_sd

In [2]:
# load imagenet pretrained model
import os

cin_model_config_path = os.path.join('./latent-diffusion/configs/latent-diffusion/cin256-v2.yaml')
cin_model_ckpt_path = './latent-diffusion/models/ldm/cin256-v2/model.ckpt'

# load model
model, sd, pl_sd = get_model(cin_model_config_path, cin_model_ckpt_path)
del sd
del pl_sd

Loading model from ./latent-diffusion/models/ldm/cin256-v2/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 400.92 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [3]:
# set ldm classifier

from ldm_classifier import LdmClassifier
ldm_clf = LdmClassifier(model)

In [4]:
import sys
sys.path.append('./latent-diffusion/ldm/data/')
from torchvision import transforms
import torch
from matplotlib import pyplot as plt
import numpy as np

from imagenet_mini import ImageNetMiniSubset, ImageNetMiniDataset

# define imagenet-mini subset for classification

val_dir = './data/imagenet-mini/validation'
subset_classes = [0, 1]
imagnet_subset = ImageNetMiniSubset(data_dir=val_dir, labels_file='validation_set.csv', size=256, classes=subset_classes)

print(f"number of samples in dataset: {len(imagnet_subset)}")

Data source: validation_set.csv
    Class tench, Tinca tinca: 48.6%
    Class goldfish, Carassius auratus: 51.4%
number of samples in dataset: 70


In [5]:
# verify ability to classify single sample:

loader = torch.utils.data.DataLoader(imagnet_subset, batch_size=1, shuffle=False)
batch = next(iter(loader))
x0, c_hypotheses = ldm_clf.get_latent_batch(batch, classes=[0, 1])

l2_label_pred, l1_label_pred = ldm_clf.classify_batch(x0, c_hypotheses)
print(f"true label: {batch['class_label']}")
print(f"L2 classification: {l2_label_pred}")
print(f"L1 classification: {l1_label_pred}")

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

mean pred errors: tensor([0.1780, 0.2576])
mean L2, L1 pred errors: (tensor(0.1780), tensor(0.2576))


diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

mean pred errors: tensor([0.1773, 0.2565])
mean L2, L1 pred errors: (tensor(0.1773), tensor(0.2565))
true label: tensor([1])
L2 classification: tensor([1], device='cuda:0')
L1 classification: tensor([1], device='cuda:0')


In [11]:
# run classification for the entire dataset

l2_labels_pred, l1_labels_pred, true_labels = ldm_clf.classify_dataset(dataset=imagnet_subset,
                                                                       batch_size=1,
                                                                       n_trials=1,
                                                                       t_sampling_stride=5,
                                                                       classes=subset_classes)

dataset samples:   0%|          | 0/70 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

class hypothsis:   0%|          | 0/2 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

diffusion sampling:   0%|          | 0/200 [00:00<?, ?it/s]

In [13]:
l2_acc = ldm_clf.get_classification_accuracy(l2_labels_pred, true_labels)
l1_acc = ldm_clf.get_classification_accuracy(l1_labels_pred, true_labels)

print(f"L2 accuracy: {l2_acc}")
print(f"L1 accuracy: {l1_acc}")

L2 accuracy: 98.57142857142858
L1 accuracy: 98.57142857142858
