In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from PIL import Image, ImageDraw, ImageFont
import json
import random

import clip
import os
from tqdm import tqdm
from utils.init_model import load_model, preprocess
from utils.initialize_font_data import retrieve_font_path, inclusive_attributes, all_gray_scale_image_file_dir, cj_font_dir, font_dir, train_json_path, validation_json_path, test_json_path, all_json, fox_text, fox_text_four_lines
from utils.transform_image import generate_all_fonts_embedded_images
from evals.evaluate_tools import generate_all_attribute_embedded_prompts, user_attribute_choices_count, compare_two_fonts, evaluate_attribute_comparison_task
from cj_fonts import inclusive_fonts, fifty_fonts
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomCrop, RandomRotation, RandomResizedCrop
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager as font_manager
from matplotlib.font_manager import FontProperties

# If using GPU then use mixed precision training.
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Must set jit=False for training
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt
exclusive_attributes:  []
loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt


In [2]:
fifty_font_paths = [
    os.path.join(cj_font_dir, f) for f in fifty_fonts.split("\n") if f != ""
]

# add font
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)
for font in font_manager.findSystemFonts(cj_font_dir):
    font_manager.fontManager.addfont(font)

ttf_list = font_manager.fontManager.ttflist

# count font number
train_font_num = len(list(json.load(open(train_json_path, "r")).keys()))
print(train_font_num)

validation_font_num = len(list(json.load(open(validation_json_path, "r")).keys()))
print(validation_font_num)

test_font_num = len(list(json.load(open(test_json_path, "r")).keys()))
print(test_font_num)

font_names = list(all_json.keys())
font_paths = [
    retrieve_font_path(font_name, font_dir=font_dir) for font_name in font_names
]
# font_names = [os.path.splitext(os.path.basename(f))[0] for f in os.listdir(font_dir)]

fox_text = fox_text
user_choices_count = user_attribute_choices_count


def retrieve_choices_with_font_name(user_choices, font_name):
    choices = []
    for user_choice in user_choices:
        if user_choice[3] == font_name or user_choice[4] == font_name:
            choices.append(user_choice)
    return choices


# target_font_names = extract_font_name_from_dir()
target_font_names = list(all_json.keys())

120
40
40


In [3]:
font_id_txt = '../similarity/fontNames.txt'
user_choices_csv = '../similarity/compsCount.csv'

with open(user_choices_csv, 'r') as f:
    tmp_user_choices = f.read().split('\n')

with open(font_id_txt, 'r') as f:
    tmp_font_names = f.read().split('\n')
font_names = []
for tmp_font_name in tmp_font_names:
    if tmp_font_name != '':
        font_names.append(tmp_font_name)

In [4]:
user_choices = []
for tmp_user_choice in tmp_user_choices:
    if tmp_user_choice == '':
        continue
    reference_font_id, font_a_id, font_b_id, vote_count_a, vote_count_b  = (int(float(e)) for e in tmp_user_choice.split(',') if e != '')
    reference_font_name = font_names[reference_font_id]
    font_a_name = font_names[font_a_id]
    font_b_name = font_names[font_b_id]
    user_choice = [reference_font_name, font_a_name, font_b_name, vote_count_a, vote_count_b]
    user_choices.append(user_choice)

In [5]:
def evaluate_for_target_font(font_name, embedded_images, user_choices=user_choices, device=device):

    result = []
    for user_choice in user_choices:

        # e.x., ['Muli', 'ARSMaquetteWebOne', 'Kenia-Regular', 10, 4]
        reference_font_name, font_a_name, font_b_name, vote_count_a, vote_count_b  = user_choice
        if font_name is not None:
            if font_name not in [reference_font_name, font_a_name, font_b_name]:
                continue

        ground_truth = vote_count_a > vote_count_b

        embedded_r = embedded_images[reference_font_name].to(device)
        embedded_a = embedded_images[font_a_name].to(device)
        embedded_b = embedded_images[font_b_name].to(device)

        # calculate the cos similarity
        cos_sim_a = torch.cosine_similarity(embedded_r, embedded_a, dim=-1)
        cos_sim_b = torch.cosine_similarity(embedded_r, embedded_b, dim=-1)
        tmp_prediction = cos_sim_a.item() > cos_sim_b.item()
        prediction = ground_truth == tmp_prediction
        tmp_result = (reference_font_name, font_a_name, font_b_name, ground_truth, prediction)
        result.append(tmp_result)
    return result

def evaluate_for_target_font_for_each_comparison(font_name, embedded_images, user_choices=user_choices, device=device):

    result = []
    for user_choice in user_choices:

        # e.x., ['Muli', 'ARSMaquetteWebOne', 'Kenia-Regular', 10, 4]
        reference_font_name, font_a_name, font_b_name, vote_count_a, vote_count_b  = user_choice
        if font_name is not None:
            if font_name not in [reference_font_name, font_a_name, font_b_name]:
                continue

        ground_truth = vote_count_a > vote_count_b

        embedded_r = embedded_images[reference_font_name].to(device)
        embedded_a = embedded_images[font_a_name].to(device)
        embedded_b = embedded_images[font_b_name].to(device)

        # calculate the cos similarity
        cos_sim_a = torch.cosine_similarity(embedded_r, embedded_a, dim=-1)
        cos_sim_b = torch.cosine_similarity(embedded_r, embedded_b, dim=-1)
        tmp_prediction = cos_sim_a.item() > cos_sim_b.item()
        prediction = ground_truth == tmp_prediction
        total_num = vote_count_a + vote_count_b
        if prediction:
            correct_num = vote_count_a if ground_truth else vote_count_b
        else:
            correct_num = vote_count_b if ground_truth else vote_count_a
        tmp_result = (reference_font_name, font_a_name, font_b_name, ground_truth, prediction, correct_num, total_num) 
        result.append(tmp_result)
    return result

In [8]:
cross_validation_k = 100
correct_num = 0
total_num = 0
# for i in tqdm(range(cross_validation_k)):
for i in range(cross_validation_k):
  tmp_test_json_path = f"../attributeData/test_font_to_attribute_values_cross_validation_{cross_validation_k}_{i}.json"
  tmp_test_json = json.load(open(tmp_test_json_path, "r"))
  tmp_test_font_names = list(tmp_test_json.keys())

  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_9101191011_batch64_aug140_use_negative_use_negative_loss1e-06_lr2e-05-0.1_image_file_dir'
  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug150_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1_image_file_dir'
  checkpoint_path = f'model_checkpoints/{signature}.pt'
  model = load_model(model, checkpoint_path)
  embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess, image_file_dir=all_gray_scale_image_file_dir)

  for target_font_name in tmp_test_font_names:
      tmp_result = evaluate_for_target_font_for_each_comparison(target_font_name, embedded_images)
      tmp_correct_num, tmp_total_num = sum([e[-2] for e in tmp_result]), sum([e[-1] for e in tmp_result])
      tmp_classification_rate = tmp_correct_num / tmp_total_num
      print(target_font_name, tmp_classification_rate)
      correct_num += tmp_correct_num
      total_num += tmp_total_num

average = correct_num / total_num
print(average)

FanwoodText-Italic 0.702276707530648
ShareTech-Regular 0.7659574468085106
IstokWeb-Bold 0.7714285714285715
Lekton-Italic 0.7109704641350211
CabinCondensed 0.6708595387840671
PressStart2P-Regular 0.8218390804597702
ModernAntiqua-Regular 0.72
Arvo-Bold 0.7891566265060241
Satisfy 0.839851024208566
Muli 0.745958429561201
Palatino-Roman 0.7303370786516854
SourceCodePro-ExtraLight 0.780439121756487
Montez-Regular 0.768
Raleway-SemiBold 0.769825918762089
UbuntuCondensed-Regular 0.7127468581687613
Slackey 0.7433628318584071
BadScript-Regular 0.7518248175182481
AveriaLibre-LightItalic 0.714622641509434
Rambla-Regular 0.6963979416809606
Roboto-MediumItalic 0.725
Julee-Regular 0.7013274336283186
PTSerif-BoldItalic 0.7438692098092643
MervaleScript-Regular 0.8865979381443299
CrimsonText-Semibold 0.756043956043956
IM_FELL_English_Roman 0.7893835616438356
Amethysta-Regular 0.7755443886097152
TitilliumWeb-ThinItalic 0.6942675159235668
CantoraOne-Regular 0.7186234817813765
Raleway-Bold 0.73458445040214

In [7]:
checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_1011_use_weight_image_file_dir_all_fonts_ex.pt'
checkpoint_path = 'model_checkpoints/best_fox_negative_91011_1011_use_weight_ex_2e5.pt'
checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_91011_multiple_3_1000_aug_ex.pt'
checkpoint_path = 'model_checkpoints/91011_1011_use_weight_aug_ex.pt'
checkpoint_path = 'model_checkpoints/67891011_891011_multiple_3_1000_image_file_dir_aug_ex.pt'
checkpoint_path = 'model_checkpoints/91011_91011_multiple_3_1000_image_file_dir_aug_ex.pt'
checkpoint_path = 'model_checkpoints/91011_1011_multiple_3_1000_aug_ex.pt'
checkpoint_path = None
#checkpoint_path = 'model_checkpoints/model_5590.pt'
#checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_1011_use_weight_image_file_dir_ex.pt'
#checkpoint_path = 'model_checkpoints/best_fox_negative_91011_1011_use_weight_aug_ex.pt'
#checkpoint_path = 'model_checkpoints/model_767.pt'
image_file_dir = None
image_file_dir = '../attributeData/images'
model = load_model(model, checkpoint_path)
embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess, image_file_dir=image_file_dir)

KeyboardInterrupt: 

In [None]:
"""
result = evaluate_for_target_font(None, embedded_images, user_choices=user_choices)
classification_rate = sum([1 for r in result if r[-1]])/len(result)
print(classification_rate)
"""

result = evaluate_for_target_font_for_each_comparison(None, embedded_images, user_choices=user_choices)
classification_rate = sum([r[-2] for r in result])/sum([r[-1] for r in result])
print(classification_rate)


0.6768399899522733


In [None]:
result = []
for target_font_name in font_names:
    #checkpoint_path = retrieve_one_leave_out_model_path(target_font_name)
    #print(checkpoint_path)
    #model = load_model(model, checkpoint_path)
    #embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess)
    tmp_result = evaluate_for_target_font(target_font_name, embedded_images)
    classification_rate = sum([e[-1] for e in tmp_result]) / len(tmp_result)
    result.append((target_font_name, classification_rate))
    print(target_font_name, classification_rate)

ARSMaquetteWebOne 0.7777777777777778
Acme-Regular 0.7857142857142857
AdventPro-SemiBold 0.9428571428571428
Aldrich 0.6774193548387096
Alegreya-BoldItalic 0.8214285714285714
AllertaStencil-Regular 0.8055555555555556
Amethysta-Regular 0.75
Andada-Bold 0.7878787878787878
Andada-Italic 0.825
AndadaSC-Bold 0.717948717948718
AnonymousPro 0.6111111111111112
ArchivoNarrow-Regular 0.75
ArialRoundedMTBold 0.75
ArialUnicodeMS 0.8125
Arimo-Bold 0.7560975609756098
Arimo-BoldItalic 0.9459459459459459
Arizonia-Regular 0.9090909090909091
Arvo-Bold 0.9583333333333334
Arvo-BoldItalic 0.6666666666666666
Arvo-Italic 0.8205128205128205
Asap-Regular 0.8666666666666667
Asset 0.8
Astloch-Bold 0.7096774193548387
AveriaLibre-LightItalic 0.8064516129032258
AveriaSansLibre-Italic 0.723404255319149
AveriaSerifLibre-Italic 0.782608695652174
BadScript-Regular 0.8536585365853658
Bello-Pro 0.7777777777777778
BenchNine-Regular 0.8780487804878049
Bentham-Regular 0.75
Bevan 0.7948717948717948
BilboSwashCaps-Regular 0.972

In [None]:
average = np.mean([tmp_result[1] for tmp_result in result])
print(average)

0.8053088222407141


In [None]:
average = np.mean([tmp_result[1] for tmp_result in result])
print(average)

0.7904783004709554
