Skip to content

Commit

Permalink
Made sure GFPGAN and RealESRGAN are on server_state. (#1319)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeroCool940711 committed Sep 26, 2022
1 parent 5d1558f commit 1fd28ee
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 88 deletions.
39 changes: 22 additions & 17 deletions scripts/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,15 @@ def load_GFPGAN():
sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir))
from gfpgan import GFPGANer

if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
elif st.session_state['defaults'].general.extra_models_gpu:
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}"))
else:
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
return instance
with server_state_lock['GFPGAN']:
if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
elif st.session_state['defaults'].general.extra_models_gpu:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}"))
else:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))

return server_state['GFPGAN']

@retry(tries=5)
def load_RealESRGAN(model_name: str):
Expand All @@ -709,17 +711,18 @@ def load_RealESRGAN(model_name: str):
sys.path.append(os.path.abspath(st.session_state['defaults'].general.RealESRGAN_dir))
from realesrgan import RealESRGANer

if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
instance.device = torch.device('cpu')
instance.model.to('cpu')
elif st.session_state['defaults'].general.extra_models_gpu:
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}"))
else:
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
instance.model.name = model_name
with server_state_lock['RealESRGAN']:
if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
server_state['RealESRGAN'] = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
server_state['RealESRGAN'].device = torch.device('cpu')
server_state['RealESRGAN'].model.to('cpu')
elif st.session_state['defaults'].general.extra_models_gpu:
server_state['RealESRGAN'] = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}"))
else:
server_state['RealESRGAN'] = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
server_state['RealESRGAN'].model.name = model_name

return instance
return server_state['RealESRGAN']

#
@retry(tries=5)
Expand All @@ -728,6 +731,7 @@ def load_LDSR(checking=False):
yaml_name = 'project'
model_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt')
yaml_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml')

if not os.path.isfile(model_path):
raise Exception("LDSR model not found at path "+model_path)
if not os.path.isfile(yaml_path):
Expand All @@ -738,6 +742,7 @@ def load_LDSR(checking=False):
sys.path.append(os.path.abspath(st.session_state['defaults'].general.LDSR_dir))
from LDSR import LDSR
LDSRObject = LDSR(model_path, yaml_path)

return LDSRObject

#
Expand Down
142 changes: 71 additions & 71 deletions scripts/txt2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ class plugin_info():


if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
GFPGAN_available = True
server_state["GFPGAN_available"] = True
else:
GFPGAN_available = False
server_state["GFPGAN_available"] = False

if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")):
RealESRGAN_available = True
server_state["RealESRGAN_available"] = True
else:
RealESRGAN_available = False
server_state["RealESRGAN_available"] = False

#
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -484,7 +484,7 @@ def txt2vid(
with autocast("cuda"):
image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)

if st.session_state["save_individual_images"] and not server_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
#im = Image.fromarray(image)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
image.save(outpath, quality=quality)
Expand All @@ -498,8 +498,8 @@ def txt2vid(

#
#try:
#if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
#if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
#print("Running GFPGAN on image ...")
st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
#skip_save = True # #287 >_>
Expand Down Expand Up @@ -714,12 +714,12 @@ def layout():
help="Do loop")
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")

if GFPGAN_available:
server_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
if server_state["GFPGAN_available"]:
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else:
server_state["use_GFPGAN"] = False
st.session_state["use_GFPGAN"] = False

if RealESRGAN_available:
if server_state["RealESRGAN_available"]:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
Expand All @@ -743,9 +743,9 @@ def layout():
if generate_button:
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
#load_models(False, server_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])

if server_state["use_GFPGAN"]:
if st.session_state["use_GFPGAN"]:
if "GFPGAN" in st.session_state:
print("GFPGAN already loaded")
else:
Expand All @@ -762,63 +762,63 @@ def layout():
if "GFPGAN" in st.session_state:
del server_state["GFPGAN"]

try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]

except (StopException, KeyError):
print(f"Received Streamlit StopException")
#try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)

#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")

#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']

#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()

#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()

#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()

## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)


#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]

#except (StopException, KeyError):
#print(f"Received Streamlit StopException")


0 comments on commit 1fd28ee

Please sign in to comment.