Skip to content

Commit

Permalink
Add support for scripting from colab
Browse files Browse the repository at this point in the history
  • Loading branch information
pauloday committed Dec 20, 2021
1 parent c160202 commit 106bb59
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 17 deletions.
17 changes: 11 additions & 6 deletions src/artbot.py
@@ -1,10 +1,7 @@
#!/bin/python
import threading, os, shutil, ffpb
from sys import argv
from re import search
from torch.cuda import device_count
import shutil
from runner import run_prompts
from parse import parse_yaml
from parse import parse_yaml, flatten_array, expand_iteration
from tqdm import tqdm
from output import write_video

Expand All @@ -14,4 +11,12 @@ def run_yaml(yaml, out_dir, image_writer, status_writer, tqdm):
outputs = run_prompts(settings, prompts, image_prompts, out_dir, image_writer=image_writer, status_writer=status_writer, tqdm=tqdm)
if 'video' in settings.keys() and settings['video']:
write_video(out_dir, name, outputs, settings['fps'], tqdm=self.tqdm)
return outputs[-1]
shutil.copy(outputs[-1], out_dir)

def run_array(settings, prompts, image_prompts, out_dir, image_writer, status_writer, tqdm):
status_writer(False, 0)
parsed_prompts = flatten_array(map(expand_iteration, prompts))
outputs = run_prompts(settings, parsed_prompts, image_prompts, out_dir, image_writer=image_writer, status_writer=status_writer, tqdm=tqdm)
if 'video' in settings.keys() and settings['video']:
write_video(out_dir, name, outputs, settings['fps'], tqdm=self.tqdm)
shutil.copy(outputs[-1], out_dir)
2 changes: 1 addition & 1 deletion src/output.py
Expand Up @@ -16,7 +16,7 @@ def obj_hash(obj):
return str_hash(encoded)

def image_name(out_dir, i, settings):
image_name = f'{out_dir}/{i}_{settings["size"][0]}x{settings["size"][1]}.jpg'
image_name = f'{out_dir}/{settings["title"]}/{i}_{settings["size"][0]}x{settings["size"][1]}.jpg'
dump_args = settings.copy();
# these args shouldn't ever change individual images when modified
del dump_args['video']
Expand Down
4 changes: 2 additions & 2 deletions src/parse.py
Expand Up @@ -30,7 +30,7 @@ def get_settings(run):
def expand_iteration(line):
if mult_tok in line:
parts = line.split(mult_tok)
n = parts[0]
n = int(parts[0])
prompt = parts[1]
return [prompt] * n
return [line]
Expand All @@ -48,6 +48,6 @@ def flatten_array(t):
def parse_yaml(yaml):
parsed = load(yaml, Loader=FullLoader)
settings = get_settings(parsed['settings'])
prompts = map(expand_iteration, parsed['prompts'])
prompts = flatten_array(map(expand_iteration, parsed['prompts']))
image_prompts = [] #flatten_array(map(expand_iteration, parsed['image_prompts']))
return (settings, prompts, image_prompts)
17 changes: 9 additions & 8 deletions src/runner.py
Expand Up @@ -259,16 +259,15 @@ def ascend_txt():
for prompt in pMs:
result.append(prompt(iii))
return result

iterations = len(prompts) * args['time_scale']
def train(i):
set_prompts(i)
opt.zero_grad()
lossAll = ascend_txt()
display_freq = math.floor(len(prompts)/args['images'])
display_freq = math.floor(iterations/args['images'])
out_path = False
if (i % display_freq == 0 and i != 0) or i == len(prompts):
if (i % display_freq == 0 and i != 0) or i == iterations:
out_path = image_name(output_dir, i, args)
checkin(lossAll, out_path)
checkin(i, lossAll, out_path)
loss = sum(lossAll)
loss.backward()
opt.step()
Expand All @@ -279,9 +278,11 @@ def train(i):
out_paths = []
i = 0
try:
with tqdm(i) as pbar:
for prompt in prompts:
path = train(i + 1) # have i start at 1 without making pbar bigger
with tqdm(total=iterations) as pbar:
while i < iterations:
if (i % args['time_scale'] == 0):
set_prompts(int(i / args['time_scale']))
path = train(i) # have i start at 1 without making pbar bigger
if path:
out_paths.append(path)
if image_writer:
Expand Down

0 comments on commit 106bb59

Please sign in to comment.