In [None]:
import os
import matplotlib
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import pandas as pd
from tqdm.auto import tqdm
import pickle
from sklearn.neighbors import NearestNeighbors
import time
import nlp_utils # our utils

In [None]:
results_dir = '/home/dcor/roeyron/TCIE/results/celeba_conditioned_embeddings/'

results_fnames = sorted(os.listdir(results_dir))
results_fpaths = [os.path.join(results_dir, fname) for fname in results_fnames]

files_per_q = defaultdict(list)
for p in results_fpaths:
    q_ind = int(p.split('_')[-3])
    files_per_q[q_ind].append(p)
files_per_q = dict(files_per_q)
files_per_q



In [None]:
def load_df(fpaths):
    dfs = []
    for fpath in fpaths:
        with open(fpath, 'rb') as f:
            df = pickle.load(f)
        dfs.append(df)
    df = pd.concat(dfs)
    return df

In [None]:
df_q0 = load_df(files_per_q[0])
df_q0.head()

In [None]:
Image.open(df_q0.iloc[349].image_path)

In [None]:
###########################
i_layer = -1  # <= 32
i_token = -1

###########################
k = 9
###########################


ddl = defaultdict(list)

query_image_inds = [200, 1153, 1011, 300, 18] + list(np.random.RandomState(50).permutation(len(df_q0))[:30])

for q_ind, fpaths in tqdm(file_per_q.items()):
    df = load_df(fpaths)
    for query_ind in query_image_inds:

        questions = df.question.unique()
        assert len(questions) == 1
        question = questions[0]

        
        X = np.array([hs[i_layer][i_token] for hs in df['hidden_states']])
    
        neighbors = NearestNeighbors(n_neighbors=k + 1, metric='cosine')
        neighbors.fit(X)
        query = X[query_ind]
        distances, indices = neighbors.kneighbors([query])
        distances, indices = distances[0, :], indices[0, :]

        ddl['query_ind'].append(query_ind)
        ddl['result_inds'].append(indices)
        ddl['question'].append(question)
        
        # print(f'{query_ind} -------', question)
        # images = [Image.open(df['image_path'].iloc[nn_ind]) for nn_ind in indices]
        # display(Image.fromarray(np.concatenate([img.resize((256, 256)) for img in images], axis=1)))
        
df_search = pd.DataFrame(ddl)
df_search.head()

In [None]:
df_search.question.unique()

In [None]:

for query_ind in df_search.query_ind.unique():
    df_query = df_search[df_search.query_ind == query_ind]
    for _, row in df_query.iterrows():
        # if not cechk_if_sub_str_in(row.question, good_questions_sub_strings):
        #     continue
        print(row.result_inds)
        print(f' ######### Prompt: {row.question} [{query_ind}]')
        images = [Image.open(df['image_path'].iloc[nn_ind]) for nn_ind in row.result_inds]
        display(Image.fromarray(np.concatenate([img.resize((180, 180)) for img in images], axis=1)))
    

In [None]:
prompts_and_images = {
    "describe the hair of the person in the image": [18, 17, 853, 552],
    "describe the expression of the person in the image": [18, 331, 958, 1114],
    "describe the background color of the image": [18, 848, 937, 404]
}

In [None]:
import shutil
dir_path = '/home/dcor/roeyron/TCIE/results/images_dir_for_paper_teaser'
if os.path.exists(dir_path):
    shutil.rmtree(dir_path)
os.mkdir(dir_path)
for prompt, image_ids in prompts_and_images.items():
    prefix = prompt.replace(' ', '_')
    for i, image_id in enumerate(image_ids):
        src_path = df_q0.iloc[image_id].image_path
        dst_path = os.path.join(dir_path, f'{prefix}_{i}.png')
        shutil.copy(src_path, dst_path)
        