In [None]:
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch
import json
import gc
import os
import pandas as pd
from tqdm import tqdm
from clip_eval import CLIPEvaluator
from transformers import AutoImageProcessor, AutoModel
from transformers import CLIPTextModel
from matplotlib import pyplot as plt
import math
import numpy as np
import random
from PIL import Image


def init_generative_model(args):
    """
    Initialize the model
    params:
        args: argparse.Namespace
    """
    model_id = "CompVis/stable-diffusion-v1-4"
    pipe = StableDiffusionPipeline.from_pretrained(model_id).to(args.device)
    return pipe


def check_identifier_token_dir(checkpoint, concept_str):
    if os.path.exists(os.path.join(checkpoint, concept_str)):
        dir = os.path.join(checkpoint, concept_str)
    elif os.path.exists(checkpoint.format(concept_str)):
        dir = checkpoint.format(concept_str)
    else:
        dir = checkpoint
    return dir


def load_trained_weights(args, pipe, checkpoint, concepts_str, c_identifier):
    """
    Load additional weights for the model
    params:
        model_name: str
        pipe: StableDiffusionPipeline
        checkpoint: str, where the additional unique identifier weights are saved
        concepts_str: str, the concepts used to generate images
    """
    print(f"loading unique identifier weights for {concepts_str} ...")
    if args.model_name == 'sd-v1-5':
        return pipe
    elif args.model_name == 'dreambooth':
        dir = check_identifier_token_dir(checkpoint, concepts_str)
        unet = UNet2DConditionModel.from_pretrained(os.path.join(dir, 'unet')).to(pipe.device)
        pipe.unet = unet
        if os.path.exists(os.path.join(dir, 'text_encoder')):
            text_encoder = CLIPTextModel.from_pretrained(os.path.join(dir, 'text_encoder')).to(pipe.device)
            pipe.text_encoer = text_encoder
        for cls in concepts_str.split(','):
            if os.path.exists(os.path.join(dir, f"{c_identifier[cls]}.bin")):
                print(f"loading unique identifier weights from {c_identifier[cls]}.bin ...")
                pipe.load_textual_inversion(os.path.join(dir, f"{c_identifier[cls]}.bin"))
    else:
        raise ValueError(f'Unknown model name {args.model_name}')
    return pipe

def check_mk_file_dir(file_name):
    check_mkdir(file_name[:file_name.rindex("/")])
    
def check_mkdir(dir_name):
    """
    check if the folder exists, if not exists, the func will create the new named folder.
    """
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

def save_img(im, prompts, save_dir, name):
    """
    Save images
    params:
        im: numpy array
        prompts: list
        save_dir: str
        name: str
    """
    save_path = os.path.join(save_dir, f'{prompts}/{name}.jpg')
    check_mk_file_dir(save_path)
    if isinstance(im, np.ndarray):
        Image.fromarray(im).save(save_path)
    elif isinstance(im, Image.Image):
        im.save(save_path)
    else:
        raise TypeError(f'Unknown type {type(im)}') 

def get_reference_images(concepts_str, src_img_dir):
    src_imgs = {}
    for concept in concepts_str.split(','):
        src_imgs[concept] = []
        for img_path in os.listdir(os.path.join(src_img_dir, concept)):
            src_imgs[concept].append(Image.open(os.path.join(src_img_dir, concept, img_path)))
    return src_imgs


def edit_original_prompt(prompt, c_identifier, keys, mode='replace'):
    """
    Edit the original prompt to the prompt with identifiers
    params:
        prompt: str
        c_identifier: dict, concepts and identifiers, e.g. {'cat': '<cute-cat>'}
        keys: concept list, e.g. ['cat', 'dog']
        mode: str, 'replace' or 'insert'
    returns:
        prompt: str
    """
    if mode != 'none':
        for key in keys:
            if '_' in key:
                replace_str = key.replace('_', ' ')
            else:
                replace_str = key
            if replace_str in prompt:
                if mode == 'replace':
                    prompt = prompt.replace(replace_str, c_identifier[key])
                elif mode == 'insert':
                    prompt = prompt.replace(replace_str, f'{c_identifier[key]} {replace_str}')
            else:
                raise ValueError(f'{replace_str} not in prompt {prompt}')
    return prompt.strip()


def generate_images(args, c_p, c_identifier):
    """
    Generate images from the optimization-based model
    params:
        args: argparse.Namespace
        c_p: dict, concepts and prompts, e.g. {'cat': ['a photo of a cat']}
        c_identifier: dict, concepts and identifiers, e.g. {'cat': '<cute-cat>'}
    """
    for concepts_str, prompts in c_p.items():
        pipe = init_generative_model(args)
        pipe = load_trained_weights(args, pipe, args.checkpoint, concepts_str, c_identifier)

        for prompt in prompts:
            edited_prompt = edit_original_prompt(prompt, c_identifier, concepts_str.split(','), mode=args.edit_mode)
            print("Generating images for prompt: ", edited_prompt)
            generator = torch.manual_seed(8888)
            for idx in range(args.num_per_prompt):
                im = pipe(edited_prompt, num_inference_steps=50, guidance_scale=7.5, generator=generator).images[0]
                save_img(im, prompt, os.path.join(args.img_save_dir, concepts_str), idx)
    print("all images saved in ", args.img_save_dir)


def check_dir(img_save_dir, concepts_str, prompt):
    """
    check dir if exists for alignment function
    """
    if os.path.exists(os.path.join(img_save_dir, concepts_str, prompt)):
        dir = os.path.join(img_save_dir, concepts_str, prompt)
    else:
        raise ValueError(f'no such dir: {os.path.join(img_save_dir, concepts_str, prompt)}')
    return dir


def eval_alignment(concepts_list, evaluator, c_p, img_save_dir, src_img_dir):
    """
    Evaluate the alignment between the generated images and the source images/texts
    params:
        concepts_list: list
        evaluator: CLIPEvaluator
        c_p: dict, concepts and prompts
        img_save_dir: str, where the generated images are saved
        src_img_dir: str, where the source images are saved
    returns:
        img_img_sim_mean: dict, the average similarity between the generated images and the source images
        text_img_sim_mean: dict, the average similarity between the generated images and the source texts
    """
    img_img_sim = {}   # {c_str: {c1:[s1, s2, s3], c2:[s1, s2, s3]}}
    text_img_sim = {}  # {c_str: [s1, s2, s3]}
    overfitting = {}
    
    for concepts_str, prompts in c_p.items():
        print(f'evaluating {concepts_str}...')
        img_img_sim[concepts_str] = {} 
        text_img_sim[concepts_str] = []
        overfitting[concepts_str] = []

        src_imgs = get_reference_images(concepts_str, src_img_dir) # {c1: [I1, I2...], c2: [I1, I2...]}
        src_img_features = {}  # {c1: [f1, f2, f3...], c2: [f1, f2, f3...}
        for concept, concept_src_imgs in src_imgs.items():
            src_img_features[concept] = [evaluator.get_image_features(src_img) for src_img in concept_src_imgs]
            img_img_sim[concepts_str][concept] = []

        src_cap_features = {} # {c1: [f1, f2, f3...], c2: [f1, f2, f3...}
        src_captions = {concept:concepts_list[concept]["caption"] for concept in concepts_str.split(',')}
        for concept, concept_src_captions in src_captions.items():
            src_cap_features[concept] = [evaluator.get_text_features(src_caption) for src_caption in concept_src_captions]


        for prompt in prompts:
            text_feature = evaluator.get_text_features(prompt)
            dir = check_dir(img_save_dir, concepts_str, prompt)
            for img_path in os.listdir(dir):
                img = Image.open(os.path.join(dir, img_path))
                # text alignment
                text_img_sim[concepts_str].append(2.5*evaluator.txt_to_img_similarity(img, text_features=text_feature).cpu().numpy())
                # image alignment
                for concept, concept_img_features in src_img_features.items():
                    img_img_sim[concepts_str][concept].extend([evaluator.img_to_img_similarity(img, src_img_features=src_img_feature).cpu().numpy() for src_img_feature in concept_img_features])
                
                # overfitting
                overfitting_flag = False
                for concept, concept_cap_features in src_cap_features.items():
                    cap_img_sim = [2.5*evaluator.txt_to_img_similarity(img, text_features=src_cap_feature).cpu().numpy() for src_cap_feature in concept_cap_features]
                    # compare text_img_sim[concepts_str][-1] with cap_img_sim
                    if text_img_sim[concepts_str][-1] < np.max(cap_img_sim):
                        print(f'overfitting: {prompt} {img_path} < {src_captions[concept][np.argmax(cap_img_sim)]} {text_img_sim[concepts_str][-1]} {np.max(cap_img_sim)}')
                        overfitting_flag = True

                if overfitting_flag == False:
                    overfitting[concepts_str].append(0)
                else:
                    overfitting[concepts_str].append(1)
    
    img_img_sim_mean = {}
    text_img_sim_mean = {}
    overfitting_mean = {}
    for concepts_str, sim_list in img_img_sim.items():
        img_img_sim_mean[concepts_str] = {}
        for concept, sims in sim_list.items():
            img_img_sim_mean[concepts_str][concept] = np.mean(sims)
        text_img_sim_mean[concepts_str] = np.mean(text_img_sim[concepts_str])
        overfitting_mean[concepts_str] = np.mean(overfitting[concepts_str])
        print(f'{concepts_str}: image alignment -- {img_img_sim_mean[concepts_str]} | text alignment -- {text_img_sim_mean[concepts_str]}')
    return {'img': img_img_sim_mean, 'text': text_img_sim_mean}


def get_placeholders(concepts_list):
    """
    Get the placeholders for the concepts
    params:
        concepts_list: list, the concepts list
    """
    cls_identifier = {}
    for concept in concepts_list:
        cls_identifier[concept['class_prompt']] = concept['placeholder']
    return cls_identifier

def load_prompts_from_file(file):
    with open(file, 'r') as f:
        prompts = f.read().split('\n')
    return prompts

def get_concept_prompts(concepts_list):
    """
    Get the concept prompts
    params:
        concepts_list: list, the concepts list
    returns:
        c_p: dict, {concept_str: [prompts]}, e.g. {'cat': ['a cat in front of a desk']}
    """
    p_c = {}
    for concept in concepts_list:
        for prompt in load_prompts_from_file(concept[f'test_prompts']):
            if prompt not in p_c:
                p_c[prompt] = [concept['class_prompt']]
            elif concept['class_prompt'] not in p_c[prompt]:
                p_c[prompt].append(concept['class_prompt'])
    c_p = {}
    for prompt, concepts in p_c.items():
        for concept in concepts:
            if concept not in c_p:
                c_p[concept] = [prompt]
            elif prompt not in c_p[concept]:
                c_p[concept].append(prompt)
    return c_p

def set_seed(seed = 8888):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

def generate(args):
    set_seed(seed=8888)
    # prepare the prompts
    concepts_list = json.load(open(args.concepts_list_path, "r"))
    c_p = get_concept_prompts(concepts_list)   # dict: {concept_str: [prompts]}
    c_identifier = get_placeholders(concepts_list)
    # generate images
    generate_images(args, c_p, c_identifier)


def evaluate(args):
    """
    Evaluate the generated images
    params:
        args: argparse.Namespace
        eval_clip_score: bool, whether to evaluate the CLIP score (including image alignment and text-image alignment)
        eval_coco_coi: bool, whether to evaluate the coco CoI score
    """
    results = {}
    concepts_list = json.load(open(args.concepts_list_path, "r"))
    c_p = get_concept_prompts(concepts_list)   # dict: {concept_str: [prompts]}
    
    # evaluate the image alignment and text-image alignment
    evaluator = CLIPEvaluator(args.device)
    clip_score = eval_alignment(concepts_list, evaluator, c_p, args.img_save_dir, args.src_img_dir)
    results.update(clip_score)

    save_results(args, results)

def run_and_test(args, **kwargs):
    generate(args)
    evaluate(args, **kwargs)
    
def save_results(args, scores):
    """ 
    Save the results to a csv file
    params:
        args: argparse.Namespace
        scores: dict, {score_name: {concept_str: score}}
    """
    results = args.get_dict()
    if os.path.exists(args.results_path):
        file = pd.read_csv(args.results_path)
    else:
        file = pd.DataFrame(columns=list(results.keys()))
    
    for score_name, score_dict in scores.items():
        for concept_str, score in score_dict.items():
            if isinstance(score, dict):
                for concept, s in score.items():
                    results[f'{concept_str}_{concept}_{score_name}'] = s
            else:
                results[f'{concept_str}_{score_name}'] = score
    file = pd.concat([file, pd.DataFrame(results, index=[0])], ignore_index=True)
    file.to_csv(args.results_path, index=False)
    print("results saved in ", args.results_path)

class Dict2Class(object):
    def __init__(self, mydict):
        self.dict = mydict
        for key in mydict.keys():
            setattr(self, key, mydict[key])

    def get_dict(self):
        return self.dict

In [None]:
args = {
    "model_name": "dreambooth",  
    "device": "cuda:0",
    "edit_mode": "insert",
    "num_per_prompt": 2,  # number of generated images per prompt
    "concepts_list_path": "./data/concepts_list_object.json",
    "checkpoint": "./snapshot/gal_obj/{}",   
    "src_img_dir": "./data/object/",
    "results_path": "results.csv",
}
args["img_save_dir"] = "./samples/{}".format(args['model_name'])
args = Dict2Class(args)
run_and_test(args, eval_clip_score=True, eval_coco_coi=True)