In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from PIL import Image, ImageDraw, ImageFont
import json
import random

import clip
import os
from tqdm import tqdm
from utils.init_model import load_model, preprocess
from utils.initialize_font_data import retrieve_font_path, inclusive_attributes, gray_scale_image_file_dir, font_dir, cj_font_dir, train_json_path, validation_json_path, test_json_path, all_json, fox_text, fox_text_four_lines
from utils.transform_image import generate_all_fonts_embedded_images
from utils.evaluate_tools import generate_all_attribute_embedded_prompts, user_attribute_choices_count, compare_two_fonts, evaluate_attribute_comparison_task



from cj_fonts import inclusive_fonts, fifty_fonts
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomCrop, RandomRotation, RandomResizedCrop
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager as font_manager
from matplotlib.font_manager import FontProperties

# If using GPU then use mixed precision training.
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Must set jit=False for training
model_name = "ViT-B/32"
# model_name = "ViT-L/14"
model, preprocess = clip.load(model_name, device=device, jit=False)

loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt
exclusive_attributes:  []
loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt


In [2]:
fifty_font_paths = [os.path.join(cj_font_dir, f) for f in fifty_fonts.split('\n') if f != '']

# add font
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)
for font in font_manager.findSystemFonts(cj_font_dir):
    font_manager.fontManager.addfont(font)

ttf_list = font_manager.fontManager.ttflist

# count font number
train_font_num = len(list(json.load(open(train_json_path, 'r')).keys()))
print(train_font_num)

validation_font_num = len(list(json.load(open(validation_json_path, 'r')).keys()))
print(validation_font_num)

test_font_num = len(list(json.load(open(test_json_path, 'r')).keys()))
print(test_font_num)

font_names = list(all_json.keys())
font_paths = [retrieve_font_path(font_name, font_dir=font_dir) for font_name in font_names]
#font_names = [os.path.splitext(os.path.basename(f))[0] for f in os.listdir(font_dir)]

fox_text = fox_text
user_choices_count = user_attribute_choices_count

def retrieve_choices_with_font_name(user_choices, font_name):
    choices = []
    for user_choice in user_choices:
        if user_choice[3] == font_name or user_choice[4] == font_name:
            choices.append(user_choice)
    return choices

#target_font_names = extract_font_name_from_dir()
target_font_names = list(all_json.keys())

120
40
40


In [4]:
type(user_choices_count)

list

In [9]:
appearance_attributes = []
for i in range(len(user_choices_count)):
    if user_choices_count[i][0][0] not in appearance_attributes:
        appearance_attributes.append(user_choices_count[i][0][0])
print(len(appearance_attributes))
print(appearance_attributes)

31
['angular', 'artistic', 'attention-grabbing', 'attractive', 'bad', 'boring', 'calm', 'charming', 'clumsy', 'complex', 'delicate', 'disorderly', 'dramatic', 'formal', 'fresh', 'friendly', 'gentle', 'graceful', 'happy', 'legible', 'modern', 'playful', 'pretentious', 'sharp', 'sloppy', 'soft', 'strong', 'technical', 'thin', 'warm', 'wide']


In [3]:
def evaluate_for_target_font(font_name, embedded_prompts, embedded_images, user_choices_count=user_choices_count):

    result = []
    for user_choice_count in user_choices_count:

        # e.x., (('angular', 'ARSMaquetteWebOne', 'Kenia-Regular'), {'more': 4, 'less': 3})
        (attribute, font_a_name, font_b_name), tmp_ground_truth = user_choice_count
        if font_name is not None:
            if font_a_name != font_name and font_b_name != font_name:
                continue

        ground_truth = 'more' if tmp_ground_truth['more'] > tmp_ground_truth['less'] else 'less'

        """
        if tmp_ground_truth['more'] - 1 > tmp_ground_truth['less']:
            ground_truth = 'more'
        elif tmp_ground_truth['more'] < tmp_ground_truth['less'] - 1:
            ground_truth = 'less'
        else:
            continue
        """

        prediction = compare_two_fonts(attribute, font_a_name, font_b_name, ground_truth, embedded_prompts, embedded_images)
        tmp_result = (attribute, font_a_name, font_b_name, ground_truth, prediction)

        result.append(tmp_result)

    return result

def evaluate_for_target_font_for_each_comparison(font_name, embedded_prompts, embedded_images, user_choices_count=user_choices_count):

    result = []
    for user_choice_count in user_choices_count:

        # e.x., (('angular', 'ARSMaquetteWebOne', 'Kenia-Regular'), {'more': 4, 'less': 3})
        (attribute, font_a_name, font_b_name), tmp_ground_truth = user_choice_count
        if font_name is not None:
            if font_a_name != font_name and font_b_name != font_name:
                continue

        total_num = tmp_ground_truth['more'] + tmp_ground_truth['less']
        ground_truth = 'more' if tmp_ground_truth['more'] > tmp_ground_truth['less'] else 'less'


        prediction = compare_two_fonts(attribute, font_a_name, font_b_name, ground_truth, embedded_prompts, embedded_images)
        if prediction:
            correct_num = tmp_ground_truth[ground_truth]
        else:
            correct_num = total_num - tmp_ground_truth[ground_truth]

        tmp_result = (attribute, font_a_name, font_b_name, ground_truth, correct_num, total_num)

        result.append(tmp_result)

    return result

In [4]:
cross_validation_k = 100
result = []
correct_num = 0
total_num = 0
for i in range(cross_validation_k):
  tmp_test_json_path = f"../attributeData/test_font_to_attribute_values_cross_validation_{cross_validation_k}_{i}.json"
  tmp_test_json = json.load(open(tmp_test_json_path, "r"))
  tmp_test_font_names = list(tmp_test_json.keys())

  signature = f'cross_validation_{cross_validation_k}_{i}_task_for_validation_ViT-B_32_9101191011_batch64_aug190_use_negative_use_negative_loss1e-06_lr2e-05-0.1_image_file_dir'
  signature = f'cross_validation_{cross_validation_k}_{i}_task_for_validation_ViT-B_32_9101191011_batch64_aug190_use_negative_use_negative_loss1e-06_lr2e-05-0.1_image_file_dir'
  signature = f'cross_validation_{cross_validation_k}_{i}_task_for_validation_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug150_lower_bound_of_scale0.35_use_negative_lr2e-05-0.1'
  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug100_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1_image_file_dir'
  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug150_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1_image_file_dir'
  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug200_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1_image_file_dir'
  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug150_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1_image_file_dir'
  checkpoint_path = f'model_checkpoints/{signature}.pt'
  model = load_model(model, checkpoint_path)
  embedded_prompts = generate_all_attribute_embedded_prompts(inclusive_attributes, model=model)
  embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess, image_file_dir=gray_scale_image_file_dir)

  for target_font_name in tmp_test_font_names:
      # tmp_result = evaluate_for_target_font(target_font_name, embedded_prompts, embedded_images)
      # classification_rate = sum([1 for r in tmp_result if r[-1]])/len(tmp_result)
      tmp_result = evaluate_for_target_font_for_each_comparison(target_font_name, embedded_prompts, embedded_images)
      tmp_correct_num, tmp_total_num = sum([r[-2] for r in tmp_result]), sum([r[-1] for r in tmp_result])
      tmp_classification_rate = tmp_correct_num / tmp_total_num
      print(target_font_name, tmp_classification_rate)
      correct_num += tmp_correct_num
      total_num += tmp_total_num

average = correct_num / total_num
print(average)

FileNotFoundError: [Errno 2] No such file or directory: 'model_checkpoints/cross_validation_100_0_ViT-B_32_cnn_based_vae_loss_weight_1.0vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug150_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1_image_file_dir.pt'

In [5]:
correct_num / total_num

0.6341244239631336

In [6]:
cross_validation_k = 40
result = []
correct_num = 0
total_num = 0
for i in tqdm(range(cross_validation_k)):
  tmp_test_json_path = f"../attributeData/test_font_to_attribute_values_cross_validation_{cross_validation_k}_{i}.json"
  tmp_test_json = json.load(open(tmp_test_json_path, "r"))
  tmp_test_font_names = list(tmp_test_json.keys())

  signature = f'cross_validation_{cross_validation_k}_{i}_ViT-B_32_9101191011_batch64_aug140_use_negative_use_negative_loss1e-06_lr2e-05-0.1_image_file_dir'
  checkpoint_path = f'model_checkpoints/{signature}.pt'
  model = load_model(model, checkpoint_path)
  embedded_prompts = generate_all_attribute_embedded_prompts(inclusive_attributes, model=model)
  embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess, image_file_dir=gray_scale_image_file_dir)

  for target_font_name in tmp_test_font_names:
      # tmp_result = evaluate_for_target_font(target_font_name, embedded_prompts, embedded_images)
      # classification_rate = sum([1 for r in tmp_result if r[-1]])/len(tmp_result)
      tmp_result = evaluate_for_target_font_for_each_comparison(target_font_name, embedded_prompts, embedded_images)
      tmp_correct_num, tmp_total_num = sum([r[-2] for r in tmp_result]), sum([r[-1] for r in tmp_result])
      tmp_classification_rate = tmp_correct_num / tmp_total_num
      print(target_font_name, tmp_classification_rate)
      correct_num += tmp_correct_num
      total_num += tmp_total_num
  break

average = correct_num / total_num
print(average)

  0%|          | 0/40 [00:00<?, ?it/s]

FanwoodText-Italic 0.6440092165898618
ShareTech-Regular 0.5961981566820277
IstokWeb-Bold 0.6278801843317973
Lekton-Italic 0.6779953917050692
CabinCondensed 0.6347926267281107
0.6361751152073732





In [4]:
checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_1011_use_weight_image_file_dir_all_fonts_ex.pt'
#checkpoint_path = 'model_checkpoints/best_fox_negative_91011_1011_use_weight_aug_ex.pt'
checkpoint_path = 'model_checkpoints/best_fox_negative_91011_1011_use_weight_ex_2e5.pt'
#checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_91011_multiple_3_1000_aug_ex.pt'
checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_1011_multiple_2_100_ex.pt'
checkpoint_path = 'model_checkpoints/67891011_891011_multiple_3_1000_image_file_dir_aug_ex.pt'
checkpoint_path = 'model_checkpoints/91011_1011_multiple_3_1000_image_file_dir_aug.pt'
checkpoint_path = 'model_checkpoints/891011_1011_multiple_3_1000_image_file_dir_aug.pt'
#checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_1011_use_weight_image_file_dir_ex.pt'
checkpoint_path = 'model_checkpoints/91011_91011_multiple_3_1000_image_file_dir_aug_ex.pt'
checkpoint_path = 'model_checkpoints/91011_1011_multiple_3_1000_aug_ex.pt'
checkpoint_path = 'model_checkpoints/91011_91011_multiple_3_1000_image_file_dir_all_aug_ex.pt'
checkpoint_path = 'model_checkpoints/model_45.pt'
checkpoint_path = None
model = load_model(model, checkpoint_path)
image_file_dir = None
image_file_dir = '../attributeData/images'
embedded_prompts = generate_all_attribute_embedded_prompts(inclusive_attributes, model=model)
embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess, image_file_dir=image_file_dir)

In [5]:
"""
result = evaluate_for_target_font(None, embedded_prompts, embedded_images)
classification_rate = sum([1 for r in result if r[-1]])/len(result)
print(classification_rate)
"""

result = evaluate_for_target_font_for_each_comparison(None, embedded_prompts, embedded_images)
classification_rate = sum([r[-2] for r in result])/sum([r[-1] for r in result])
print(classification_rate)

0.5186578341013824


In [10]:
result = []
for target_font_name in font_names:
    tmp_result = evaluate_for_target_font(target_font_name, embedded_prompts, embedded_images)
    classification_rate = sum([1 for r in tmp_result if r[-1]])/len(tmp_result)
    result.append((target_font_name, classification_rate))
    print(target_font_name, classification_rate)

average = sum([r[1] for r in result])/len(result)
print(average)

FanwoodText-Italic 0.4879032258064516
ShareTech-Regular 0.41935483870967744
IstokWeb-Bold 0.5080645161290323
Lekton-Italic 0.5120967741935484
CabinCondensed 0.4717741935483871
PressStart2P-Regular 0.5645161290322581
ModernAntiqua-Regular 0.40725806451612906
Arvo-Bold 0.5
Satisfy 0.5887096774193549
Muli 0.5282258064516129
Palatino-Roman 0.5725806451612904
SourceCodePro-ExtraLight 0.3790322580645161
Montez-Regular 0.5645161290322581
Raleway-SemiBold 0.4596774193548387
UbuntuCondensed-Regular 0.42338709677419356
Slackey 0.5887096774193549
BadScript-Regular 0.6814516129032258
AveriaLibre-LightItalic 0.532258064516129
Rambla-Regular 0.5
Roboto-MediumItalic 0.5564516129032258
Julee-Regular 0.6129032258064516
PTSerif-BoldItalic 0.33064516129032256
MervaleScript-Regular 0.43548387096774194
CrimsonText-Semibold 0.5403225806451613
IM_FELL_English_Roman 0.5282258064516129
Amethysta-Regular 0.5443548387096774
TitilliumWeb-ThinItalic 0.4596774193548387
CantoraOne-Regular 0.4959677419354839
Raleway-

In [8]:
result = []
for target_font_name in target_font_names:
    checkpoint_path = retrieve_one_leave_out_model_path(target_font_name)
    print(checkpoint_path)
    model = load_model(model, checkpoint_path)
    embedded_prompts = generate_all_attribute_embedded_prompts(inclusive_attributes, model=model)
    embedded_images = generate_all_fonts_embedded_images(font_paths, fox_text, model=model, preprocess=preprocess)
    tmp_result = evaluate_for_target_font(target_font_name, embedded_prompts, embedded_images)
    classification_rate = sum([1 for r in tmp_result if r[-1]])/len(tmp_result)
    result.append((target_font_name, classification_rate))
    print(target_font_name, classification_rate)


model_checkpoints/one_leave_out_CrimsonText-Semibold.pt
CrimsonText-Semibold 0.7338709677419355
model_checkpoints/one_leave_out_Amethysta-Regular.pt
Amethysta-Regular 0.6895161290322581
model_checkpoints/one_leave_out_Tinos-Bold.pt
Tinos-Bold 0.7540322580645161
model_checkpoints/one_leave_out_Rosario-Bold.pt
Rosario-Bold 0.7338709677419355
model_checkpoints/one_leave_out_MavenProBlack.pt
MavenProBlack 0.75
model_checkpoints/one_leave_out_Parisienne-Regular.pt


KeyboardInterrupt: 

In [16]:
cr = evaluate_attribute_comparison_task(validation_font_names, fox_text, model=model, image_file_dir=image_file_dir)
print(cr)

0.7262147221414024


In [16]:

average = sum([r[1] for r in result])/len(result)
print(average)


0.7456451612903231


In [29]:
average = sum([r[1] for r in result])/len(result)
print(average)

0.8222989405791508
