# init cell

In [1]:
import torch
import torchvision.transforms as transforms
import json
# load model
model1 = torch.hub.load('/home/qzlzdy/Python/RF5_danbooru-pretrained_master', 'resnet50', source='local')
model1.eval()
model2 = torch.hub.load('/home/qzlzdy/Python/RF5_danbooru-pretrained_master', 'resnet50', source='local')
model2.eval()
# preprocess function
preprocess = transforms.Compose([
    transforms.Resize(360),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979])
])
# load categories
with open('class_names_6000.json', 'r') as f:
    class_names = json.load(f)

# function define

In [2]:
from pathlib import Path
from PIL import Image
import os
import threading

def clear_temporary_directory():
    temp = Path('../datasets/danbooru-images/temp/')
    all_temp_files = temp.glob('*')
    print('clearing temporary dictory...')
    for path in all_temp_files:
        os.remove(str(path))

class ConvertRGB(threading.Thread):
    def __init__(self, index):
        threading.Thread.__init__(self)
        self.load_path = '../datasets/danbooru-images/danbooru-images/0{:03d}/'.format(index)
        self.name = "Convertor {}".format(index)

    def run(self):
        data_root = Path(self.load_path)
        all_images = data_root.glob('*')
        all_images = [str(path) for path in all_images]
        save_path = '../datasets/danbooru-images/temp/{}'
        
        for path in all_images:
            Image.open(path).convert('RGB').save(save_path.format(path[49:]))     

def fill_temporary_directory(index):
    beg = 4 * index
    end = beg + 4
    threads = [ConvertRGB(i) for i in range(beg, end)]
    print('filling temporary directory...')
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

mutex = threading.Lock() # tag_file
class Tagger(threading.Thread):
    def __init__(self, name, model, all_images):
        threading.Thread.__init__(self)
        self.name = name
        self.all_images = all_images
        self.model = model
        self.total = len(all_images)
    
    def run(self):
        for i, path in enumerate(self.all_images):
            image = Image.open(path)
            image = preprocess(image)
            image = image.unsqueeze(0)
            with torch.no_grad():
                probs = self.model(image)
            probs = torch.sigmoid(probs[0])
            tmp = probs[probs > 0.25]
            inds = probs.argsort(descending=True)
            tags = [class_names[i] for i in inds[0:len(tmp)]]
            tags = {'id': path[33:], 'tags': tags}
            tags = json.dumps(tags)
            mutex.acquire()
            tag_file.write(tags + '\n')
            mutex.release()
            if i % 100 == 0:
                print('{}: tagging............({}/{})'.format(self.name, i, self.total))

def tag_temporary_directory():
    data_root = Path('../datasets/danbooru-images/temp/')
    all_images = data_root.glob('*')
    all_images = [str(path) for path in all_images]
    half = len(all_images) // 2
    threads = [
        Tagger('Tagger1', model1, all_images[:half]),
        Tagger('Tagger2', model2, all_images[half:])
    ]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

def single_thread_tagger():
    data_root = Path('../datasets/danbooru-images/temp/')
    all_images = data_root.glob('*')
    all_images = [str(path) for path in all_images]
    thread = Tagger('Tagger1', model1, all_images)
    thread.start()
    thread.join()

# main loop

In [3]:
for i in range(37, 38):
    print('starting round', i)
    clear_temporary_directory()
    fill_temporary_directory(i)
    b = 4 * i
    e = b + 3
    tag_file = open('./danbooru-tags/{}-{}.json'.format(b, e), 'w')
    tag_temporary_directory()
    tag_file.close()

starting round 37
clearing temporary dictory...
filling temporary directory...
Tagger1: tagging............(0/3319)
Tagger2: tagging............(0/3320)
Tagger2: tagging............(100/3320)
Tagger1: tagging............(100/3319)
Tagger2: tagging............(200/3320)
Tagger1: tagging............(200/3319)
Tagger2: tagging............(300/3320)
Tagger1: tagging............(300/3319)
Tagger2: tagging............(400/3320)
Tagger1: tagging............(400/3319)
Tagger2: tagging............(500/3320)
Tagger1: tagging............(500/3319)
Tagger2: tagging............(600/3320)
Tagger1: tagging............(600/3319)
Tagger2: tagging............(700/3320)
Tagger1: tagging............(700/3319)
Tagger2: tagging............(800/3320)
Tagger1: tagging............(800/3319)
Tagger2: tagging............(900/3320)
Tagger1: tagging............(900/3319)
Tagger2: tagging............(1000/3320)
Tagger1: tagging............(1000/3319)
Tagger2: tagging............(1100/3320)
Tagger1: tagging.........

In [4]:
# debug
print(all_images[0][33:])
print(len('../datasets/danbooru-images/temp/'))

273027.jpg
33


# alternative loop

In [5]:
i = 6
clear_temporary_direcory()
fill_temporary_directory(i)
b = 4 * i
e = b + 3
tag_file = open('./danbooru-tags/{}-{}.json'.format(b, e), 'w')
single_thread_tagger(ind)
tag_file.close()

Tagger2: tagging............(0/4443)
Tagger1: tagging............(0/4442)
Tagger1: tagging............(100/4442)
Tagger2: tagging............(100/4443)
Tagger1: tagging............(200/4442)
Tagger2: tagging............(200/4443)
Tagger1: tagging............(300/4442)
Tagger2: tagging............(300/4443)
Tagger1: tagging............(400/4442)
Tagger2: tagging............(400/4443)
Tagger1: tagging............(500/4442)
Tagger2: tagging............(500/4443)
Tagger1: tagging............(600/4442)
Tagger2: tagging............(600/4443)
Tagger1: tagging............(700/4442)
Tagger2: tagging............(700/4443)
Tagger1: tagging............(800/4442)
Tagger2: tagging............(800/4443)
Tagger1: tagging............(900/4442)
Tagger2: tagging............(900/4443)
Tagger1: tagging............(1000/4442)
Tagger2: tagging............(1000/4443)
Tagger1: tagging............(1100/4442)
Tagger2: tagging............(1100/4443)
Tagger1: tagging............(1200/4442)
Tagger2: tagging........