diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 00000000..47d7045d
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,53 @@
+FROM nvcr.io/nvidia/pytorch:21.06-py3
+
+COPY requirements-app.txt requirements_gpu.txt
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ locales \
+ wget \
+ build-essential \
+ vim \
+ htop \
+ curl \
+ git less ssh cmake \
+ zip unzip gzip bzip2 \
+ python3-tk gcc g++ libpq-dev
+
+RUN apt -y install openssh-server openssh-client
+# BLIP-specific commands
+RUN apt-get install -y libxtst6
+RUN pip3 uninstall -y torch
+RUN pip3 uninstall -y torchtext
+RUN pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
+RUN pip3 install omegaconf
+RUN pip3 install ipython
+RUN pip3 install pycocoevalcap
+RUN pip3 install pycocotools
+RUN pip3 install timm==0.4.12
+RUN pip3 install fairscale==0.4.4
+RUN apt install -y default-jre
+RUN apt install -y openjdk-11-jre-headless
+RUN apt install -y openjdk-8-jre-headless
+RUN pip uninstall opencv-python
+RUN pip uninstall opencv-contrib-python
+RUN pip uninstall opencv-contrib-python-headless
+
+
+RUN pip3 install -r requirements_gpu.txt
+
+
+COPY . /lavis_app
+WORKDIR /lavis_app
+
+RUN wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
+RUN mv sd-v1-4.ckpt /lavis_app/stable-diffusion/sd-v1-4.ckpt
+
+ENV PYTHONPATH="${PYTHONPATH}:./:/lavis_app:/lavis_app/stable-diffusion"
+
+EXPOSE 8080
+RUN chmod +x /lavis_app/run_scripts/start_lavis_app.sh
+ENTRYPOINT ["/lavis_app/run_scripts/start_lavis_app.sh" ]
+
+
diff --git a/app/__init__.py b/app/__init__.py
index 05522040..3a108837 100644
--- a/app/__init__.py
+++ b/app/__init__.py
@@ -1,10 +1,3 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
from PIL import Image
import requests
@@ -24,3 +17,6 @@ def load_demo_image():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache_root = "/export/home/.cache/lavis/"
+pending_job_path = "app/task_queues/pending_jobs/"
+finished_job_path = "app/task_queues/finished_jobs/"
+job_output_path = "app/task_queues/outputs/"
diff --git a/app/backend/caption_backend.py b/app/backend/caption_backend.py
new file mode 100644
index 00000000..8b4f99ca
--- /dev/null
+++ b/app/backend/caption_backend.py
@@ -0,0 +1,107 @@
+from app import device, load_demo_image
+from app.utils import load_model_cache, get_pending_jobs, create_uniq_user_job_name
+from app import job_output_path, finished_job_path, pending_job_path
+from lavis.processors import load_processor
+from PIL import Image
+
+import random
+import numpy as np
+import torch
+import os, shutil, time
+
+job_type = 'caption'
+
+if torch.cuda.is_available():
+ torch.cuda.set_device(0)
+ device = "cuda"
+else:
+ device = "cpu"
+
+def setup_seed(seed):
+ random.seed(seed)
+ np.random.seed(int(seed))
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ import torch.backends.cudnn as cudnn
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+def back_end():
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
+ blip_large_model = load_model_cache(
+ "blip_caption",
+ model_type=f"large_coco",
+ is_eval=True,
+ device=device,
+ )
+ blip_base_model = load_model_cache(
+ "blip_caption",
+ model_type=f"base_coco",
+ is_eval=True,
+ device=device,
+ )
+ os.makedirs(os.path.join(finished_job_path, job_type), exist_ok=True)
+ while True:
+ pending_jobs = get_pending_jobs(job_type)
+ for job in pending_jobs:
+ while True:
+ with open(job) as f:
+ content = f.readline().rstrip(' \n')
+ if len(content.split('\t')) == 5: break
+ time_stamp, blip_type, sampling_method, num_captions, seed = content.split('\t')
+ outpath = os.path.join(job_output_path, job_type)
+ os.makedirs(outpath, exist_ok=True)
+ img_file = outpath+'/{}_raw_image.pt'.format(create_uniq_user_job_name(time_stamp, sampling_method))
+ while True:
+ if os.path.exists(img_file):
+ break
+ time.sleep(1)
+ img = torch.load(outpath+'/{}_raw_image.pt'.format(create_uniq_user_job_name(time_stamp, sampling_method)),map_location=torch.device(device))
+ if blip_type == 'large':
+ model = blip_large_model
+ else:
+ model = blip_base_model
+ use_nucleus_sampling = False
+ if sampling_method == 'Nucleus sampling':
+ use_nucleus_sampling = True
+ setup_seed(int(seed))
+ captions = generate_caption(model, img, use_nucleus_sampling, int(num_captions))
+ caption_result = outpath+'/{}_result.txt'.format(create_uniq_user_job_name(time_stamp, sampling_method))
+ with open(caption_result,'w') as f:
+ for caption in captions:
+ f.write(caption+'\n')
+ shutil.move(job, os.path.join(finished_job_path, job_type))
+ os.remove(img_file)
+
+
+def generate_caption(
+ model, image, use_nucleus_sampling=False, num_captions = 1, num_beams=3, max_length=40, min_length=5
+):
+ samples = {"image": image}
+
+ captions = []
+ if use_nucleus_sampling:
+ #for _ in range(5):
+ captions = model.generate(
+ samples,
+ use_nucleus_sampling=True,
+ max_length=max_length,
+ min_length=min_length,
+ top_p=0.9,
+ num_captions=num_captions
+ )
+ #captions.append(caption[0])
+ else:
+ caption = model.generate(
+ samples,
+ use_nucleus_sampling=False,
+ num_beams=num_beams,
+ max_length=max_length,
+ min_length=min_length,
+ num_captions=1
+ )
+ captions.append(caption[0])
+ return captions
+if __name__ == "__main__":
+ back_end()
diff --git a/app/backend/multimodal_search_backend.py b/app/backend/multimodal_search_backend.py
new file mode 100644
index 00000000..17fab829
--- /dev/null
+++ b/app/backend/multimodal_search_backend.py
@@ -0,0 +1,198 @@
+import os, shutil
+
+import numpy as np
+import streamlit as st
+import torch
+import torch.nn.functional as F
+from app import cache_root, device, job_output_path, finished_job_path
+from app.utils import (
+ getAttMap,
+ init_bert_tokenizer,
+ load_blip_itm_model,
+ read_img,
+ resize_img,
+ get_pending_jobs,
+ create_uniq_user_job_name
+)
+from lavis.models import BlipFeatureExtractor, load_model
+from lavis.processors import load_processor
+
+if torch.cuda.is_available():
+ torch.cuda.set_device(0)
+ device = "cuda"
+else:
+ device = "cpu"
+
+job_type = 'search'
+
+def load_feat():
+ from lavis.common.utils import download_url
+
+ dirname = os.path.join(os.path.dirname(__file__), "assets")
+ filename = "path2feat_coco_train2014.pth"
+ filepath = os.path.join(dirname, filename)
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
+
+ if not os.path.exists(filepath):
+ download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
+
+ path2feat = torch.load(filepath)
+ paths = sorted(path2feat.keys())
+
+ all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
+
+ return path2feat, paths, all_img_feats
+
+def load_feature_extractor_model(device):
+ model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
+
+ model = load_model("blip_feature_extractor", model_type="base", is_eval=True, device=device)
+ model.load_from_pretrained(model_url)
+
+ return model
+
+def search(time_stamp, user_question, feature_extractor, vis_processor, raw_user_question, num_display, itm_model):
+ sample = {"text_input": user_question}
+ with torch.no_grad():
+ text_feature = feature_extractor.extract_features(
+ sample, mode="text").text_embeds_proj[0, 0]
+
+ path2feat, paths, all_img_feats = load_feat()
+ all_img_feats.to(device)
+ all_img_feats = F.normalize(all_img_feats, dim=1)
+
+ num_cols = 4
+ num_rows = int(num_display) // num_cols
+
+ similarities = text_feature @ all_img_feats.T
+ indices = torch.argsort(similarities, descending=True)[:num_display]
+
+ top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
+ sorted_similarities = [similarities[idx] for idx in indices]
+ file_root = os.path.join(cache_root, "coco/images/train2014/")
+ filenames = [os.path.join(file_root, p) for p in top_paths]
+ outpath = os.path.join(job_output_path, job_type)
+ os.makedirs(outpath, exist_ok=True)
+
+ bsz = 8 # max number of images to avoid cuda oom
+
+ #itm_model = load_blip_itm_model("cuda", model_type=blip_type)
+
+ tokenizer = init_bert_tokenizer()
+ queries_batch = [user_question] * bsz
+ queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to("cpu")
+
+ num_batches = int(num_display / bsz)
+
+ avg_gradcams = []
+ all_raw_images = []
+ itm_scores = []
+
+ for i in range(num_batches):
+ filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
+ raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
+ gradcam, itm_output = compute_gradcam_batch(
+ itm_model, images, queries_batch, queries_tok_batch
+ )
+
+ all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
+ norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
+
+ for norm_img, grad_cam in zip(norm_imgs, gradcam):
+ avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
+ avg_gradcams.append(avg_gradcam)
+
+ with torch.no_grad():
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)
+
+ itm_scores.append(itm_score)
+
+ #avg_gradcams = torch.cat(avg_gradcams)
+ #all_raw_images = torch.cat(all_raw_images)
+
+ itm_scores = torch.cat(itm_scores)[:, 1]
+ torch.save(itm_scores, outpath+'/{}_itm.pt'.format(create_uniq_user_job_name(time_stamp,raw_user_question)))
+ np.save(outpath+'/{}_avg_gradcams.npy'.format(create_uniq_user_job_name(time_stamp,raw_user_question)), avg_gradcams, allow_pickle=True)
+ np.save(outpath+'/{}_all_raw_images.npy'.format(create_uniq_user_job_name(time_stamp,raw_user_question)),all_raw_images,allow_pickle=True)
+
+ search_result = outpath+'/{}_result.txt'.format(create_uniq_user_job_name(time_stamp,raw_user_question))
+ with open(search_result,'w') as f:
+ for filename in filenames:
+ f.write(filename+'\n')
+
+def back_end():
+ # === event ===
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
+ text_processor = load_processor("blip_caption")
+ feature_extractor = load_feature_extractor_model(device)
+ os.makedirs("{}/{}/".format(finished_job_path, job_type), exist_ok=True)
+ large_itm_model = load_blip_itm_model(device, model_type='large')
+ base_itm_model = load_blip_itm_model(device, model_type='base')
+
+ while True:
+ pending_jobs = get_pending_jobs(job_type)
+ for job in pending_jobs:
+ while True:
+ with open(job) as f:
+ content = f.readline().rstrip(' \n')
+ if len(content.split('\t')) == 4: break
+ time_stamp, raw_user_question, num_display, blip_type = content.split('\t')
+ user_question = text_processor(raw_user_question)
+ if blip_type == 'large':
+ search(time_stamp, user_question, feature_extractor, vis_processor, raw_user_question, int(num_display), large_itm_model)
+ else:
+ search(time_stamp, user_question, feature_extractor, vis_processor, raw_user_question, int(num_display), base_itm_model)
+ shutil.move(job, "{}/{}/".format(finished_job_path,job_type))
+
+
+def read_and_process_images(image_paths, vis_processor):
+ raw_images = [read_img(path) for path in image_paths]
+ images = [vis_processor(r_img) for r_img in raw_images]
+ images_tensors = torch.stack(images).to(device)
+
+ return raw_images, images_tensors
+
+
+def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):
+ model.text_encoder.base_model.base_model.encoder.layer[
+ block_num
+ ].crossattention.self.save_attention = True
+
+ output = model(visual_input, text_input, match_head="itm")
+ loss = output[:, 1].sum()
+
+ model.zero_grad()
+ loss.backward()
+ with torch.no_grad():
+ mask = tokenized_text.attention_mask.view(
+ tokenized_text.attention_mask.size(0), 1, -1, 1, 1
+ ).to(device=device) # (bsz,1,token_len, 1,1)
+ token_length = mask.sum() - 2
+ token_length = token_length.cpu()
+ # grads and cams [bsz, num_head, seq_len, image_patch]
+ grads = model.text_encoder.base_model.base_model.encoder.layer[
+ block_num
+ ].crossattention.self.get_attn_gradients()
+ cams = model.text_encoder.base_model.base_model.encoder.layer[
+ block_num
+ ].crossattention.self.get_attention_map()
+
+ # assume using vit large with 576 num image patch
+ cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
+ grads = (
+ grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
+ * mask
+ )
+
+ gradcam = cams * grads
+ # [enc token gradcam, average gradcam across token, gradcam for individual token]
+ # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
+ gradcam = gradcam.mean(1).cpu().detach()
+ gradcam = (
+ gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length
+ )
+
+ return gradcam, output
+
+if __name__ == '__main__':
+ back_end()
diff --git a/app/backend/txt2image_backend.py b/app/backend/txt2image_backend.py
new file mode 100644
index 00000000..655159d7
--- /dev/null
+++ b/app/backend/txt2image_backend.py
@@ -0,0 +1,199 @@
+import os, shutil, subprocess
+import torch
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange
+from torchvision.utils import make_grid
+import time
+from pytorch_lightning import seed_everything
+from torch import autocast
+
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.plms import PLMSSampler
+
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from transformers import AutoFeatureExtractor
+
+from app.utils import (
+ get_pending_jobs,
+ create_uniq_user_job_name
+)
+
+from app import job_output_path, finished_job_path
+
+job_type = 'txt2image'
+if torch.cuda.is_available():
+ torch.cuda.set_device(1)
+ device = "cuda"
+else:
+ device = "cpu"
+
+
+# load safety model
+safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
+safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+def load_model_from_config(config, ckpt, verbose=False):
+ #print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if not os.path.exists('prompts'):
+ os.makedirs('prompts', exist_ok=True)
+ else:
+ shutil.rmtree('prompts')
+ os.makedirs('prompts', exist_ok=True)
+
+ model.half()
+ if torch.cuda.is_available():
+ model.cuda()
+ model.eval()
+ return model
+
+
+def load_replacement(x):
+ try:
+ hwc = x.shape
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
+ y = (np.array(y)/255.0).astype(x.dtype)
+ assert y.shape == x.shape
+ return y
+ except Exception:
+ return x
+
+
+def check_safety(x_image):
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
+ for i in range(len(has_nsfw_concept)):
+ if has_nsfw_concept[i]:
+ x_checked_image[i] = load_replacement(x_checked_image[i])
+ return x_checked_image, has_nsfw_concept
+
+def back_end():
+ config = 'stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
+ ckpt = 'stable-diffusion/sd-v1-4.ckpt'
+ config = OmegaConf.load(f"{config}")
+ model = load_model_from_config(config, f"{ckpt}")
+ print(device)
+ model = model.to(device)
+ outpath = os.path.join(job_output_path, job_type)
+ sample_path = os.path.join(outpath, "samples")
+ if not os.path.exists(sample_path):
+ subprocess.run(['mkdir', '-p', sample_path], shell=False)
+ finished_path = os.path.join(finished_job_path,job_type)
+ if not os.path.exists(finished_path):
+ subprocess.run(['mkdir', '-p',finished_path], shell=False)
+ while True:
+ pending_jobs = get_pending_jobs(job_type)
+ for job in pending_jobs:
+ while True:
+ with open(job) as f:
+ content = f.readline().rstrip(' \n')
+ if len(content.split('\t')) == 4: break
+ random_seed, time_stamp, user_prompt, num_images = content.split('\t')
+ generate_image(model, int(random_seed), time_stamp, user_prompt, int(num_images), sample_path)
+ shutil.move(job, finished_path)
+
+
+def generate_image(model, random_seed, time_stamp, user_prompt, num_images, sample_path):
+ scale = 7.5
+ num_latent_channels = 4
+ down_sample_factor = 8
+ H, W = 512, 512
+ ddim_steps = 50
+ skip_grid = False
+ skip_save = False
+ n_rows = num_images
+ precision_scope = autocast
+ data = [[user_prompt] * num_images]
+ seed_everything(random_seed)
+ model = model.to(device)
+ sampler = PLMSSampler(model)
+ with torch.no_grad():
+ with precision_scope(device):
+ #with precision_scope("cuda"):
+ with model.ema_scope():
+ tic = time.time()
+ all_samples = list()
+ for n in trange(1, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+ uc = None
+ if scale != 1.0:
+ uc = model.get_learned_conditioning(num_images * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+ shape = [num_latent_channels, H // down_sample_factor, W // down_sample_factor]
+ samples_ddim, _ = sampler.sample(S=ddim_steps,
+ conditioning=c,
+ batch_size= num_images,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc,
+ eta=0.0,
+ x_T=None)
+
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
+
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
+
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
+
+ if not skip_save:
+ count = 1
+ for x_sample in x_checked_image_torch:
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ img = Image.fromarray(x_sample.astype(np.uint8))
+ #img = put_watermark(img, wm_encoder)
+ img.save(os.path.join(sample_path, "{}_{}.png".format(create_uniq_user_job_name(str(time_stamp), user_prompt), count)))
+ count += 1
+ if not skip_grid:
+ all_samples.append(x_checked_image_torch)
+
+ if not skip_grid:
+ # additionally, save as grid
+ grid = torch.stack(all_samples, 0)
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
+ grid = make_grid(grid, nrow=n_rows)
+
+ # to image
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
+ img = Image.fromarray(grid.astype(np.uint8))
+ #img = put_watermark(img, wm_encoder)
+ img.save(os.path.join(sample_path, "{}_grid.png".format(create_uniq_user_job_name(str(time_stamp), user_prompt))))
+if __name__ == '__main__':
+ back_end()
diff --git a/app/calculate_coco_features.py b/app/calculate_coco_features.py
index 168e8503..19d3614f 100644
--- a/app/calculate_coco_features.py
+++ b/app/calculate_coco_features.py
@@ -1,10 +1,3 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
from PIL import Image
import requests
import torch
diff --git a/app/caption.py b/app/caption.py
index ad118988..b8c1ecd5 100644
--- a/app/caption.py
+++ b/app/caption.py
@@ -1,24 +1,38 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
import streamlit as st
from app import device, load_demo_image
from app.utils import load_model_cache
from lavis.processors import load_processor
from PIL import Image
+import random
+import numpy as np
+import torch
+
+
+def setup_seeds(seed):
+ random.seed(seed)
+ np.random.seed(int(seed))
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ import torch.backends.cudnn as cudnn
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
def app():
# ===== layout =====
- model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_large", "BLIP_base"])
sampling_method = st.sidebar.selectbox(
"Sampling method:", ["Beam search", "Nucleus sampling"]
)
+ num_captions = 1
+ if sampling_method == "Nucleus sampling":
+ random_seed = st.sidebar.text_input("Seed:", 1024, help="Set random seed to reproduce the image description")
+ setup_seeds(random_seed)
+ num_captions = st.sidebar.slider("Choose number of captions to generate",
+ min_value=1, max_value=5, step=1)
st.markdown(
"
Image Description Generation
",
@@ -44,8 +58,8 @@ def app():
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
col1.image(resized_image, use_column_width=True)
- col2.header("Description")
-
+ #col2.header("Description")
+ #with col2:
cap_button = st.button("Generate")
# ==== event ====
@@ -63,28 +77,39 @@ def app():
img = vis_processor(raw_img).unsqueeze(0).to(device)
captions = generate_caption(
- model=model, image=img, use_nucleus_sampling=not use_beam
+ model=model, image=img, use_nucleus_sampling=not use_beam, num_captions=num_captions
)
-
- col2.write("\n\n".join(captions), use_column_width=True)
+
+ #with col2:
+ # for caption in captions:
+ # caption_md = '{}
'.format(caption)
+ # st.markdown(caption_md, unsafe_allow_html=True)
+ with col2:
+ st.header("Description")
+ #with col2:
+ for caption in captions:
+ caption_md = '{}
'.format(caption)
+ st.markdown(caption_md, unsafe_allow_html=True)
+ #col2.write("\n\n".join(captions), use_column_width=True)
def generate_caption(
- model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5
+ model, image, use_nucleus_sampling=False, num_captions = 1, num_beams=3, max_length=40, min_length=5
):
samples = {"image": image}
captions = []
if use_nucleus_sampling:
- for _ in range(5):
- caption = model.generate(
+ #for _ in range(5):
+ captions = model.generate(
samples,
use_nucleus_sampling=True,
max_length=max_length,
min_length=min_length,
top_p=0.9,
- )
- captions.append(caption[0])
+ num_captions=num_captions
+ )
+ #captions.append(caption[0])
else:
caption = model.generate(
samples,
@@ -92,6 +117,7 @@ def generate_caption(
num_beams=num_beams,
max_length=max_length,
min_length=min_length,
+ num_captions=1
)
captions.append(caption[0])
diff --git a/app/caption_front_end.py b/app/caption_front_end.py
new file mode 100644
index 00000000..9d906c94
--- /dev/null
+++ b/app/caption_front_end.py
@@ -0,0 +1,105 @@
+import streamlit as st
+from app import load_demo_image, job_output_path, pending_job_path
+from app.utils import create_uniq_user_job_name
+from lavis.processors import load_processor
+from PIL import Image
+
+import os, time, subprocess
+import random
+import numpy as np
+import torch
+
+job_type = 'caption'
+
+def setup_seeds(seed):
+ random.seed(seed)
+ np.random.seed(int(seed))
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ import torch.backends.cudnn as cudnn
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+
+def app():
+ # ===== layout =====
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_large", "BLIP_base"])
+
+ sampling_method = st.sidebar.selectbox(
+ "Sampling method:", ["Beam search", "Nucleus sampling"]
+ )
+ num_captions = 1
+ if sampling_method == "Nucleus sampling":
+ random_seed = st.sidebar.text_input("Seed:", 1024, help="Set random seed to reproduce the image description")
+ setup_seeds(random_seed)
+ num_captions = st.sidebar.slider("Choose number of captions to generate",
+ min_value=1, max_value=5, step=1)
+
+ st.markdown(
+ "Image Description Generation
",
+ unsafe_allow_html=True,
+ )
+
+ instructions = """Try the provided image or upload your own:"""
+ file = st.file_uploader(instructions)
+
+ use_beam = sampling_method == "Beam search"
+
+ col1, col2 = st.columns(2)
+
+ if file:
+ raw_img = Image.open(file).convert("RGB")
+ else:
+ raw_img = load_demo_image()
+
+ col1.header("Image")
+
+ w, h = raw_img.size
+ scaling_factor = 720 / w
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
+
+ col1.image(resized_image, use_column_width=True)
+
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
+ img = vis_processor(raw_img).unsqueeze(0)
+
+ col2.header("Description")
+ #with col2:
+ cap_button = st.button("Generate")
+ blip_type = model_type.split("_")[1].lower()
+
+ if cap_button:
+ time_stamp = time.time()
+ pending_jobs = os.path.join(pending_job_path, job_type)
+ if not os.path.exists(pending_jobs):
+ #os.makedirs(pending_jobs)
+ subprocess.run(['mkdir', '-p', pending_jobs], shell=False)
+ file_name = '{}_result.txt'.format(create_uniq_user_job_name(time_stamp, sampling_method))
+ with open(os.path.join(pending_jobs, file_name),'w') as new_job:
+ line = str(time_stamp)+'\t'+blip_type+'\t'+str(sampling_method)+'\t'+str(num_captions)
+ new_job.write(line+'\n')
+ new_job.close()
+
+ num_pending_jobs = len(os.listdir(pending_jobs))
+ outpath = os.path.join(job_output_path,job_type)
+ if not os.path.exists(outpath):
+ subprocess.run(['mkdir', '-p', outpath], shell=False)
+ search_result = outpath+'/{}_result.txt'.format(create_uniq_user_job_name(time_stamp, sampling_method))
+ torch.save(img, outpath+'/{}_raw_image.pt'.format(create_uniq_user_job_name(time_stamp, sampling_method)))
+
+ with st.spinner("Queuing (#{} in line)".format(num_pending_jobs)):
+ while True:
+ if os.path.exists(search_result):
+ time.sleep(1)
+ with open(search_result) as f:
+ count = 0
+ with col2:
+ #st.header("Description")
+ for caption in f:
+ caption = caption.rstrip(' \n')
+ if count < num_captions:
+ caption_md = '{}
'.format(caption)
+ st.markdown(caption_md, unsafe_allow_html=True)
+ count += 1
+ break
diff --git a/app/classification.py b/app/classification.py
index 2b5bd896..df9e93a6 100644
--- a/app/classification.py
+++ b/app/classification.py
@@ -1,21 +1,15 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
import plotly.graph_objects as go
import requests
import streamlit as st
import torch
-from lavis.models import load_model
+from lavis.models import BlipFeatureExtractor, load_model
+from lavis.models.blip_models.blip_image_text_matching import BlipITM
from lavis.processors import load_processor
from lavis.processors.blip_processors import BlipCaptionProcessor
from PIL import Image
+import numpy as np
from app import device, load_demo_image
-from app.utils import load_blip_itm_model
from lavis.processors.clip_processors import ClipImageEvalProcessor
@@ -36,6 +30,7 @@ def load_demo_image(img_url=None):
allow_output_mutation=True,
)
def load_model_cache(model_type, device):
+
if model_type == "blip":
model = load_model(
"blip_feature_extractor", model_type="base", is_eval=True, device=device
@@ -60,56 +55,77 @@ def load_model_cache(model_type, device):
return model
+@st.cache(
+ hash_funcs={
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
+ .cpu()
+ .numpy()
+ },
+ allow_output_mutation=True,
+)
+def load_blip_itm_model(device):
+ pretrained_path = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
+ model = BlipITM(pretrained=pretrained_path, vit="base")
+ model.eval()
+ model = model.to(device)
+ return model
+
+
def app():
model_type = st.sidebar.selectbox(
"Model:",
- ["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
+ ["BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
)
score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"])
# ===== layout =====
st.markdown(
- "Zero-shot Classification
",
+ " Zero-shot Classification
",
unsafe_allow_html=True,
)
instructions = """Try the provided image or upload your own:"""
file = st.file_uploader(instructions)
- st.header("Image")
+ col1,col2 = st.columns(2)
+ col1.header("Image")
+ #col2.header("Categories")
+ row2_col1,row2_col2 = st.columns(2)
if file:
raw_img = Image.open(file).convert("RGB")
+ st.session_state.new_image = 'yes'
+ st.session_state.category = 'yes'
else:
raw_img = load_demo_image()
- st.image(raw_img) # , use_column_width=True)
-
- col1, col2 = st.columns(2)
-
- col1.header("Categories")
-
- cls_0 = col1.text_input("category 1", value="merlion")
- cls_1 = col1.text_input("category 2", value="sky")
- cls_2 = col1.text_input("category 3", value="giraffe")
- cls_3 = col1.text_input("category 4", value="fountain")
- cls_4 = col1.text_input("category 5", value="marina bay")
-
- cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]
- cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
+ w, h = raw_img.size
+ scaling_factor = 700 / w
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
- if len(cls_names) != len(set(cls_names)):
- st.error("Please provide unique class names")
- return
+ row2_col1.image(resized_image, use_column_width=True)
+ cls_names = []
button = st.button("Submit")
-
- col2.header("Prediction")
-
+ if 'cls_names' not in st.session_state or st.session_state.cls_names == '' or not button:
+ col2.header("Categories")
+ with row2_col2:
+ cls_0 = st.text_input("category 1", value="merlion")
+ cls_1 = st.text_input("category 2", value="elephant")
+ cls_2 = st.text_input("category 3", value="giraffe")
+ cls_3 = st.text_input("category 4", value="fountain")
+ cls_4 = st.text_input("category 5", value="marina bay")
+ cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4 ]
+ cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
+ st.session_state.cls_names = ','.join(cls_names)
+
+ if len(cls_names) != len(set(cls_names)):
+ st.error("Please provide unique class names")
+ return
# ===== event =====
-
if button:
if model_type.startswith("BLIP"):
text_processor = BlipCaptionProcessor(prompt="A picture of ")
+ cls_names = st.session_state['cls_names'].split(',')
cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
if score_type == "Cosine":
@@ -143,32 +159,6 @@ def app():
sims = torch.nn.Softmax(dim=0)(sims)
inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
- elif model_type.startswith("ALBEF"):
- vis_processor = load_processor("blip_image_eval").build(image_size=224)
- img = vis_processor(raw_img).unsqueeze(0).to(device)
-
- text_processor = BlipCaptionProcessor(prompt="A picture of ")
- cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
-
- feature_extractor = load_model_cache(model_type="albef", device=device)
-
- sample = {"image": img, "text_input": cls_prompt}
-
- with torch.no_grad():
- image_features = feature_extractor.extract_features(
- sample, mode="image"
- ).image_embeds_proj[:, 0]
- text_features = feature_extractor.extract_features(
- sample, mode="text"
- ).text_embeds_proj[:, 0]
-
- st.write(image_features.shape)
- st.write(text_features.shape)
-
- sims = (image_features @ text_features.t())[0] / feature_extractor.temp
-
- sims = torch.nn.Softmax(dim=0)(sims)
- inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
elif model_type.startswith("CLIP"):
if model_type == "CLIP_ViT-B-32":
@@ -193,6 +183,9 @@ def app():
image_features = clip_features.image_embeds_proj
text_features = clip_features.text_embeds_proj
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+
sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)
inv_sims = sims.tolist()[::-1]
else:
@@ -202,15 +195,38 @@ def app():
fig = go.Figure(
go.Bar(
x=inv_sims,
- y=cls_names[::-1],
- text=["{:.2f}".format(s) for s in inv_sims],
+ y=[c+' ' for c in cls_names[::-1]],
+ text=["{:.2f}%".format(s) for s in inv_sims],
orientation="h",
)
)
fig.update_traces(
- textfont_size=12,
+ textfont_size=16,
textangle=0,
textposition="outside",
cliponaxis=False,
+ marker_color="#0176D3"
)
- col2.plotly_chart(fig, use_container_width=True)
+ fig.add_vline(x=0, line_width=1, line_color="#C9C9C9")
+ fig.add_hline(y=-0.6, line_width=1, line_color="#C9C9C9")
+ fig.update_layout(font=dict(family="Salesforce Sans", size=25, color="#032D60"))
+ fig.update_layout(
+ xaxis = dict(
+ tickmode='linear',
+ tickfont = dict(size=16),
+ tick0=0,
+ dtick=20,
+ ticksuffix="%"
+ ),
+ yaxis = dict(tickfont = dict(size=16)),
+ plot_bgcolor= "rgba(0, 0, 0, 0)",
+ title="Zero-shot image classification",
+ title_font_family="Salesforce Sans",
+ title_font_size=28,
+ title_font_color="#032D60",
+ paper_bgcolor= "rgba(0, 0, 0, 0)",)
+ fig.update_xaxes(fixedrange=True, dtick=20)
+ fig.update_xaxes(range=[0, 100], ticks="outside", tickson="boundaries", ticklen=6)
+ col2.header("Prediction")
+ with row2_col2:
+ st.plotly_chart(fig, use_container_width=True)
diff --git a/app/dataset_browser.py b/app/dataset_explorer.py
similarity index 79%
rename from app/dataset_browser.py
rename to app/dataset_explorer.py
index 6b761d89..db67ce2a 100644
--- a/app/dataset_browser.py
+++ b/app/dataset_explorer.py
@@ -1,21 +1,17 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
-import random
from collections import OrderedDict
from functools import reduce
+import os
from tkinter import N
-
import streamlit as st
-from lavis.common.registry import registry
-from lavis.datasets.builders import dataset_zoo, load_dataset
-from lavis.datasets.builders.base_dataset_builder import load_dataset_config
+import streamlit.components.v1 as components
+
+import random
from PIL import Image
+from lavis.datasets.builders import load_dataset, dataset_zoo
+from lavis.datasets.builders.base_dataset_builder import load_dataset_config
+from lavis.common.registry import registry
+
IMAGE_LAYOUT = 3, 4
VIDEO_LAYOUT = 1, 2
@@ -29,6 +25,18 @@ def sample_dataset(dataset, indices):
return samples
+# def create_gif_from_video(video_path):
+# import imageio
+# import os
+
+# video = imageio.get_reader(video_path)
+# fps = video.get_meta_data()["fps"]
+# images = []
+# for i in range(video.get_length()):
+# images.append(video.get_data(i))
+# imageio.mimsave(os.path.splitext(video_path)[0] + ".gif", images, fps=fps)
+
+
def get_concat_v(im1, im2):
margin = 5
@@ -137,6 +145,16 @@ def show_samples(dataset, offset=0, is_next=False):
col.markdown(
"![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)"
)
+ # col.video(open(next(visual_info), "rb"))
+ # video_path = "/export/share/dongxuli/data/msrvtt_retrieval/videos/video0.mp4"
+ # col.video(next(visual_info))
+ # col.markdown(
+ # f"""""",
+ # unsafe_allow_html=True,
+ # )
except StopIteration:
break
@@ -148,6 +166,22 @@ def show_samples(dataset, offset=0, is_next=False):
st.session_state.n_display = n_samples
+def show_dataset_card():
+ builder = registry.get_builder_class(dataset_name)
+ cfg_path = builder.default_config_path()
+ config = load_dataset_config(cfg_path)
+ data_card = config.get("dataset_card", None)
+
+ if data_card is None:
+ st.warning(f"No dataset card found for {dataset_name}.")
+ else:
+ img_path = data_card.replace("md", "png")
+ img = resize_img_w(Image.open(img_path), new_w=672)
+ st.image(img)
+
+ st.markdown(open(data_card).read())
+
+
if __name__ == "__main__":
st.set_page_config(
page_title="LAVIS Dataset Explorer",
@@ -157,9 +191,9 @@ def show_samples(dataset, offset=0, is_next=False):
dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
- function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
+ function = st.sidebar.selectbox("Function:", ["Dataset Card", "Explorer"], index=0)
- if function == "Browser":
+ if function == "Explorer":
shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
dataset = load_dataset_cache(dataset_name)
@@ -238,3 +272,5 @@ def show_samples(dataset, offset=0, is_next=False):
offset=st.session_state.last_start - st.session_state.start_idx,
is_next=True,
)
+ elif function == "Dataset Card":
+ show_dataset_card()
diff --git a/app/image_text_match.py b/app/image_text_match.py
index e7957384..1b7a35fb 100644
--- a/app/image_text_match.py
+++ b/app/image_text_match.py
@@ -1,23 +1,18 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
import numpy as np
import streamlit as st
import torch
-from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
+from app import device, load_demo_image
+from app.utils import (
+ init_bert_tokenizer,
+ getAttMap,
+ load_blip_itm_model,
+)
from lavis.processors import load_processor
from PIL import Image
-
-from app import device, load_demo_image
-from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
-
+from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
def app():
- model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_large", "BLIP_base"])
if model_type.startswith("BLIP"):
blip_type = model_type.split("_")[1]
@@ -74,7 +69,7 @@ def app():
qry_tok = tokenizer(qry, return_tensors="pt").to(device)
gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
- avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
+ avg_gradcam = getAttMap(norm_img, gradcam[1], blur=True)
col2.image(avg_gradcam, use_column_width=True, clamp=True)
# output = model(img, question)
diff --git a/app/main.py b/app/main.py
index 108c46f8..6ee06a7d 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,25 +1,36 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
from app.multipage import MultiPage
from app import vqa, caption
+#from app import caption_front_end as caption
from app import image_text_match as itm
from app import text_localization as tl
-from app import multimodal_search as ms
+#from app import multimodal_search as ms
+from app import multimodal_search_front_end as ms
from app import classification as cl
+from PIL import Image
+import streamlit as st
+from app import txt2image_front_end as ig
if __name__ == "__main__":
app = MultiPage()
+ logo = Image.open("app/logo_color.png")
+ st.sidebar.image(logo.resize((592, 157)))
+
+ # add Salesforce Logo on top right
+ st.markdown(
+ "
",
+ unsafe_allow_html=True,
+ )
+
+ # load custom css
+ with open("app/style_override.css") as f:
+ st.markdown(f'', unsafe_allow_html=True)
+
app.add_page("Image Description Generation", caption.app)
app.add_page("Multimodal Search", ms.app)
app.add_page("Visual Question Answering", vqa.app)
- app.add_page("Image Text Matching", itm.app)
- app.add_page("Text Localization", tl.app)
- app.add_page("Classification", cl.app)
+ app.add_page("Zero-shot Image Classification", cl.app)
+ app.add_page("Text-to-Image Generation", ig.app)
+ app.add_page("Text Visualization", tl.app)
app.run()
diff --git a/app/multimodal_search.py b/app/multimodal_search.py
index ffc97664..43f6ba6b 100644
--- a/app/multimodal_search.py
+++ b/app/multimodal_search.py
@@ -1,10 +1,3 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
import os
import numpy as np
@@ -19,7 +12,7 @@
read_img,
resize_img,
)
-from lavis.models import load_model
+from lavis.models import BlipFeatureExtractor, load_model
from lavis.processors import load_processor
@@ -32,17 +25,12 @@
allow_output_mutation=True,
)
def load_feat():
- from lavis.common.utils import download_url
-
- dirname = os.path.join(os.path.dirname(__file__), "assets")
- filename = "path2feat_coco_train2014.pth"
- filepath = os.path.join(dirname, filename)
- url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
-
- if not os.path.exists(filepath):
- download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
-
- path2feat = torch.load(filepath)
+ path2feat = torch.load(
+ os.path.join(
+ os.path.dirname(__file__),
+ "/export/home/.cache/lavis/path2feat_coco_train2014.pth",
+ )
+ )
paths = sorted(path2feat.keys())
all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
@@ -61,9 +49,7 @@ def load_feat():
def load_feature_extractor_model(device):
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
- model = load_model(
- "blip_feature_extractor", model_type="base", is_eval=True, device=device
- )
+ model = load_model("blip_feature_extractor", is_eval=True, device=device)
model.load_from_pretrained(model_url)
return model
@@ -91,93 +77,101 @@ def app():
vis_processor = load_processor("blip_image_eval").build(image_size=384)
text_processor = load_processor("blip_caption")
- user_question = st.text_input(
- "Search query", "A dog running on the grass.", help="Type something to search."
- )
- user_question = text_processor(user_question)
- feature_extractor = load_feature_extractor_model(device)
+ row1_1, row1_spacer1, row1_2, row1_spacer2 = st.columns((15.5, .1, 3.5, 0.1))
+ with row1_1:
+ user_question = st.text_input(
+ "Search query", "A dog running on the grass.", help="Type something to search."
+ )
+ with row1_2:
+ st.markdown("")
+ st.markdown("")
+ search_button = st.button("Search")
- # ======= ITC =========
- sample = {"text_input": user_question}
+ if search_button:
+ user_question = text_processor(user_question)
+ feature_extractor = load_feature_extractor_model(device)
- with torch.no_grad():
- text_feature = feature_extractor.extract_features(
- sample, mode="text"
- ).text_embeds_proj[0, 0]
+ # ======= ITC =========
+ sample = {"text_input": user_question}
- path2feat, paths, all_img_feats = load_feat()
- all_img_feats.to(device)
- all_img_feats = F.normalize(all_img_feats, dim=1)
+ with torch.no_grad():
+ text_feature = feature_extractor.extract_features(
+ sample, mode="text"
+ ).text_features[0, 0]
- num_cols = 4
- num_rows = int(num_display / num_cols)
+ path2feat, paths, all_img_feats = load_feat()
+ all_img_feats.to(device)
+ all_img_feats = F.normalize(all_img_feats, dim=1)
- similarities = text_feature @ all_img_feats.T
- indices = torch.argsort(similarities, descending=True)[:num_display]
+ num_cols = 4
+ num_rows = int(num_display / num_cols)
- top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
- sorted_similarities = [similarities[idx] for idx in indices]
- filenames = [os.path.join(file_root, p) for p in top_paths]
+ similarities = text_feature @ all_img_feats.T
+ indices = torch.argsort(similarities, descending=True)[:num_display]
- # ========= ITM and GradCam ==========
- bsz = 4 # max number of images to avoid cuda oom
- if model_type.startswith("BLIP"):
- blip_type = model_type.split("_")[1]
+ top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
+ sorted_similarities = [similarities[idx] for idx in indices]
+ filenames = [os.path.join(file_root, p) for p in top_paths]
- itm_model = load_blip_itm_model(device, model_type=blip_type)
+ # ========= ITM and GradCam ==========
+ bsz = 4 # max number of images to avoid cuda oom
+ if model_type.startswith("BLIP"):
+ blip_type = model_type.split("_")[1]
- tokenizer = init_bert_tokenizer()
- queries_batch = [user_question] * bsz
- queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
+ itm_model = load_blip_itm_model(device, model_type=blip_type)
- num_batches = int(num_display / bsz)
+ tokenizer = init_bert_tokenizer()
+ queries_batch = [user_question] * bsz
+ queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
- avg_gradcams = []
- all_raw_images = []
- itm_scores = []
+ num_batches = int(num_display / bsz)
- for i in range(num_batches):
- filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
- raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
- gradcam, itm_output = compute_gradcam_batch(
- itm_model, images, queries_batch, queries_tok_batch
- )
+ avg_gradcams = []
+ all_raw_images = []
+ itm_scores = []
- all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
- norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
+ for i in range(num_batches):
+ filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
+ raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
+ gradcam, itm_output = compute_gradcam_batch(
+ itm_model, images, queries_batch, queries_tok_batch
+ )
- for norm_img, grad_cam in zip(norm_imgs, gradcam):
- avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
- avg_gradcams.append(avg_gradcam)
+ all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
+ norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
- with torch.no_grad():
- itm_score = torch.nn.functional.softmax(itm_output, dim=1)
+ for norm_img, grad_cam in zip(norm_imgs, gradcam):
+ avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
+ avg_gradcams.append(avg_gradcam)
+
+ with torch.no_grad():
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)
- itm_scores.append(itm_score)
+ itm_scores.append(itm_score)
- # ========= ITM re-ranking =========
- itm_scores = torch.cat(itm_scores)[:, 1]
- if itm_ranking:
- itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
+ # ========= ITM re-ranking =========
+ itm_scores = torch.cat(itm_scores)[:, 1]
+ if itm_ranking:
+ itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
- avg_gradcams_sorted = []
- all_raw_images_sorted = []
- for idx in indices:
- avg_gradcams_sorted.append(avg_gradcams[idx])
- all_raw_images_sorted.append(all_raw_images[idx])
+ avg_gradcams_sorted = []
+ all_raw_images_sorted = []
+ for idx in indices:
+ avg_gradcams_sorted.append(avg_gradcams[idx])
+ all_raw_images_sorted.append(all_raw_images[idx])
- avg_gradcams = avg_gradcams_sorted
- all_raw_images = all_raw_images_sorted
+ avg_gradcams = avg_gradcams_sorted
+ all_raw_images = all_raw_images_sorted
- if show_gradcam:
- images_to_show = iter(avg_gradcams)
- else:
- images_to_show = iter(all_raw_images)
+ if show_gradcam:
+ images_to_show = iter(avg_gradcams)
+ else:
+ images_to_show = iter(all_raw_images)
- for _ in range(num_rows):
- with st.container():
- for col in st.columns(num_cols):
- col.image(next(images_to_show), use_column_width=True, clamp=True)
+ for _ in range(num_rows):
+ with st.container():
+ for col in st.columns(num_cols):
+ col.image(next(images_to_show), use_column_width=True, clamp=True)
def read_and_process_images(image_paths, vis_processor):
@@ -193,7 +187,7 @@ def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block
block_num
].crossattention.self.save_attention = True
- output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
+ output = model(visual_input, text_input, match_head="itm")
loss = output[:, 1].sum()
model.zero_grad()
diff --git a/app/multimodal_search_front_end.py b/app/multimodal_search_front_end.py
new file mode 100644
index 00000000..3fb51d9a
--- /dev/null
+++ b/app/multimodal_search_front_end.py
@@ -0,0 +1,112 @@
+import os, time, subprocess
+
+import numpy as np
+import streamlit as st
+import torch
+from app import cache_root, device, pending_job_path, job_output_path
+from app.utils import create_uniq_user_job_name
+
+
+job_type = 'search'
+
+def app():
+ # === layout ===
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
+ if model_type.startswith("BLIP"):
+ blip_type = model_type.split("_")[1]
+ file_root = os.path.join(cache_root, "coco/images/train2014/")
+
+ values = [16, 24, 48]
+ default_layer_num = values.index(24)
+ num_display = st.sidebar.selectbox(
+ "Number of images:", values, index=default_layer_num
+ )
+ show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1)
+ itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0)
+
+ st.markdown(
+ "Multimodal Search
", unsafe_allow_html=True
+ )
+
+ row1_1, row1_spacer1, row1_2, row1_spacer2 = st.columns((15.5, .1, 3.5, 0.1))
+ with row1_1:
+ user_question = st.text_input(
+ "Search query", "A dog running on the grass.", help="Type something to search."
+ )
+ with row1_2:
+ st.markdown("")
+ st.markdown("")
+ search_button = st.button("Search")
+
+ if search_button:
+ time_stamp = time.time()
+ pending_path = os.path.join(pending_job_path, job_type)
+ if not os.path.exists(pending_path):
+ subprocess.run(['mkdir', '-p', pending_path], shell=False)
+ file_name = '{}_result.txt'.format(create_uniq_user_job_name(str(time_stamp), user_question))
+ with open(os.path.join(pending_path, file_name),'w') as new_job:
+ line = str(time_stamp)+'\t'+user_question+'\t'+str(num_display)+'\t'+blip_type
+ new_job.write(line+'\n')
+ new_job.close()
+
+ num_pending_jobs = len(os.listdir(pending_path))
+ outpath = os.path.join(job_output_path, job_type)
+ search_result = os.path.join(outpath, file_name)
+
+ filenames = []
+ with st.spinner("Queuing (#{} in line)".format(num_pending_jobs)):
+ while True:
+ if os.path.exists(search_result):
+ time.sleep(1)
+ with open(search_result) as f:
+ count = 0
+ for line in f:
+ if count < num_display:
+ p = os.path.join(file_root, line.rstrip('\n'))
+ filenames.append(p)
+ count += 1
+ break
+ # ========= ITM and GradCam ==========
+
+ itm_scores_pt = outpath+'/{}_itm.pt'.format(create_uniq_user_job_name(str(time_stamp), user_question))
+ itm_scores = torch.load(itm_scores_pt, map_location=torch.device('cpu'))
+ os.remove(itm_scores_pt)
+
+ avg_gradcams_pt = outpath+'/{}_avg_gradcams.npy'.format(create_uniq_user_job_name(str(time_stamp), user_question))
+ avg_gradcams = np.load(avg_gradcams_pt, allow_pickle=True)
+
+ os.remove(avg_gradcams_pt)
+
+ all_raw_images_pt = outpath+'/{}_all_raw_images.npy'.format(create_uniq_user_job_name(str(time_stamp), user_question))
+ all_raw_images = np.load(all_raw_images_pt, allow_pickle=True)
+ os.remove(all_raw_images_pt)
+
+ # ========= ITM re-ranking =========
+ if itm_ranking:
+ itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
+
+ avg_gradcams_sorted = []
+ all_raw_images_sorted = []
+ for idx in indices:
+ avg_gradcams_sorted.append(avg_gradcams[idx])
+ all_raw_images_sorted.append(all_raw_images[idx])
+
+ avg_gradcams = avg_gradcams_sorted
+ all_raw_images = all_raw_images_sorted
+
+ if show_gradcam:
+ images_to_show = iter(avg_gradcams)
+ else:
+ images_to_show = iter(all_raw_images)
+
+ num_cols = 4
+ num_rows = int(num_display / num_cols)
+ for _ in range(num_rows):
+ with st.container():
+ for col in st.columns(num_cols):
+ col.image(next(images_to_show), use_column_width=True, clamp=True)
+
+
+
+
+
diff --git a/app/multipage.py b/app/multipage.py
index 040f76eb..70912656 100644
--- a/app/multipage.py
+++ b/app/multipage.py
@@ -1,10 +1,3 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
"""
This file is the framework for generating multiple Streamlit applications
through an object oriented framework.
diff --git a/app/style_override.css b/app/style_override.css
new file mode 100644
index 00000000..04f8a292
--- /dev/null
+++ b/app/style_override.css
@@ -0,0 +1,57 @@
+body, div, span {
+ font-family: 'Salesforce Sans' !important;
+}
+
+section.main {
+ align-items: self-start;
+ margin-left: 2rem;
+}
+
+button[kind="header"], header {
+ display: none !important;
+}
+
+section[data-testid="stSidebar"] > div:nth-child(1) > div:nth-child(2) {
+ padding-top: 0;
+}
+div[data-testid="stImage"] img {
+ margin-bottom: 2rem;
+}
+
+div.block-container {
+ padding-top: 1rem;
+ max-width: 98%;
+}
+
+div[data-testid="stMarkdownContainer"] h1 {
+ text-align: left !important;
+ padding-top: 15px;
+}
+
+h1, h2 {
+ color: #032D60;
+}
+
+section[data-testid="stFileUploadDropzone"] {
+ max-width: 30rem;
+}
+
+.stButton {
+ margin-top: 15px;
+}
+.stButton button {
+ background-color: #0176D3 !important;
+ color: white !important;
+}
+.stButton button:hover {
+ background-color: #032D60 !important;
+}
+
+div[data-testid="stHorizontalBlock"] {
+ gap: 4rem;
+ padding-top: 2rem;
+}
+div[data-testid="stHorizontalBlock"] div[data-testid="column"] {
+ width: calc(50% - 2rem);
+ flex: 1 1 calc(50% - 2rem);
+}
diff --git a/app/txt2image_front_end.py b/app/txt2image_front_end.py
new file mode 100644
index 00000000..2a3910b0
--- /dev/null
+++ b/app/txt2image_front_end.py
@@ -0,0 +1,62 @@
+import streamlit as st
+from PIL import Image
+import os, time, subprocess
+from app import pending_job_path, job_output_path
+from app.utils import create_uniq_user_job_name
+
+job_type='txt2image'
+
+def compute_grid(num_images):
+ cols = []
+ for i in range(num_images):
+ cols.append(st.columns())
+ return cols
+
+def app():
+ num_images = st.sidebar.slider("Choose number of images to generate",
+ min_value=1, max_value=3, step=1)
+ random_seed = st.sidebar.text_input("Seed:", 1024,
+ help="Set random seed to reproduce the generated images")
+ st.markdown(
+ "Image Generation
", unsafe_allow_html=True
+ )
+ row1_1, row1_spacer1, row1_2, row1_spacer2 = st.columns((15.5, .1, 3.5, 0.1))
+ with row1_1:
+ user_prompt = st.text_input(
+ "Describe the image you would like to generate",
+ "a painting of Singapore Garden By the Bay in the style of Vincent Van Gogh",
+ help="Try something creative."
+ )
+ with row1_2:
+ st.write("")
+ st.write("")
+ generation_button = st.button("Generate")
+
+ if generation_button:
+ time_stamp = str(time.time())
+ file_name = str(random_seed)+'\t'+str(time_stamp)+'\t'+user_prompt[:50]+'\t'+str(num_images)+'.txt'
+ pending_path = os.path.join(pending_job_path, job_type)
+ if not os.path.exists(pending_path):
+ subprocess.run(['mkdir', '-p', pending_path], shell=False)
+ with open(os.path.join(pending_path,file_name),'w') as new_job:
+ line = str(random_seed)+'\t'+time_stamp+'\t'+user_prompt+'\t'+str(num_images)
+ new_job.write(line+'\n')
+ new_job.close()
+ outpath = os.path.join(job_output_path, job_type)
+ if not os.path.exists(outpath):
+ os.mkdir(outpath)
+ sample_path = os.path.join(outpath, "samples")
+ if not os.path.exists(sample_path):
+ subprocess.run(['mkdir', '-p', sample_path], shell=False)
+ generated_image = sample_path+'/{}_grid.png'.format(create_uniq_user_job_name(time_stamp, user_prompt))
+ num_pending_jobs = len(os.listdir(pending_path))
+ with st.spinner("Queuing (#{} in line) for generation".format(num_pending_jobs)):
+ while True:
+ if os.path.exists(generated_image):
+ time.sleep(1)
+ cols = st.columns(num_images)
+ for i in range(1, num_images+1):
+ img = sample_path+'/{}_{}.png'.format(create_uniq_user_job_name(time_stamp, user_prompt), i)
+ with cols[i-1]:
+ st.image(Image.open(img))
+ break
diff --git a/app/utils.py b/app/utils.py
index 5a4f209d..7617fa8a 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -1,19 +1,25 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
-import numpy as np
import streamlit as st
import torch
-from lavis.models import BlipBase, load_model
-from matplotlib import pyplot as plt
+import numpy as np
+from app import pending_job_path
+
+import os, glob
+
from PIL import Image
+
+from matplotlib import pyplot as plt
from scipy.ndimage import filters
from skimage import transform as skimage_transform
+from lavis.models import BlipBase, load_model
+
+def get_pending_jobs(job_type):
+ list_of_prompts = filter(os.path.isfile,
+ glob.glob('{}/{}/*.txt'.format(pending_job_path, job_type) ))
+ # Sort list of files based on last modification time in ascending order
+ list_of_prompts = sorted(list_of_prompts,
+ key = os.path.getmtime)
+ return list(list_of_prompts)
def resize_img(raw_img):
w, h = raw_img.size
@@ -46,6 +52,7 @@ def init_bert_tokenizer():
return tokenizer
+
def getAttMap(img, attMap, blur=True, overlap=True):
attMap -= attMap.min()
if attMap.max() > 0:
@@ -75,7 +82,18 @@ def getAttMap(img, attMap, blur=True, overlap=True):
allow_output_mutation=True,
)
def load_blip_itm_model(device, model_type="base"):
+ # if model_type == "large":
+ # pretrained_path = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
+ # else:
+ # pretrained_path = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
+ # model = BlipITM(pretrained=pretrained_path, vit=model_type)
+ # model.eval()
+ # model = model.to(device)
+ # return model
model = load_model(
"blip_image_text_matching", model_type, is_eval=True, device=device
)
return model
+
+def create_uniq_user_job_name(time_stamp, user_info):
+ return str(time_stamp).replace('.','_') + '_' + '_'.join(user_info.split(' ')[:20])
diff --git a/app/vqa.py b/app/vqa.py
index c505a985..4ade6a1c 100644
--- a/app/vqa.py
+++ b/app/vqa.py
@@ -1,10 +1,3 @@
-"""
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-"""
-
import streamlit as st
from app import load_demo_image, device
from app.utils import load_model_cache
@@ -13,7 +6,7 @@
def app():
- model_type = st.sidebar.selectbox("Model:", ["BLIP"])
+ model_type = st.sidebar.selectbox("Model:", ["BLIP", "BLIP_aokvqa", "BLIP_okvqa"])
# ===== layout =====
st.markdown(
@@ -39,10 +32,14 @@ def app():
col1.image(resized_image, use_column_width=True)
col2.header("Question")
- user_question = col2.text_input("Input your question!", "What are objects there?")
- qa_button = st.button("Submit")
-
- col2.header("Answer")
+ with col2:
+ #user_question = col2.text_input("Input your question!", "What are objects there?")
+ #question = 'Input your question!
'
+ #st.markdown(question, unsafe_allow_html=True)
+ user_question = col2.text_area("Input your question!", "What are objects there?")
+ qa_button = st.button("Answer my question")
+
+ #col2.header("Answer")
# ===== event =====
vis_processor = load_processor("blip_image_eval").build(image_size=480)
@@ -53,11 +50,19 @@ def app():
model = load_model_cache(
"blip_vqa", model_type="vqav2", is_eval=True, device=device
)
-
+ if model_type == "BLIP_aokvqa":
+ model.load_from_pretrained("https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_aokvqa.pth")
+ elif model_type == "BLIP_okvqa":
+ model.load_from_pretrained("https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_okvqa.pth")
img = vis_processor(raw_img).unsqueeze(0).to(device)
question = text_processor(user_question)
vqa_samples = {"image": img, "text_input": [question]}
answers = model.predict_answers(vqa_samples, inference_method="generate")
- col2.write("\n".join(answers), use_column_width=True)
+ with col2:
+ st.header("Answer")
+ for answer in answers:
+ answer_md = '{}
'.format(answer)
+ st.markdown(answer_md, unsafe_allow_html=True)
+ #col2.write("\n".join(answers), use_column_width=True)
diff --git a/docker/blip_pod_a100_docker.yaml b/docker/blip_pod_a100_docker.yaml
new file mode 100644
index 00000000..99ab1300
--- /dev/null
+++ b/docker/blip_pod_a100_docker.yaml
@@ -0,0 +1,48 @@
+## cramaiah_nodot=`echo ${USER} | sed s/[.]/-/g`; sed "s/cramaiah/${cramaiah_nodot}/g" sfr-pod-cramaiah.yaml > sfr-pod-${cramaiah_nodot}.yaml
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: sfr-pod-blip-server-a100-docker
+ namespace: sfr-ns-guangsen-wang
+ labels:
+ app: detector-blip-server-a100-docker
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: detector-blip-server-a100-docker
+ template:
+ metadata:
+ labels:
+ app: detector-blip-server-a100-docker
+ spec:
+ volumes:
+ - name: sfr-home-pv-guangsen-wang
+ persistentVolumeClaim:
+ claimName: sfr-home-pvc-guangsen-wang
+ - name: sfr-share-pv-guangsen-wang
+ persistentVolumeClaim:
+ claimName: sfr-share-pvc-guangsen-wang
+ containers:
+ - name: lavis-pytorch
+ image: "gcr.io/salesforce-research-internal/lavis_streamlit_gpu"
+ ports:
+ - containerPort: 8080
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ cpu: "23"
+ memory: 150G
+ volumeMounts:
+ - name: sfr-home-pv-guangsen-wang
+ mountPath: "/export/home"
+ - name: sfr-share-pv-guangsen-wang
+ mountPath: "/export/share"
+ nodeSelector:
+ cloud.google.com/gke-accelerator: nvidia-tesla-a100
+ tolerations:
+ - key: "gpu_num"
+ operator: "Equal"
+ value: "2"
+ effect: "NoSchedule"
diff --git a/docker/blip_service_a100.yaml b/docker/blip_service_a100.yaml
new file mode 100644
index 00000000..66f8b220
--- /dev/null
+++ b/docker/blip_service_a100.yaml
@@ -0,0 +1,12 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: sfr-pod-blip-server-a100-docker
+spec:
+ type: LoadBalancer
+ selector:
+ app: detector-blip-server-a100-docker
+ ports:
+ - name: detector-blip-server-a100-docker
+ port: 8080
+ targetPort: 8080
diff --git a/docker/create_docker_image.sh b/docker/create_docker_image.sh
new file mode 100644
index 00000000..97affdf3
--- /dev/null
+++ b/docker/create_docker_image.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+TAG="gcr.io/salesforce-research-internal/lavis_streamlit_gpu"
+gcloud builds submit . -t=$TAG --machine-type=n1-highcpu-32 --timeout=9000
diff --git a/requirements-app.txt b/requirements-app.txt
new file mode 100644
index 00000000..c8988daf
--- /dev/null
+++ b/requirements-app.txt
@@ -0,0 +1,40 @@
+contexttimer
+decord>=0.6.0
+einops>=0.4.1
+fairscale==0.4.4
+ftfy
+iopath
+ipython
+omegaconf>=2.1.2
+opencv-python==4.5.5.64
+opendatasets
+packaging
+pandas
+plotly
+pre-commit
+pycocoevalcap
+pycocotools
+python-magic
+timm==0.4.12
+torch==1.10.0
+torchvision==0.11.1
+setuptools==59.5.0
+tqdm
+webdataset
+wheel
+torchtext==0.11.0
+albumentations==0.4.3
+diffusers
+pudb==2019.2
+invisible-watermark
+imageio==2.9.0
+imageio-ffmpeg==0.4.2
+pytorch-lightning==1.4.2
+test-tube>=0.7.5
+streamlit>=0.73.1
+torch-fidelity==0.3.0
+transformers==4.19.2
+torchmetrics==0.6.0
+kornia==0.6
+-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+-e git+https://github.com/openai/CLIP.git@main#egg=clip
diff --git a/run_scripts/run_local.sh b/run_scripts/run_local.sh
new file mode 100755
index 00000000..6f372cdd
--- /dev/null
+++ b/run_scripts/run_local.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+conda create -n lavis_local python=3.8
+conda init bash
+# YOU MAY NEED TO RESTART THE SHELL FOR CONDA INIT TO TAKE EFFECT
+echo "Restart shell and continue the rest"
+exit 1
+conda activate lavis_local
+git clone https://github.com/MetaMind/LAVIS.git
+cd LAVIS
+git clone https://github.com/CompVis/stable-diffusion.git
+git checkout lavis_diffusion
+pip install -r requirements-dev.txt
+export PYTHONPATH=./:$PYTHONPATH:./stable-diffusion
+streamlit run --server.port 8080 app/main.py
diff --git a/run_scripts/start_lavis_app.sh b/run_scripts/start_lavis_app.sh
new file mode 100755
index 00000000..7db0e239
--- /dev/null
+++ b/run_scripts/start_lavis_app.sh
@@ -0,0 +1,5 @@
+#!/bin/sh
+nohup python /lavis_app/app/backend/multimodal_search_backend.py > app/backend/search.log &
+nohup python /lavis_app/app/backend/txt2image_backend.py > app/backend/imagen.log &
+nohup python /lavis_app/app/backend/caption_backend.py > app/backend/caption.log &
+streamlit run --server.port 8080 /lavis_app/app/main.py