From d5ec8f4bc7a22a2d209c4ff5297f58fb6c98f435 Mon Sep 17 00:00:00 2001 From: Guangsen Wang Date: Fri, 9 Dec 2022 19:26:45 +0800 Subject: [PATCH 1/2] updated app --- app/__init__.py | 10 +- app/backend/caption_backend.py | 107 ++++++++++ app/backend/multimodal_search_backend.py | 198 +++++++++++++++++ app/backend/txt2image_backend.py | 199 ++++++++++++++++++ app/calculate_coco_features.py | 7 - app/caption.py | 62 ++++-- app/caption_front_end.py | 105 +++++++++ app/classification.py | 142 +++++++------ ...dataset_browser.py => dataset_explorer.py} | 64 ++++-- app/image_text_match.py | 23 +- app/main.py | 33 ++- app/multimodal_search.py | 172 ++++++++------- app/multimodal_search_front_end.py | 112 ++++++++++ app/multipage.py | 7 - app/style_override.css | 57 +++++ app/txt2image_front_end.py | 62 ++++++ app/utils.py | 38 +++- app/vqa.py | 33 +-- 18 files changed, 1177 insertions(+), 254 deletions(-) create mode 100644 app/backend/caption_backend.py create mode 100644 app/backend/multimodal_search_backend.py create mode 100644 app/backend/txt2image_backend.py create mode 100644 app/caption_front_end.py rename app/{dataset_browser.py => dataset_explorer.py} (79%) create mode 100644 app/multimodal_search_front_end.py create mode 100644 app/style_override.css create mode 100644 app/txt2image_front_end.py diff --git a/app/__init__.py b/app/__init__.py index 05522040..3a108837 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,10 +1,3 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - from PIL import Image import requests @@ -24,3 +17,6 @@ def load_demo_image(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cache_root = "/export/home/.cache/lavis/" +pending_job_path = "app/task_queues/pending_jobs/" +finished_job_path = "app/task_queues/finished_jobs/" +job_output_path = "app/task_queues/outputs/" diff --git a/app/backend/caption_backend.py b/app/backend/caption_backend.py new file mode 100644 index 00000000..8b4f99ca --- /dev/null +++ b/app/backend/caption_backend.py @@ -0,0 +1,107 @@ +from app import device, load_demo_image +from app.utils import load_model_cache, get_pending_jobs, create_uniq_user_job_name +from app import job_output_path, finished_job_path, pending_job_path +from lavis.processors import load_processor +from PIL import Image + +import random +import numpy as np +import torch +import os, shutil, time + +job_type = 'caption' + +if torch.cuda.is_available(): + torch.cuda.set_device(0) + device = "cuda" +else: + device = "cpu" + +def setup_seed(seed): + random.seed(seed) + np.random.seed(int(seed)) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + import torch.backends.cudnn as cudnn + cudnn.benchmark = False + cudnn.deterministic = True + +def back_end(): + vis_processor = load_processor("blip_image_eval").build(image_size=384) + blip_large_model = load_model_cache( + "blip_caption", + model_type=f"large_coco", + is_eval=True, + device=device, + ) + blip_base_model = load_model_cache( + "blip_caption", + model_type=f"base_coco", + is_eval=True, + device=device, + ) + os.makedirs(os.path.join(finished_job_path, job_type), exist_ok=True) + while True: + pending_jobs = get_pending_jobs(job_type) + for job in pending_jobs: + while True: + with open(job) as f: + content = f.readline().rstrip(' \n') + if len(content.split('\t')) == 5: break + time_stamp, blip_type, sampling_method, num_captions, seed = content.split('\t') + outpath = os.path.join(job_output_path, job_type) + os.makedirs(outpath, exist_ok=True) + img_file = outpath+'/{}_raw_image.pt'.format(create_uniq_user_job_name(time_stamp, sampling_method)) + while True: + if os.path.exists(img_file): + break + time.sleep(1) + img = torch.load(outpath+'/{}_raw_image.pt'.format(create_uniq_user_job_name(time_stamp, sampling_method)),map_location=torch.device(device)) + if blip_type == 'large': + model = blip_large_model + else: + model = blip_base_model + use_nucleus_sampling = False + if sampling_method == 'Nucleus sampling': + use_nucleus_sampling = True + setup_seed(int(seed)) + captions = generate_caption(model, img, use_nucleus_sampling, int(num_captions)) + caption_result = outpath+'/{}_result.txt'.format(create_uniq_user_job_name(time_stamp, sampling_method)) + with open(caption_result,'w') as f: + for caption in captions: + f.write(caption+'\n') + shutil.move(job, os.path.join(finished_job_path, job_type)) + os.remove(img_file) + + +def generate_caption( + model, image, use_nucleus_sampling=False, num_captions = 1, num_beams=3, max_length=40, min_length=5 +): + samples = {"image": image} + + captions = [] + if use_nucleus_sampling: + #for _ in range(5): + captions = model.generate( + samples, + use_nucleus_sampling=True, + max_length=max_length, + min_length=min_length, + top_p=0.9, + num_captions=num_captions + ) + #captions.append(caption[0]) + else: + caption = model.generate( + samples, + use_nucleus_sampling=False, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + num_captions=1 + ) + captions.append(caption[0]) + return captions +if __name__ == "__main__": + back_end() diff --git a/app/backend/multimodal_search_backend.py b/app/backend/multimodal_search_backend.py new file mode 100644 index 00000000..17fab829 --- /dev/null +++ b/app/backend/multimodal_search_backend.py @@ -0,0 +1,198 @@ +import os, shutil + +import numpy as np +import streamlit as st +import torch +import torch.nn.functional as F +from app import cache_root, device, job_output_path, finished_job_path +from app.utils import ( + getAttMap, + init_bert_tokenizer, + load_blip_itm_model, + read_img, + resize_img, + get_pending_jobs, + create_uniq_user_job_name +) +from lavis.models import BlipFeatureExtractor, load_model +from lavis.processors import load_processor + +if torch.cuda.is_available(): + torch.cuda.set_device(0) + device = "cuda" +else: + device = "cpu" + +job_type = 'search' + +def load_feat(): + from lavis.common.utils import download_url + + dirname = os.path.join(os.path.dirname(__file__), "assets") + filename = "path2feat_coco_train2014.pth" + filepath = os.path.join(dirname, filename) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth" + + if not os.path.exists(filepath): + download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth") + + path2feat = torch.load(filepath) + paths = sorted(path2feat.keys()) + + all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device) + + return path2feat, paths, all_img_feats + +def load_feature_extractor_model(device): + model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" + + model = load_model("blip_feature_extractor", model_type="base", is_eval=True, device=device) + model.load_from_pretrained(model_url) + + return model + +def search(time_stamp, user_question, feature_extractor, vis_processor, raw_user_question, num_display, itm_model): + sample = {"text_input": user_question} + with torch.no_grad(): + text_feature = feature_extractor.extract_features( + sample, mode="text").text_embeds_proj[0, 0] + + path2feat, paths, all_img_feats = load_feat() + all_img_feats.to(device) + all_img_feats = F.normalize(all_img_feats, dim=1) + + num_cols = 4 + num_rows = int(num_display) // num_cols + + similarities = text_feature @ all_img_feats.T + indices = torch.argsort(similarities, descending=True)[:num_display] + + top_paths = [paths[ind.detach().cpu().item()] for ind in indices] + sorted_similarities = [similarities[idx] for idx in indices] + file_root = os.path.join(cache_root, "coco/images/train2014/") + filenames = [os.path.join(file_root, p) for p in top_paths] + outpath = os.path.join(job_output_path, job_type) + os.makedirs(outpath, exist_ok=True) + + bsz = 8 # max number of images to avoid cuda oom + + #itm_model = load_blip_itm_model("cuda", model_type=blip_type) + + tokenizer = init_bert_tokenizer() + queries_batch = [user_question] * bsz + queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to("cpu") + + num_batches = int(num_display / bsz) + + avg_gradcams = [] + all_raw_images = [] + itm_scores = [] + + for i in range(num_batches): + filenames_in_batch = filenames[i * bsz : (i + 1) * bsz] + raw_images, images = read_and_process_images(filenames_in_batch, vis_processor) + gradcam, itm_output = compute_gradcam_batch( + itm_model, images, queries_batch, queries_tok_batch + ) + + all_raw_images.extend([resize_img(r_img) for r_img in raw_images]) + norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] + + for norm_img, grad_cam in zip(norm_imgs, gradcam): + avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True) + avg_gradcams.append(avg_gradcam) + + with torch.no_grad(): + itm_score = torch.nn.functional.softmax(itm_output, dim=1) + + itm_scores.append(itm_score) + + #avg_gradcams = torch.cat(avg_gradcams) + #all_raw_images = torch.cat(all_raw_images) + + itm_scores = torch.cat(itm_scores)[:, 1] + torch.save(itm_scores, outpath+'/{}_itm.pt'.format(create_uniq_user_job_name(time_stamp,raw_user_question))) + np.save(outpath+'/{}_avg_gradcams.npy'.format(create_uniq_user_job_name(time_stamp,raw_user_question)), avg_gradcams, allow_pickle=True) + np.save(outpath+'/{}_all_raw_images.npy'.format(create_uniq_user_job_name(time_stamp,raw_user_question)),all_raw_images,allow_pickle=True) + + search_result = outpath+'/{}_result.txt'.format(create_uniq_user_job_name(time_stamp,raw_user_question)) + with open(search_result,'w') as f: + for filename in filenames: + f.write(filename+'\n') + +def back_end(): + # === event === + vis_processor = load_processor("blip_image_eval").build(image_size=384) + text_processor = load_processor("blip_caption") + feature_extractor = load_feature_extractor_model(device) + os.makedirs("{}/{}/".format(finished_job_path, job_type), exist_ok=True) + large_itm_model = load_blip_itm_model(device, model_type='large') + base_itm_model = load_blip_itm_model(device, model_type='base') + + while True: + pending_jobs = get_pending_jobs(job_type) + for job in pending_jobs: + while True: + with open(job) as f: + content = f.readline().rstrip(' \n') + if len(content.split('\t')) == 4: break + time_stamp, raw_user_question, num_display, blip_type = content.split('\t') + user_question = text_processor(raw_user_question) + if blip_type == 'large': + search(time_stamp, user_question, feature_extractor, vis_processor, raw_user_question, int(num_display), large_itm_model) + else: + search(time_stamp, user_question, feature_extractor, vis_processor, raw_user_question, int(num_display), base_itm_model) + shutil.move(job, "{}/{}/".format(finished_job_path,job_type)) + + +def read_and_process_images(image_paths, vis_processor): + raw_images = [read_img(path) for path in image_paths] + images = [vis_processor(r_img) for r_img in raw_images] + images_tensors = torch.stack(images).to(device) + + return raw_images, images_tensors + + +def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6): + model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.save_attention = True + + output = model(visual_input, text_input, match_head="itm") + loss = output[:, 1].sum() + + model.zero_grad() + loss.backward() + with torch.no_grad(): + mask = tokenized_text.attention_mask.view( + tokenized_text.attention_mask.size(0), 1, -1, 1, 1 + ).to(device=device) # (bsz,1,token_len, 1,1) + token_length = mask.sum() - 2 + token_length = token_length.cpu() + # grads and cams [bsz, num_head, seq_len, image_patch] + grads = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attn_gradients() + cams = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attention_map() + + # assume using vit large with 576 num image patch + cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask + grads = ( + grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24) + * mask + ) + + gradcam = cams * grads + # [enc token gradcam, average gradcam across token, gradcam for individual token] + # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :])) + gradcam = gradcam.mean(1).cpu().detach() + gradcam = ( + gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length + ) + + return gradcam, output + +if __name__ == '__main__': + back_end() diff --git a/app/backend/txt2image_backend.py b/app/backend/txt2image_backend.py new file mode 100644 index 00000000..655159d7 --- /dev/null +++ b/app/backend/txt2image_backend.py @@ -0,0 +1,199 @@ +import os, shutil, subprocess +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.plms import PLMSSampler + +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + +from app.utils import ( + get_pending_jobs, + create_uniq_user_job_name +) + +from app import job_output_path, finished_job_path + +job_type = 'txt2image' +if torch.cuda.is_available(): + torch.cuda.set_device(1) + device = "cuda" +else: + device = "cpu" + + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) +safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + +def load_model_from_config(config, ckpt, verbose=False): + #print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if not os.path.exists('prompts'): + os.makedirs('prompts', exist_ok=True) + else: + shutil.rmtree('prompts') + os.makedirs('prompts', exist_ok=True) + + model.half() + if torch.cuda.is_available(): + model.cuda() + model.eval() + return model + + +def load_replacement(x): + try: + hwc = x.shape + y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + +def check_safety(x_image): + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + assert x_checked_image.shape[0] == len(has_nsfw_concept) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + +def back_end(): + config = 'stable-diffusion/configs/stable-diffusion/v1-inference.yaml' + ckpt = 'stable-diffusion/sd-v1-4.ckpt' + config = OmegaConf.load(f"{config}") + model = load_model_from_config(config, f"{ckpt}") + print(device) + model = model.to(device) + outpath = os.path.join(job_output_path, job_type) + sample_path = os.path.join(outpath, "samples") + if not os.path.exists(sample_path): + subprocess.run(['mkdir', '-p', sample_path], shell=False) + finished_path = os.path.join(finished_job_path,job_type) + if not os.path.exists(finished_path): + subprocess.run(['mkdir', '-p',finished_path], shell=False) + while True: + pending_jobs = get_pending_jobs(job_type) + for job in pending_jobs: + while True: + with open(job) as f: + content = f.readline().rstrip(' \n') + if len(content.split('\t')) == 4: break + random_seed, time_stamp, user_prompt, num_images = content.split('\t') + generate_image(model, int(random_seed), time_stamp, user_prompt, int(num_images), sample_path) + shutil.move(job, finished_path) + + +def generate_image(model, random_seed, time_stamp, user_prompt, num_images, sample_path): + scale = 7.5 + num_latent_channels = 4 + down_sample_factor = 8 + H, W = 512, 512 + ddim_steps = 50 + skip_grid = False + skip_save = False + n_rows = num_images + precision_scope = autocast + data = [[user_prompt] * num_images] + seed_everything(random_seed) + model = model.to(device) + sampler = PLMSSampler(model) + with torch.no_grad(): + with precision_scope(device): + #with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(1, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if scale != 1.0: + uc = model.get_learned_conditioning(num_images * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [num_latent_channels, H // down_sample_factor, W // down_sample_factor] + samples_ddim, _ = sampler.sample(S=ddim_steps, + conditioning=c, + batch_size= num_images, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=0.0, + x_T=None) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() + + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim) + + x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + if not skip_save: + count = 1 + for x_sample in x_checked_image_torch: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + #img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, "{}_{}.png".format(create_uniq_user_job_name(str(time_stamp), user_prompt), count))) + count += 1 + if not skip_grid: + all_samples.append(x_checked_image_torch) + + if not skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + img = Image.fromarray(grid.astype(np.uint8)) + #img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, "{}_grid.png".format(create_uniq_user_job_name(str(time_stamp), user_prompt)))) +if __name__ == '__main__': + back_end() diff --git a/app/calculate_coco_features.py b/app/calculate_coco_features.py index 168e8503..19d3614f 100644 --- a/app/calculate_coco_features.py +++ b/app/calculate_coco_features.py @@ -1,10 +1,3 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - from PIL import Image import requests import torch diff --git a/app/caption.py b/app/caption.py index ad118988..b8c1ecd5 100644 --- a/app/caption.py +++ b/app/caption.py @@ -1,24 +1,38 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - import streamlit as st from app import device, load_demo_image from app.utils import load_model_cache from lavis.processors import load_processor from PIL import Image +import random +import numpy as np +import torch + + +def setup_seeds(seed): + random.seed(seed) + np.random.seed(int(seed)) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + import torch.backends.cudnn as cudnn + cudnn.benchmark = False + cudnn.deterministic = True + def app(): # ===== layout ===== - model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + model_type = st.sidebar.selectbox("Model:", ["BLIP_large", "BLIP_base"]) sampling_method = st.sidebar.selectbox( "Sampling method:", ["Beam search", "Nucleus sampling"] ) + num_captions = 1 + if sampling_method == "Nucleus sampling": + random_seed = st.sidebar.text_input("Seed:", 1024, help="Set random seed to reproduce the image description") + setup_seeds(random_seed) + num_captions = st.sidebar.slider("Choose number of captions to generate", + min_value=1, max_value=5, step=1) st.markdown( "

Image Description Generation

", @@ -44,8 +58,8 @@ def app(): resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) col1.image(resized_image, use_column_width=True) - col2.header("Description") - + #col2.header("Description") + #with col2: cap_button = st.button("Generate") # ==== event ==== @@ -63,28 +77,39 @@ def app(): img = vis_processor(raw_img).unsqueeze(0).to(device) captions = generate_caption( - model=model, image=img, use_nucleus_sampling=not use_beam + model=model, image=img, use_nucleus_sampling=not use_beam, num_captions=num_captions ) - - col2.write("\n\n".join(captions), use_column_width=True) + + #with col2: + # for caption in captions: + # caption_md = '

{}

'.format(caption) + # st.markdown(caption_md, unsafe_allow_html=True) + with col2: + st.header("Description") + #with col2: + for caption in captions: + caption_md = '

{}

'.format(caption) + st.markdown(caption_md, unsafe_allow_html=True) + #col2.write("\n\n".join(captions), use_column_width=True) def generate_caption( - model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5 + model, image, use_nucleus_sampling=False, num_captions = 1, num_beams=3, max_length=40, min_length=5 ): samples = {"image": image} captions = [] if use_nucleus_sampling: - for _ in range(5): - caption = model.generate( + #for _ in range(5): + captions = model.generate( samples, use_nucleus_sampling=True, max_length=max_length, min_length=min_length, top_p=0.9, - ) - captions.append(caption[0]) + num_captions=num_captions + ) + #captions.append(caption[0]) else: caption = model.generate( samples, @@ -92,6 +117,7 @@ def generate_caption( num_beams=num_beams, max_length=max_length, min_length=min_length, + num_captions=1 ) captions.append(caption[0]) diff --git a/app/caption_front_end.py b/app/caption_front_end.py new file mode 100644 index 00000000..9d906c94 --- /dev/null +++ b/app/caption_front_end.py @@ -0,0 +1,105 @@ +import streamlit as st +from app import load_demo_image, job_output_path, pending_job_path +from app.utils import create_uniq_user_job_name +from lavis.processors import load_processor +from PIL import Image + +import os, time, subprocess +import random +import numpy as np +import torch + +job_type = 'caption' + +def setup_seeds(seed): + random.seed(seed) + np.random.seed(int(seed)) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + import torch.backends.cudnn as cudnn + cudnn.benchmark = False + cudnn.deterministic = True + + +def app(): + # ===== layout ===== + model_type = st.sidebar.selectbox("Model:", ["BLIP_large", "BLIP_base"]) + + sampling_method = st.sidebar.selectbox( + "Sampling method:", ["Beam search", "Nucleus sampling"] + ) + num_captions = 1 + if sampling_method == "Nucleus sampling": + random_seed = st.sidebar.text_input("Seed:", 1024, help="Set random seed to reproduce the image description") + setup_seeds(random_seed) + num_captions = st.sidebar.slider("Choose number of captions to generate", + min_value=1, max_value=5, step=1) + + st.markdown( + "

Image Description Generation

", + unsafe_allow_html=True, + ) + + instructions = """Try the provided image or upload your own:""" + file = st.file_uploader(instructions) + + use_beam = sampling_method == "Beam search" + + col1, col2 = st.columns(2) + + if file: + raw_img = Image.open(file).convert("RGB") + else: + raw_img = load_demo_image() + + col1.header("Image") + + w, h = raw_img.size + scaling_factor = 720 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) + + col1.image(resized_image, use_column_width=True) + + vis_processor = load_processor("blip_image_eval").build(image_size=384) + img = vis_processor(raw_img).unsqueeze(0) + + col2.header("Description") + #with col2: + cap_button = st.button("Generate") + blip_type = model_type.split("_")[1].lower() + + if cap_button: + time_stamp = time.time() + pending_jobs = os.path.join(pending_job_path, job_type) + if not os.path.exists(pending_jobs): + #os.makedirs(pending_jobs) + subprocess.run(['mkdir', '-p', pending_jobs], shell=False) + file_name = '{}_result.txt'.format(create_uniq_user_job_name(time_stamp, sampling_method)) + with open(os.path.join(pending_jobs, file_name),'w') as new_job: + line = str(time_stamp)+'\t'+blip_type+'\t'+str(sampling_method)+'\t'+str(num_captions) + new_job.write(line+'\n') + new_job.close() + + num_pending_jobs = len(os.listdir(pending_jobs)) + outpath = os.path.join(job_output_path,job_type) + if not os.path.exists(outpath): + subprocess.run(['mkdir', '-p', outpath], shell=False) + search_result = outpath+'/{}_result.txt'.format(create_uniq_user_job_name(time_stamp, sampling_method)) + torch.save(img, outpath+'/{}_raw_image.pt'.format(create_uniq_user_job_name(time_stamp, sampling_method))) + + with st.spinner("Queuing (#{} in line)".format(num_pending_jobs)): + while True: + if os.path.exists(search_result): + time.sleep(1) + with open(search_result) as f: + count = 0 + with col2: + #st.header("Description") + for caption in f: + caption = caption.rstrip(' \n') + if count < num_captions: + caption_md = '

{}

'.format(caption) + st.markdown(caption_md, unsafe_allow_html=True) + count += 1 + break diff --git a/app/classification.py b/app/classification.py index 2b5bd896..df9e93a6 100644 --- a/app/classification.py +++ b/app/classification.py @@ -1,21 +1,15 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - import plotly.graph_objects as go import requests import streamlit as st import torch -from lavis.models import load_model +from lavis.models import BlipFeatureExtractor, load_model +from lavis.models.blip_models.blip_image_text_matching import BlipITM from lavis.processors import load_processor from lavis.processors.blip_processors import BlipCaptionProcessor from PIL import Image +import numpy as np from app import device, load_demo_image -from app.utils import load_blip_itm_model from lavis.processors.clip_processors import ClipImageEvalProcessor @@ -36,6 +30,7 @@ def load_demo_image(img_url=None): allow_output_mutation=True, ) def load_model_cache(model_type, device): + if model_type == "blip": model = load_model( "blip_feature_extractor", model_type="base", is_eval=True, device=device @@ -60,56 +55,77 @@ def load_model_cache(model_type, device): return model +@st.cache( + hash_funcs={ + torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach() + .cpu() + .numpy() + }, + allow_output_mutation=True, +) +def load_blip_itm_model(device): + pretrained_path = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth" + model = BlipITM(pretrained=pretrained_path, vit="base") + model.eval() + model = model.to(device) + return model + + def app(): model_type = st.sidebar.selectbox( "Model:", - ["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"], + ["BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"], ) score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"]) # ===== layout ===== st.markdown( - "

Zero-shot Classification

", + "

Zero-shot Classification

", unsafe_allow_html=True, ) instructions = """Try the provided image or upload your own:""" file = st.file_uploader(instructions) - st.header("Image") + col1,col2 = st.columns(2) + col1.header("Image") + #col2.header("Categories") + row2_col1,row2_col2 = st.columns(2) if file: raw_img = Image.open(file).convert("RGB") + st.session_state.new_image = 'yes' + st.session_state.category = 'yes' else: raw_img = load_demo_image() - st.image(raw_img) # , use_column_width=True) - - col1, col2 = st.columns(2) - - col1.header("Categories") - - cls_0 = col1.text_input("category 1", value="merlion") - cls_1 = col1.text_input("category 2", value="sky") - cls_2 = col1.text_input("category 3", value="giraffe") - cls_3 = col1.text_input("category 4", value="fountain") - cls_4 = col1.text_input("category 5", value="marina bay") - - cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4] - cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0] + w, h = raw_img.size + scaling_factor = 700 / w + resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor))) - if len(cls_names) != len(set(cls_names)): - st.error("Please provide unique class names") - return + row2_col1.image(resized_image, use_column_width=True) + cls_names = [] button = st.button("Submit") - - col2.header("Prediction") - + if 'cls_names' not in st.session_state or st.session_state.cls_names == '' or not button: + col2.header("Categories") + with row2_col2: + cls_0 = st.text_input("category 1", value="merlion") + cls_1 = st.text_input("category 2", value="elephant") + cls_2 = st.text_input("category 3", value="giraffe") + cls_3 = st.text_input("category 4", value="fountain") + cls_4 = st.text_input("category 5", value="marina bay") + cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4 ] + cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0] + st.session_state.cls_names = ','.join(cls_names) + + if len(cls_names) != len(set(cls_names)): + st.error("Please provide unique class names") + return # ===== event ===== - if button: if model_type.startswith("BLIP"): text_processor = BlipCaptionProcessor(prompt="A picture of ") + cls_names = st.session_state['cls_names'].split(',') cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names] if score_type == "Cosine": @@ -143,32 +159,6 @@ def app(): sims = torch.nn.Softmax(dim=0)(sims) inv_sims = [sim * 100 for sim in sims.tolist()[::-1]] - elif model_type.startswith("ALBEF"): - vis_processor = load_processor("blip_image_eval").build(image_size=224) - img = vis_processor(raw_img).unsqueeze(0).to(device) - - text_processor = BlipCaptionProcessor(prompt="A picture of ") - cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names] - - feature_extractor = load_model_cache(model_type="albef", device=device) - - sample = {"image": img, "text_input": cls_prompt} - - with torch.no_grad(): - image_features = feature_extractor.extract_features( - sample, mode="image" - ).image_embeds_proj[:, 0] - text_features = feature_extractor.extract_features( - sample, mode="text" - ).text_embeds_proj[:, 0] - - st.write(image_features.shape) - st.write(text_features.shape) - - sims = (image_features @ text_features.t())[0] / feature_extractor.temp - - sims = torch.nn.Softmax(dim=0)(sims) - inv_sims = [sim * 100 for sim in sims.tolist()[::-1]] elif model_type.startswith("CLIP"): if model_type == "CLIP_ViT-B-32": @@ -193,6 +183,9 @@ def app(): image_features = clip_features.image_embeds_proj text_features = clip_features.text_embeds_proj + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1) inv_sims = sims.tolist()[::-1] else: @@ -202,15 +195,38 @@ def app(): fig = go.Figure( go.Bar( x=inv_sims, - y=cls_names[::-1], - text=["{:.2f}".format(s) for s in inv_sims], + y=[c+' ' for c in cls_names[::-1]], + text=["{:.2f}%".format(s) for s in inv_sims], orientation="h", ) ) fig.update_traces( - textfont_size=12, + textfont_size=16, textangle=0, textposition="outside", cliponaxis=False, + marker_color="#0176D3" ) - col2.plotly_chart(fig, use_container_width=True) + fig.add_vline(x=0, line_width=1, line_color="#C9C9C9") + fig.add_hline(y=-0.6, line_width=1, line_color="#C9C9C9") + fig.update_layout(font=dict(family="Salesforce Sans", size=25, color="#032D60")) + fig.update_layout( + xaxis = dict( + tickmode='linear', + tickfont = dict(size=16), + tick0=0, + dtick=20, + ticksuffix="%" + ), + yaxis = dict(tickfont = dict(size=16)), + plot_bgcolor= "rgba(0, 0, 0, 0)", + title="Zero-shot image classification", + title_font_family="Salesforce Sans", + title_font_size=28, + title_font_color="#032D60", + paper_bgcolor= "rgba(0, 0, 0, 0)",) + fig.update_xaxes(fixedrange=True, dtick=20) + fig.update_xaxes(range=[0, 100], ticks="outside", tickson="boundaries", ticklen=6) + col2.header("Prediction") + with row2_col2: + st.plotly_chart(fig, use_container_width=True) diff --git a/app/dataset_browser.py b/app/dataset_explorer.py similarity index 79% rename from app/dataset_browser.py rename to app/dataset_explorer.py index 6b761d89..db67ce2a 100644 --- a/app/dataset_browser.py +++ b/app/dataset_explorer.py @@ -1,21 +1,17 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - -import random from collections import OrderedDict from functools import reduce +import os from tkinter import N - import streamlit as st -from lavis.common.registry import registry -from lavis.datasets.builders import dataset_zoo, load_dataset -from lavis.datasets.builders.base_dataset_builder import load_dataset_config +import streamlit.components.v1 as components + +import random from PIL import Image +from lavis.datasets.builders import load_dataset, dataset_zoo +from lavis.datasets.builders.base_dataset_builder import load_dataset_config +from lavis.common.registry import registry + IMAGE_LAYOUT = 3, 4 VIDEO_LAYOUT = 1, 2 @@ -29,6 +25,18 @@ def sample_dataset(dataset, indices): return samples +# def create_gif_from_video(video_path): +# import imageio +# import os + +# video = imageio.get_reader(video_path) +# fps = video.get_meta_data()["fps"] +# images = [] +# for i in range(video.get_length()): +# images.append(video.get_data(i)) +# imageio.mimsave(os.path.splitext(video_path)[0] + ".gif", images, fps=fps) + + def get_concat_v(im1, im2): margin = 5 @@ -137,6 +145,16 @@ def show_samples(dataset, offset=0, is_next=False): col.markdown( "![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)" ) + # col.video(open(next(visual_info), "rb")) + # video_path = "/export/share/dongxuli/data/msrvtt_retrieval/videos/video0.mp4" + # col.video(next(visual_info)) + # col.markdown( + # f"""""", + # unsafe_allow_html=True, + # ) except StopIteration: break @@ -148,6 +166,22 @@ def show_samples(dataset, offset=0, is_next=False): st.session_state.n_display = n_samples +def show_dataset_card(): + builder = registry.get_builder_class(dataset_name) + cfg_path = builder.default_config_path() + config = load_dataset_config(cfg_path) + data_card = config.get("dataset_card", None) + + if data_card is None: + st.warning(f"No dataset card found for {dataset_name}.") + else: + img_path = data_card.replace("md", "png") + img = resize_img_w(Image.open(img_path), new_w=672) + st.image(img) + + st.markdown(open(data_card).read()) + + if __name__ == "__main__": st.set_page_config( page_title="LAVIS Dataset Explorer", @@ -157,9 +191,9 @@ def show_samples(dataset, offset=0, is_next=False): dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names()) - function = st.sidebar.selectbox("Function:", ["Browser"], index=0) + function = st.sidebar.selectbox("Function:", ["Dataset Card", "Explorer"], index=0) - if function == "Browser": + if function == "Explorer": shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0) dataset = load_dataset_cache(dataset_name) @@ -238,3 +272,5 @@ def show_samples(dataset, offset=0, is_next=False): offset=st.session_state.last_start - st.session_state.start_idx, is_next=True, ) + elif function == "Dataset Card": + show_dataset_card() diff --git a/app/image_text_match.py b/app/image_text_match.py index e7957384..1b7a35fb 100644 --- a/app/image_text_match.py +++ b/app/image_text_match.py @@ -1,23 +1,18 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - import numpy as np import streamlit as st import torch -from lavis.models.blip_models.blip_image_text_matching import compute_gradcam +from app import device, load_demo_image +from app.utils import ( + init_bert_tokenizer, + getAttMap, + load_blip_itm_model, +) from lavis.processors import load_processor from PIL import Image - -from app import device, load_demo_image -from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model - +from lavis.models.blip_models.blip_image_text_matching import compute_gradcam def app(): - model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + model_type = st.sidebar.selectbox("Model:", ["BLIP_large", "BLIP_base"]) if model_type.startswith("BLIP"): blip_type = model_type.split("_")[1] @@ -74,7 +69,7 @@ def app(): qry_tok = tokenizer(qry, return_tensors="pt").to(device) gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num) - avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True) + avg_gradcam = getAttMap(norm_img, gradcam[1], blur=True) col2.image(avg_gradcam, use_column_width=True, clamp=True) # output = model(img, question) diff --git a/app/main.py b/app/main.py index 108c46f8..6ee06a7d 100644 --- a/app/main.py +++ b/app/main.py @@ -1,25 +1,36 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - from app.multipage import MultiPage from app import vqa, caption +#from app import caption_front_end as caption from app import image_text_match as itm from app import text_localization as tl -from app import multimodal_search as ms +#from app import multimodal_search as ms +from app import multimodal_search_front_end as ms from app import classification as cl +from PIL import Image +import streamlit as st +from app import txt2image_front_end as ig if __name__ == "__main__": app = MultiPage() + logo = Image.open("app/logo_color.png") + st.sidebar.image(logo.resize((592, 157))) + + # add Salesforce Logo on top right + st.markdown( + "", + unsafe_allow_html=True, + ) + + # load custom css + with open("app/style_override.css") as f: + st.markdown(f'', unsafe_allow_html=True) + app.add_page("Image Description Generation", caption.app) app.add_page("Multimodal Search", ms.app) app.add_page("Visual Question Answering", vqa.app) - app.add_page("Image Text Matching", itm.app) - app.add_page("Text Localization", tl.app) - app.add_page("Classification", cl.app) + app.add_page("Zero-shot Image Classification", cl.app) + app.add_page("Text-to-Image Generation", ig.app) + app.add_page("Text Visualization", tl.app) app.run() diff --git a/app/multimodal_search.py b/app/multimodal_search.py index ffc97664..43f6ba6b 100644 --- a/app/multimodal_search.py +++ b/app/multimodal_search.py @@ -1,10 +1,3 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - import os import numpy as np @@ -19,7 +12,7 @@ read_img, resize_img, ) -from lavis.models import load_model +from lavis.models import BlipFeatureExtractor, load_model from lavis.processors import load_processor @@ -32,17 +25,12 @@ allow_output_mutation=True, ) def load_feat(): - from lavis.common.utils import download_url - - dirname = os.path.join(os.path.dirname(__file__), "assets") - filename = "path2feat_coco_train2014.pth" - filepath = os.path.join(dirname, filename) - url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth" - - if not os.path.exists(filepath): - download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth") - - path2feat = torch.load(filepath) + path2feat = torch.load( + os.path.join( + os.path.dirname(__file__), + "/export/home/.cache/lavis/path2feat_coco_train2014.pth", + ) + ) paths = sorted(path2feat.keys()) all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device) @@ -61,9 +49,7 @@ def load_feat(): def load_feature_extractor_model(device): model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" - model = load_model( - "blip_feature_extractor", model_type="base", is_eval=True, device=device - ) + model = load_model("blip_feature_extractor", is_eval=True, device=device) model.load_from_pretrained(model_url) return model @@ -91,93 +77,101 @@ def app(): vis_processor = load_processor("blip_image_eval").build(image_size=384) text_processor = load_processor("blip_caption") - user_question = st.text_input( - "Search query", "A dog running on the grass.", help="Type something to search." - ) - user_question = text_processor(user_question) - feature_extractor = load_feature_extractor_model(device) + row1_1, row1_spacer1, row1_2, row1_spacer2 = st.columns((15.5, .1, 3.5, 0.1)) + with row1_1: + user_question = st.text_input( + "Search query", "A dog running on the grass.", help="Type something to search." + ) + with row1_2: + st.markdown("") + st.markdown("") + search_button = st.button("Search") - # ======= ITC ========= - sample = {"text_input": user_question} + if search_button: + user_question = text_processor(user_question) + feature_extractor = load_feature_extractor_model(device) - with torch.no_grad(): - text_feature = feature_extractor.extract_features( - sample, mode="text" - ).text_embeds_proj[0, 0] + # ======= ITC ========= + sample = {"text_input": user_question} - path2feat, paths, all_img_feats = load_feat() - all_img_feats.to(device) - all_img_feats = F.normalize(all_img_feats, dim=1) + with torch.no_grad(): + text_feature = feature_extractor.extract_features( + sample, mode="text" + ).text_features[0, 0] - num_cols = 4 - num_rows = int(num_display / num_cols) + path2feat, paths, all_img_feats = load_feat() + all_img_feats.to(device) + all_img_feats = F.normalize(all_img_feats, dim=1) - similarities = text_feature @ all_img_feats.T - indices = torch.argsort(similarities, descending=True)[:num_display] + num_cols = 4 + num_rows = int(num_display / num_cols) - top_paths = [paths[ind.detach().cpu().item()] for ind in indices] - sorted_similarities = [similarities[idx] for idx in indices] - filenames = [os.path.join(file_root, p) for p in top_paths] + similarities = text_feature @ all_img_feats.T + indices = torch.argsort(similarities, descending=True)[:num_display] - # ========= ITM and GradCam ========== - bsz = 4 # max number of images to avoid cuda oom - if model_type.startswith("BLIP"): - blip_type = model_type.split("_")[1] + top_paths = [paths[ind.detach().cpu().item()] for ind in indices] + sorted_similarities = [similarities[idx] for idx in indices] + filenames = [os.path.join(file_root, p) for p in top_paths] - itm_model = load_blip_itm_model(device, model_type=blip_type) + # ========= ITM and GradCam ========== + bsz = 4 # max number of images to avoid cuda oom + if model_type.startswith("BLIP"): + blip_type = model_type.split("_")[1] - tokenizer = init_bert_tokenizer() - queries_batch = [user_question] * bsz - queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device) + itm_model = load_blip_itm_model(device, model_type=blip_type) - num_batches = int(num_display / bsz) + tokenizer = init_bert_tokenizer() + queries_batch = [user_question] * bsz + queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device) - avg_gradcams = [] - all_raw_images = [] - itm_scores = [] + num_batches = int(num_display / bsz) - for i in range(num_batches): - filenames_in_batch = filenames[i * bsz : (i + 1) * bsz] - raw_images, images = read_and_process_images(filenames_in_batch, vis_processor) - gradcam, itm_output = compute_gradcam_batch( - itm_model, images, queries_batch, queries_tok_batch - ) + avg_gradcams = [] + all_raw_images = [] + itm_scores = [] - all_raw_images.extend([resize_img(r_img) for r_img in raw_images]) - norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] + for i in range(num_batches): + filenames_in_batch = filenames[i * bsz : (i + 1) * bsz] + raw_images, images = read_and_process_images(filenames_in_batch, vis_processor) + gradcam, itm_output = compute_gradcam_batch( + itm_model, images, queries_batch, queries_tok_batch + ) - for norm_img, grad_cam in zip(norm_imgs, gradcam): - avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True) - avg_gradcams.append(avg_gradcam) + all_raw_images.extend([resize_img(r_img) for r_img in raw_images]) + norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images] - with torch.no_grad(): - itm_score = torch.nn.functional.softmax(itm_output, dim=1) + for norm_img, grad_cam in zip(norm_imgs, gradcam): + avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True) + avg_gradcams.append(avg_gradcam) + + with torch.no_grad(): + itm_score = torch.nn.functional.softmax(itm_output, dim=1) - itm_scores.append(itm_score) + itm_scores.append(itm_score) - # ========= ITM re-ranking ========= - itm_scores = torch.cat(itm_scores)[:, 1] - if itm_ranking: - itm_scores_sorted, indices = torch.sort(itm_scores, descending=True) + # ========= ITM re-ranking ========= + itm_scores = torch.cat(itm_scores)[:, 1] + if itm_ranking: + itm_scores_sorted, indices = torch.sort(itm_scores, descending=True) - avg_gradcams_sorted = [] - all_raw_images_sorted = [] - for idx in indices: - avg_gradcams_sorted.append(avg_gradcams[idx]) - all_raw_images_sorted.append(all_raw_images[idx]) + avg_gradcams_sorted = [] + all_raw_images_sorted = [] + for idx in indices: + avg_gradcams_sorted.append(avg_gradcams[idx]) + all_raw_images_sorted.append(all_raw_images[idx]) - avg_gradcams = avg_gradcams_sorted - all_raw_images = all_raw_images_sorted + avg_gradcams = avg_gradcams_sorted + all_raw_images = all_raw_images_sorted - if show_gradcam: - images_to_show = iter(avg_gradcams) - else: - images_to_show = iter(all_raw_images) + if show_gradcam: + images_to_show = iter(avg_gradcams) + else: + images_to_show = iter(all_raw_images) - for _ in range(num_rows): - with st.container(): - for col in st.columns(num_cols): - col.image(next(images_to_show), use_column_width=True, clamp=True) + for _ in range(num_rows): + with st.container(): + for col in st.columns(num_cols): + col.image(next(images_to_show), use_column_width=True, clamp=True) def read_and_process_images(image_paths, vis_processor): @@ -193,7 +187,7 @@ def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block block_num ].crossattention.self.save_attention = True - output = model({"image": visual_input, "text_input": text_input}, match_head="itm") + output = model(visual_input, text_input, match_head="itm") loss = output[:, 1].sum() model.zero_grad() diff --git a/app/multimodal_search_front_end.py b/app/multimodal_search_front_end.py new file mode 100644 index 00000000..3fb51d9a --- /dev/null +++ b/app/multimodal_search_front_end.py @@ -0,0 +1,112 @@ +import os, time, subprocess + +import numpy as np +import streamlit as st +import torch +from app import cache_root, device, pending_job_path, job_output_path +from app.utils import create_uniq_user_job_name + + +job_type = 'search' + +def app(): + # === layout === + model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"]) + if model_type.startswith("BLIP"): + blip_type = model_type.split("_")[1] + file_root = os.path.join(cache_root, "coco/images/train2014/") + + values = [16, 24, 48] + default_layer_num = values.index(24) + num_display = st.sidebar.selectbox( + "Number of images:", values, index=default_layer_num + ) + show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1) + itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0) + + st.markdown( + "

Multimodal Search

", unsafe_allow_html=True + ) + + row1_1, row1_spacer1, row1_2, row1_spacer2 = st.columns((15.5, .1, 3.5, 0.1)) + with row1_1: + user_question = st.text_input( + "Search query", "A dog running on the grass.", help="Type something to search." + ) + with row1_2: + st.markdown("") + st.markdown("") + search_button = st.button("Search") + + if search_button: + time_stamp = time.time() + pending_path = os.path.join(pending_job_path, job_type) + if not os.path.exists(pending_path): + subprocess.run(['mkdir', '-p', pending_path], shell=False) + file_name = '{}_result.txt'.format(create_uniq_user_job_name(str(time_stamp), user_question)) + with open(os.path.join(pending_path, file_name),'w') as new_job: + line = str(time_stamp)+'\t'+user_question+'\t'+str(num_display)+'\t'+blip_type + new_job.write(line+'\n') + new_job.close() + + num_pending_jobs = len(os.listdir(pending_path)) + outpath = os.path.join(job_output_path, job_type) + search_result = os.path.join(outpath, file_name) + + filenames = [] + with st.spinner("Queuing (#{} in line)".format(num_pending_jobs)): + while True: + if os.path.exists(search_result): + time.sleep(1) + with open(search_result) as f: + count = 0 + for line in f: + if count < num_display: + p = os.path.join(file_root, line.rstrip('\n')) + filenames.append(p) + count += 1 + break + # ========= ITM and GradCam ========== + + itm_scores_pt = outpath+'/{}_itm.pt'.format(create_uniq_user_job_name(str(time_stamp), user_question)) + itm_scores = torch.load(itm_scores_pt, map_location=torch.device('cpu')) + os.remove(itm_scores_pt) + + avg_gradcams_pt = outpath+'/{}_avg_gradcams.npy'.format(create_uniq_user_job_name(str(time_stamp), user_question)) + avg_gradcams = np.load(avg_gradcams_pt, allow_pickle=True) + + os.remove(avg_gradcams_pt) + + all_raw_images_pt = outpath+'/{}_all_raw_images.npy'.format(create_uniq_user_job_name(str(time_stamp), user_question)) + all_raw_images = np.load(all_raw_images_pt, allow_pickle=True) + os.remove(all_raw_images_pt) + + # ========= ITM re-ranking ========= + if itm_ranking: + itm_scores_sorted, indices = torch.sort(itm_scores, descending=True) + + avg_gradcams_sorted = [] + all_raw_images_sorted = [] + for idx in indices: + avg_gradcams_sorted.append(avg_gradcams[idx]) + all_raw_images_sorted.append(all_raw_images[idx]) + + avg_gradcams = avg_gradcams_sorted + all_raw_images = all_raw_images_sorted + + if show_gradcam: + images_to_show = iter(avg_gradcams) + else: + images_to_show = iter(all_raw_images) + + num_cols = 4 + num_rows = int(num_display / num_cols) + for _ in range(num_rows): + with st.container(): + for col in st.columns(num_cols): + col.image(next(images_to_show), use_column_width=True, clamp=True) + + + + + diff --git a/app/multipage.py b/app/multipage.py index 040f76eb..70912656 100644 --- a/app/multipage.py +++ b/app/multipage.py @@ -1,10 +1,3 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - """ This file is the framework for generating multiple Streamlit applications through an object oriented framework. diff --git a/app/style_override.css b/app/style_override.css new file mode 100644 index 00000000..04f8a292 --- /dev/null +++ b/app/style_override.css @@ -0,0 +1,57 @@ +body, div, span { + font-family: 'Salesforce Sans' !important; +} + +section.main { + align-items: self-start; + margin-left: 2rem; +} + +button[kind="header"], header { + display: none !important; +} + +section[data-testid="stSidebar"] > div:nth-child(1) > div:nth-child(2) { + padding-top: 0; +} +div[data-testid="stImage"] img { + margin-bottom: 2rem; +} + +div.block-container { + padding-top: 1rem; + max-width: 98%; +} + +div[data-testid="stMarkdownContainer"] h1 { + text-align: left !important; + padding-top: 15px; +} + +h1, h2 { + color: #032D60; +} + +section[data-testid="stFileUploadDropzone"] { + max-width: 30rem; +} + +.stButton { + margin-top: 15px; +} +.stButton button { + background-color: #0176D3 !important; + color: white !important; +} +.stButton button:hover { + background-color: #032D60 !important; +} + +div[data-testid="stHorizontalBlock"] { + gap: 4rem; + padding-top: 2rem; +} +div[data-testid="stHorizontalBlock"] div[data-testid="column"] { + width: calc(50% - 2rem); + flex: 1 1 calc(50% - 2rem); +} diff --git a/app/txt2image_front_end.py b/app/txt2image_front_end.py new file mode 100644 index 00000000..2a3910b0 --- /dev/null +++ b/app/txt2image_front_end.py @@ -0,0 +1,62 @@ +import streamlit as st +from PIL import Image +import os, time, subprocess +from app import pending_job_path, job_output_path +from app.utils import create_uniq_user_job_name + +job_type='txt2image' + +def compute_grid(num_images): + cols = [] + for i in range(num_images): + cols.append(st.columns()) + return cols + +def app(): + num_images = st.sidebar.slider("Choose number of images to generate", + min_value=1, max_value=3, step=1) + random_seed = st.sidebar.text_input("Seed:", 1024, + help="Set random seed to reproduce the generated images") + st.markdown( + "

Image Generation

", unsafe_allow_html=True + ) + row1_1, row1_spacer1, row1_2, row1_spacer2 = st.columns((15.5, .1, 3.5, 0.1)) + with row1_1: + user_prompt = st.text_input( + "Describe the image you would like to generate", + "a painting of Singapore Garden By the Bay in the style of Vincent Van Gogh", + help="Try something creative." + ) + with row1_2: + st.write("") + st.write("") + generation_button = st.button("Generate") + + if generation_button: + time_stamp = str(time.time()) + file_name = str(random_seed)+'\t'+str(time_stamp)+'\t'+user_prompt[:50]+'\t'+str(num_images)+'.txt' + pending_path = os.path.join(pending_job_path, job_type) + if not os.path.exists(pending_path): + subprocess.run(['mkdir', '-p', pending_path], shell=False) + with open(os.path.join(pending_path,file_name),'w') as new_job: + line = str(random_seed)+'\t'+time_stamp+'\t'+user_prompt+'\t'+str(num_images) + new_job.write(line+'\n') + new_job.close() + outpath = os.path.join(job_output_path, job_type) + if not os.path.exists(outpath): + os.mkdir(outpath) + sample_path = os.path.join(outpath, "samples") + if not os.path.exists(sample_path): + subprocess.run(['mkdir', '-p', sample_path], shell=False) + generated_image = sample_path+'/{}_grid.png'.format(create_uniq_user_job_name(time_stamp, user_prompt)) + num_pending_jobs = len(os.listdir(pending_path)) + with st.spinner("Queuing (#{} in line) for generation".format(num_pending_jobs)): + while True: + if os.path.exists(generated_image): + time.sleep(1) + cols = st.columns(num_images) + for i in range(1, num_images+1): + img = sample_path+'/{}_{}.png'.format(create_uniq_user_job_name(time_stamp, user_prompt), i) + with cols[i-1]: + st.image(Image.open(img)) + break diff --git a/app/utils.py b/app/utils.py index 5a4f209d..7617fa8a 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,19 +1,25 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - -import numpy as np import streamlit as st import torch -from lavis.models import BlipBase, load_model -from matplotlib import pyplot as plt +import numpy as np +from app import pending_job_path + +import os, glob + from PIL import Image + +from matplotlib import pyplot as plt from scipy.ndimage import filters from skimage import transform as skimage_transform +from lavis.models import BlipBase, load_model + +def get_pending_jobs(job_type): + list_of_prompts = filter(os.path.isfile, + glob.glob('{}/{}/*.txt'.format(pending_job_path, job_type) )) + # Sort list of files based on last modification time in ascending order + list_of_prompts = sorted(list_of_prompts, + key = os.path.getmtime) + return list(list_of_prompts) def resize_img(raw_img): w, h = raw_img.size @@ -46,6 +52,7 @@ def init_bert_tokenizer(): return tokenizer + def getAttMap(img, attMap, blur=True, overlap=True): attMap -= attMap.min() if attMap.max() > 0: @@ -75,7 +82,18 @@ def getAttMap(img, attMap, blur=True, overlap=True): allow_output_mutation=True, ) def load_blip_itm_model(device, model_type="base"): + # if model_type == "large": + # pretrained_path = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" + # else: + # pretrained_path = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth" + # model = BlipITM(pretrained=pretrained_path, vit=model_type) + # model.eval() + # model = model.to(device) + # return model model = load_model( "blip_image_text_matching", model_type, is_eval=True, device=device ) return model + +def create_uniq_user_job_name(time_stamp, user_info): + return str(time_stamp).replace('.','_') + '_' + '_'.join(user_info.split(' ')[:20]) diff --git a/app/vqa.py b/app/vqa.py index c505a985..4ade6a1c 100644 --- a/app/vqa.py +++ b/app/vqa.py @@ -1,10 +1,3 @@ -""" - # Copyright (c) 2022, salesforce.com, inc. - # All rights reserved. - # SPDX-License-Identifier: BSD-3-Clause - # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - import streamlit as st from app import load_demo_image, device from app.utils import load_model_cache @@ -13,7 +6,7 @@ def app(): - model_type = st.sidebar.selectbox("Model:", ["BLIP"]) + model_type = st.sidebar.selectbox("Model:", ["BLIP", "BLIP_aokvqa", "BLIP_okvqa"]) # ===== layout ===== st.markdown( @@ -39,10 +32,14 @@ def app(): col1.image(resized_image, use_column_width=True) col2.header("Question") - user_question = col2.text_input("Input your question!", "What are objects there?") - qa_button = st.button("Submit") - - col2.header("Answer") + with col2: + #user_question = col2.text_input("Input your question!", "What are objects there?") + #question = '

Input your question!

' + #st.markdown(question, unsafe_allow_html=True) + user_question = col2.text_area("Input your question!", "What are objects there?") + qa_button = st.button("Answer my question") + + #col2.header("Answer") # ===== event ===== vis_processor = load_processor("blip_image_eval").build(image_size=480) @@ -53,11 +50,19 @@ def app(): model = load_model_cache( "blip_vqa", model_type="vqav2", is_eval=True, device=device ) - + if model_type == "BLIP_aokvqa": + model.load_from_pretrained("https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_aokvqa.pth") + elif model_type == "BLIP_okvqa": + model.load_from_pretrained("https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_okvqa.pth") img = vis_processor(raw_img).unsqueeze(0).to(device) question = text_processor(user_question) vqa_samples = {"image": img, "text_input": [question]} answers = model.predict_answers(vqa_samples, inference_method="generate") - col2.write("\n".join(answers), use_column_width=True) + with col2: + st.header("Answer") + for answer in answers: + answer_md = '

{}

'.format(answer) + st.markdown(answer_md, unsafe_allow_html=True) + #col2.write("\n".join(answers), use_column_width=True) From 3d7b0d9cc3209e0210dc5a57fc72e452d80a41f1 Mon Sep 17 00:00:00 2001 From: Guangsen Wang Date: Fri, 9 Dec 2022 19:27:58 +0800 Subject: [PATCH 2/2] push files needed for GCP deployment --- Dockerfile | 53 ++++++++++++++++++++++++++++++++ docker/blip_pod_a100_docker.yaml | 48 +++++++++++++++++++++++++++++ docker/blip_service_a100.yaml | 12 ++++++++ docker/create_docker_image.sh | 3 ++ requirements-app.txt | 40 ++++++++++++++++++++++++ run_scripts/run_local.sh | 14 +++++++++ run_scripts/start_lavis_app.sh | 5 +++ 7 files changed, 175 insertions(+) create mode 100644 Dockerfile create mode 100644 docker/blip_pod_a100_docker.yaml create mode 100644 docker/blip_service_a100.yaml create mode 100644 docker/create_docker_image.sh create mode 100644 requirements-app.txt create mode 100755 run_scripts/run_local.sh create mode 100755 run_scripts/start_lavis_app.sh diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..47d7045d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,53 @@ +FROM nvcr.io/nvidia/pytorch:21.06-py3 + +COPY requirements-app.txt requirements_gpu.txt + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + locales \ + wget \ + build-essential \ + vim \ + htop \ + curl \ + git less ssh cmake \ + zip unzip gzip bzip2 \ + python3-tk gcc g++ libpq-dev + +RUN apt -y install openssh-server openssh-client +# BLIP-specific commands +RUN apt-get install -y libxtst6 +RUN pip3 uninstall -y torch +RUN pip3 uninstall -y torchtext +RUN pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html +RUN pip3 install omegaconf +RUN pip3 install ipython +RUN pip3 install pycocoevalcap +RUN pip3 install pycocotools +RUN pip3 install timm==0.4.12 +RUN pip3 install fairscale==0.4.4 +RUN apt install -y default-jre +RUN apt install -y openjdk-11-jre-headless +RUN apt install -y openjdk-8-jre-headless +RUN pip uninstall opencv-python +RUN pip uninstall opencv-contrib-python +RUN pip uninstall opencv-contrib-python-headless + + +RUN pip3 install -r requirements_gpu.txt + + +COPY . /lavis_app +WORKDIR /lavis_app + +RUN wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt +RUN mv sd-v1-4.ckpt /lavis_app/stable-diffusion/sd-v1-4.ckpt + +ENV PYTHONPATH="${PYTHONPATH}:./:/lavis_app:/lavis_app/stable-diffusion" + +EXPOSE 8080 +RUN chmod +x /lavis_app/run_scripts/start_lavis_app.sh +ENTRYPOINT ["/lavis_app/run_scripts/start_lavis_app.sh" ] + + diff --git a/docker/blip_pod_a100_docker.yaml b/docker/blip_pod_a100_docker.yaml new file mode 100644 index 00000000..99ab1300 --- /dev/null +++ b/docker/blip_pod_a100_docker.yaml @@ -0,0 +1,48 @@ +## cramaiah_nodot=`echo ${USER} | sed s/[.]/-/g`; sed "s/cramaiah/${cramaiah_nodot}/g" sfr-pod-cramaiah.yaml > sfr-pod-${cramaiah_nodot}.yaml +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: sfr-pod-blip-server-a100-docker + namespace: sfr-ns-guangsen-wang + labels: + app: detector-blip-server-a100-docker +spec: + replicas: 1 + selector: + matchLabels: + app: detector-blip-server-a100-docker + template: + metadata: + labels: + app: detector-blip-server-a100-docker + spec: + volumes: + - name: sfr-home-pv-guangsen-wang + persistentVolumeClaim: + claimName: sfr-home-pvc-guangsen-wang + - name: sfr-share-pv-guangsen-wang + persistentVolumeClaim: + claimName: sfr-share-pvc-guangsen-wang + containers: + - name: lavis-pytorch + image: "gcr.io/salesforce-research-internal/lavis_streamlit_gpu" + ports: + - containerPort: 8080 + resources: + limits: + nvidia.com/gpu: 2 + cpu: "23" + memory: 150G + volumeMounts: + - name: sfr-home-pv-guangsen-wang + mountPath: "/export/home" + - name: sfr-share-pv-guangsen-wang + mountPath: "/export/share" + nodeSelector: + cloud.google.com/gke-accelerator: nvidia-tesla-a100 + tolerations: + - key: "gpu_num" + operator: "Equal" + value: "2" + effect: "NoSchedule" diff --git a/docker/blip_service_a100.yaml b/docker/blip_service_a100.yaml new file mode 100644 index 00000000..66f8b220 --- /dev/null +++ b/docker/blip_service_a100.yaml @@ -0,0 +1,12 @@ +apiVersion: v1 +kind: Service +metadata: + name: sfr-pod-blip-server-a100-docker +spec: + type: LoadBalancer + selector: + app: detector-blip-server-a100-docker + ports: + - name: detector-blip-server-a100-docker + port: 8080 + targetPort: 8080 diff --git a/docker/create_docker_image.sh b/docker/create_docker_image.sh new file mode 100644 index 00000000..97affdf3 --- /dev/null +++ b/docker/create_docker_image.sh @@ -0,0 +1,3 @@ +#!/bin/bash +TAG="gcr.io/salesforce-research-internal/lavis_streamlit_gpu" +gcloud builds submit . -t=$TAG --machine-type=n1-highcpu-32 --timeout=9000 diff --git a/requirements-app.txt b/requirements-app.txt new file mode 100644 index 00000000..c8988daf --- /dev/null +++ b/requirements-app.txt @@ -0,0 +1,40 @@ +contexttimer +decord>=0.6.0 +einops>=0.4.1 +fairscale==0.4.4 +ftfy +iopath +ipython +omegaconf>=2.1.2 +opencv-python==4.5.5.64 +opendatasets +packaging +pandas +plotly +pre-commit +pycocoevalcap +pycocotools +python-magic +timm==0.4.12 +torch==1.10.0 +torchvision==0.11.1 +setuptools==59.5.0 +tqdm +webdataset +wheel +torchtext==0.11.0 +albumentations==0.4.3 +diffusers +pudb==2019.2 +invisible-watermark +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +pytorch-lightning==1.4.2 +test-tube>=0.7.5 +streamlit>=0.73.1 +torch-fidelity==0.3.0 +transformers==4.19.2 +torchmetrics==0.6.0 +kornia==0.6 +-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +-e git+https://github.com/openai/CLIP.git@main#egg=clip diff --git a/run_scripts/run_local.sh b/run_scripts/run_local.sh new file mode 100755 index 00000000..6f372cdd --- /dev/null +++ b/run_scripts/run_local.sh @@ -0,0 +1,14 @@ +#!/bin/bash +conda create -n lavis_local python=3.8 +conda init bash +# YOU MAY NEED TO RESTART THE SHELL FOR CONDA INIT TO TAKE EFFECT +echo "Restart shell and continue the rest" +exit 1 +conda activate lavis_local +git clone https://github.com/MetaMind/LAVIS.git +cd LAVIS +git clone https://github.com/CompVis/stable-diffusion.git +git checkout lavis_diffusion +pip install -r requirements-dev.txt +export PYTHONPATH=./:$PYTHONPATH:./stable-diffusion +streamlit run --server.port 8080 app/main.py diff --git a/run_scripts/start_lavis_app.sh b/run_scripts/start_lavis_app.sh new file mode 100755 index 00000000..7db0e239 --- /dev/null +++ b/run_scripts/start_lavis_app.sh @@ -0,0 +1,5 @@ +#!/bin/sh +nohup python /lavis_app/app/backend/multimodal_search_backend.py > app/backend/search.log & +nohup python /lavis_app/app/backend/txt2image_backend.py > app/backend/imagen.log & +nohup python /lavis_app/app/backend/caption_backend.py > app/backend/caption.log & +streamlit run --server.port 8080 /lavis_app/app/main.py