In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from PIL import Image, ImageDraw, ImageFont
import json
import random
import gradio as gr

import clip
import os
from tqdm import tqdm
from fine_tune import draw_text_with_new_lines, MyDataset, TestDataset, calculate_corr, load_model, evaluate, all_attributes
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager as font_manager

# If using GPU then use mixed precision training.
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Must set jit=False for training
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

from torchvision.transforms.functional import pil_to_tensor, to_pil_image

In [68]:
char_size = 150
font_dir = '../gwfonts'
cj_font_dir = '../all-fonts'
font_paths = [os.path.join(font_dir, f) for f in os.listdir(font_dir)]
cj_font_paths = [os.path.join(cj_font_dir, f) for f in os.listdir(cj_font_dir)]

# add font
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)

for font in font_manager.findSystemFonts(cj_font_dir):
    font_manager.fontManager.addfont(font)

ttf_list = font_manager.fontManager.ttflist

predicted_attributes = json.load(open('../attributeData/predicted_cj_font_attribute.json', 'r'))
attribute_to_indexes = {attribute: all_attributes.index(attribute) for attribute in all_attributes}

In [69]:
def draw_text_with_new_lines(text, font, img_width, img_height):
    image = Image.new('RGB', (img_width, img_height), color=(255, 255, 255))
    draw = ImageDraw.Draw(image)
    lines = text.split('\n')
    y_text = 0
    for line in lines:
        line_width, line_height = font.getsize(line)
        draw.text(((img_width - line_width) / 2, y_text),
                  line, font=font, fill=(0, 0, 0))
        y_text += line_height
    return image

In [70]:
font_to_indexes = {}
for attribute in all_attributes:
    cos_sims = np.array([predicted_attributes[os.path.splitext(os.path.basename(font_path))[0]][attribute] for font_path in cj_font_paths])
    sorted_index = np.argsort(-cos_sims)
    for i, font_path in enumerate(cj_font_paths):
        font_name = os.path.splitext(os.path.basename(font_path))[0]
        if font_name not in font_to_indexes:
            font_to_indexes[font_name] = []
        font_to_indexes[font_name].append(sorted_index[i])

In [71]:
def choose_closest_font(target_attributes, target_attribute_scores):
    assert len(target_attributes) == len(target_attribute_scores)

    predicted_target_attribute_scores = []
    for attribute, attribute_score in zip(target_attributes, target_attribute_scores):
        tmp = []
        attribute_index = attribute_to_indexes[attribute]
        for i, font_path in enumerate(cj_font_paths):
            font_name = os.path.splitext(os.path.basename(font_path))[0]
            tmp.append(abs(font_to_indexes[font_name][attribute_index] - attribute_score))
        predicted_target_attribute_scores.append(np.array(tmp))
    distances = np.sum(np.array(predicted_target_attribute_scores), axis=0)
    min_index = np.argmin(distances)
    return cj_font_paths[min_index]

In [72]:
target_attributes = ['angular', 'happy',]
target_attribute_scores = [10, 10]
font_path = choose_closest_font(target_attributes, target_attribute_scores)
font_name = os.path.splitext(os.path.basename(font_path))[0]
for attribute in target_attributes:
    attribute_index = attribute_to_indexes[attribute]
    index = font_to_indexes[font_name][attribute_index]
    print(index)

9
6


In [74]:
target_attributes = ['formal', 'artistic', 'italic']
def image_builder(text, *target_attribute_scores):
    font_path = choose_closest_font(target_attributes, target_attribute_scores)
    font = ImageFont.truetype(font_path, char_size)
    line_num = text.count('\n') + 1
    width = int(char_size * len(text) / line_num)
    height = (char_size + int(char_size/3)) * line_num
    image = draw_text_with_new_lines(text, font, width, height)
    return image
sliders = []
with gr.Blocks() as demo:
    with gr.Row():
        text = gr.Textbox(value='A')
    for attribute in target_attributes:
        with gr.Row():
            sliders.append(gr.Slider(0, 301, 150, label=attribute))
    with gr.Row():
        values = [slider.value for slider in sliders]
        output_img1 = gr.Image(value=image_builder(text.value, *values), show_label=False)

    for i, attribute in enumerate(target_attributes):
        sliders[i].change(fn=image_builder, inputs=[text, *sliders], outputs=[output_img1], show_progress=False)
    text.change(fn=image_builder, inputs=[text, *sliders], outputs=[output_img1], show_progress=False)


demo.launch(debug=True, share=True)

  import sys


Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://7c0f30338a045be41a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


  import sys


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://7c0f30338a045be41a.gradio.live




In [44]:
char_size

50