In [2]:
from PIL import Image
from dotenv import load_dotenv
import re
import pyszuru
import glob
import time
from tqdm.auto import tqdm
import os
from requests.exceptions import ConnectTimeout

In [3]:
USER = os.getenv('USER')
PASS = os.getenv('PASS')
LINK = os.getenv('LINK')
mybooru = pyszuru.API(
    LINK,
    username=USER,
    password=PASS
)

In [4]:
def txt_to_tagset(txtfile):
    with open(txtfile, 'r') as tx:
        tagset = set([ x.strip().replace('**', ',').split(',')[0].lower() for x in tx.read().splitlines() ])
    return tagset

def charfile_to_tagset(txtfile):
    with open(txtfile, 'r') as tx:
        pattern = re.compile('\((.*?):\d.\d\)')
        tagset = []
        for line in tx.read().splitlines():
            x = line.replace('[', '').replace(']', '')
            if pattern.findall(x):
                tagset.append(pattern.findall(x)[0].strip().lower())
            else:
                print(f'Problem in file {txtfile}! Offending line is:')
                print(line)
        tagset = set(tagset)
    return tagset

In [5]:
donepath = "../done/"
files = glob.glob("../tmp/*.png")
char_files = glob.glob("E:/stable-diffusion-webui/extensions/Umi-AI-Embeds/wildcards/characters/*.txt") + glob.glob("E:/stable-diffusion-webui/extensions/Umi-AI-Embeds/wildcards/characters/*.tags") + ["E:/stable-diffusion-webui/extensions/Umi-AI-Embeds/wildcards/Franchise Girls.txt"]
species_files = glob.glob("E:/stable-diffusion-webui/extensions/Umi-AI-Embeds/wildcards/species/*.txt")
comm_file = "E:/stable-diffusion-webui/extensions/Umi-AI-Embeds/wildcards/Common Tags.txt"

characters = set()
for cfile in char_files:
    characters = characters.union( charfile_to_tagset(cfile) )

species = set()
for sfile in species_files:
    species = species.union( charfile_to_tagset(sfile) )
common_tags = txt_to_tagset(comm_file)

In [6]:
def extract_tags(im):
    parameters = im.info['parameters']
    tokens = []
    ctags = []
    for line in parameters.splitlines()[:-2]:
        for sub in line.split(","):
            matches = re.findall(r'\((.*?):\d.\d\)', sub)
            if matches:
                tokens += [match.strip().lower() for match in matches]
        ctags += [ x.strip().lower() for x in line.replace('(', ' ').replace(')', ', ').replace(':', ', ').lower().split(',') ]
    tokens = set(tokens)
    ctags = set(ctags)

    character_tags = [ x.replace('\\', '').replace(' ', '_') for x in tokens.intersection(characters)]
    species_tags = [ x.replace('\\', '').replace(' ', '_') for x in tokens.intersection(species)]
    comm_tags = [ x.replace(' ', '_') for x in common_tags.intersection(ctags) ]

    return character_tags, species_tags, comm_tags

def upload_file(
                f, character_tags, species_tags, comm_tags,
                update_posts=True, sleep_time=30, updated=0, skipped=0, new_tags = []
                ):
    local_updated = updated
    local_skipped = skipped
    try:
        with open(f, "rb") as to_upload:
                file_token = mybooru.upload_file(to_upload)

        print(f'Attempting to upload file {f}')
        print(f'With tags {character_tags + species_tags + comm_tags}')
        try:
            new_post = mybooru.createPost(file_token, "safe")
        except pyszuru.SzurubooruHTTPError as e:
            print(e)
            if 'PostAlreadyUploadedError' in str(e):
                print("For file ", f)
                if update_posts:
                    post_id = re.split('\(|\)', str(e) )[1]
                    print('Will try to update tags...')
                    new_post = mybooru.getPost(post_id)
                    local_updated += 1
                else:
                    print('Skipping...')
                    local_skipped += 1
                    return local_updated, local_skipped, []
        for tag in character_tags:
            try:
                new_post.tags = [tag]
            except pyszuru.SzurubooruHTTPError as e:
                print(e)
                if 'TagNotFoundError' in str(e):
                    if tag in str(e):
                        new_tag = mybooru.createTag(tag)
                        new_tag.category = 'character'
                        new_tag.push()
                        new_tags.append(tag)
                        print(f'Created tag {tag}')
        for tag in species_tags:
            try:
                new_post.tags = [tag]
            except pyszuru.SzurubooruHTTPError as e:
                print(e)
                if 'TagNotFoundError' in str(e):
                    if tag in str(e):
                        new_tag = mybooru.createTag(tag)
                        new_tag.category = 'species'
                        new_tag.push()
                        new_tags.append(tag)
                        print(f'Created tag {tag}')
        for tag in comm_tags:
            try:
                new_post.tags = [tag]
            except pyszuru.SzurubooruHTTPError as e:
                print(e)
                if 'TagNotFoundError' in str(e):
                    if tag in str(e):
                        new_tag = mybooru.createTag(tag)
                        new_tag.push()
                        new_tags.append(tag)
                        print(f'Created tag {tag}')

        new_post.tags = character_tags + species_tags + comm_tags
        new_post.push()

        print(f'Successfully uploaded file {f}')
        print(f'With tags {character_tags + species_tags + comm_tags}')
        return local_updated, local_skipped, new_tags
    except pyszuru.SzurubooruHTTPError as e:
        print(e)
        err_msg = "Failed to connect to szurubooru REST API"
        if err_msg in str(e):
            print(f'Timed out! Will wait for {sleep_time}s and try again...')
            time.sleep(sleep_time)
            local_updated, local_skipped, new_tags = upload_file(
                f, character_tags, species_tags, comm_tags,
                update_posts=update_posts, updated=local_updated, skipped=local_skipped, new_tags=new_tags
            )
            return local_updated, local_skipped, new_tags
        else:
            raise Exception(e)
    except ConnectTimeout as e:
        print(e)
        err_msg = "Max retries exceeded with url"
        if err_msg in str(e):
            print(f'Timed out! Will wait for {sleep_time}s and try again...')
            time.sleep(sleep_time)
            local_updated, local_skipped, new_tags = upload_file(
                f, character_tags, species_tags, comm_tags,
                update_posts=update_posts, updated=local_updated, skipped=local_skipped, new_tags=new_tags
            )
            return local_updated, local_skipped, new_tags
        else:
            raise Exception(e)
    except Exception as e:
        print(e)
        err_msg = "ConnectionResetError"
        if err_msg in str(e):
            print(f'Timed out! Will wait for {sleep_time}s and try again...')
            time.sleep(sleep_time)
            local_updated, local_skipped, new_tags = upload_file(
                f, character_tags, species_tags, comm_tags,
                update_posts=update_posts, updated=local_updated, skipped=local_skipped, new_tags=new_tags
            )
            return local_updated, local_skipped, new_tags
        else:
            raise Exception(e)

In [7]:
update_posts = False

updated = 0
skipped = 0
new_tags = []
print(f'Going to process {len(files)} images...')
for f in tqdm(files):
    with Image.open(f) as im:
        character_tags, species_tags, comm_tags = extract_tags(im)

    if ( len(character_tags) + len(comm_tags) + len(species_tags) ) == 0:
        print("Warning! Could not auto detect tags")
        print(f"for image {f}")
        print('Skipping...')
        skipped += 1
        continue
    updated, skipped, new_tags = upload_file(
        f, character_tags, species_tags, comm_tags,
        update_posts=update_posts, updated=updated, new_tags=new_tags
        )
    os.replace(f, donepath + os.path.basename(f))

print(f'Processed {len(files)} images with {updated} reprocessed and {skipped} rejections.')
if len(new_tags) != 0:
    print(f'Created new tags {new_tags}')

Going to process 0 images...


0it [00:00, ?it/s]

Processed 0 images with 0 reprocessed and 0 rejections.


In [8]:
flist = glob.glob("../to_upload/**/*.*", recursive=True)
for f in tqdm(flist):
    print(f"trying to upload {f}")
    if os.path.isfile(os.path.abspath(f)):
        if (os.path.getsize(os.path.abspath(f)) / 1048576) > 1000:
            print(f"{f} is too large. skipping...")
            continue
        updated, skipped, new_tags = upload_file(
            f, [], [], [], update_posts=False
            )
        os.replace(f, donepath + os.path.basename(f))

0it [00:00, ?it/s]