In [1]:
import os
import json
import torch
from tqdm import tqdm

from utils.cross_lingual_font_retrieval import inclusive_cjk_font_paths
from utils.init_model import model, preprocess, load_model, device
from utils.initialize_font_data import retrieve_font_path, fox_text_four_lines
from utils.transform_image import draw_text_with_new_lines, generate_all_fonts_embedded_images

exclusive_attributes:  ['capitals', 'cursive', 'display', 'italic', 'monospace', 'serif']


In [3]:
roman_font_dir = '../gwfonts'
cjk_font_dir = '../all-fonts'
annotation_file_dir = '../attributeData/cross-lingual-outputs/'
cjk_text = '夏林\n火山'
char_size = 150
roman_font_paths = [os.path.join(roman_font_dir, font_name) for font_name in os.listdir(roman_font_dir)]
cjk_font_paths = [os.path.join(cjk_font_dir, font_name) for font_name in os.listdir(cjk_font_dir)]

In [9]:
checkpoint_path = None
checkpoint_path = 'model_checkpoints/new_best_fox_negative_91011_1011_use_weight_image_file_dir_ex.pt'
checkpoint_path = 'model_checkpoints/best_ViT-B_32_cnn_based_vae_loss_weight_3.0_vae_loss_kl_weight0.001_res_64_9101191011_batch64_aug200_lower_bound_of_scale0.35_use_negative_lr1e-05-0.1.pt'
model = load_model(model, checkpoint_path)

In [10]:
roman_embedded_images = generate_all_fonts_embedded_images(roman_font_paths, fox_text_four_lines, model=model, preprocess=preprocess)
cjk_embedded_images = generate_all_fonts_embedded_images(cjk_font_paths, cjk_text, model=model, preprocess=preprocess)

In [11]:
def evaluate_with_annotation_json(name, mode='Roman'):
  name = name.replace(' ', '_')
  basename = f'{name}-{mode}-50.json'
  data_path = os.path.join(annotation_file_dir, basename) 
  data = json.load(open(data_path, 'r'))
  if mode == 'Roman':
    ref_embedded_images = roman_embedded_images
    option_embedded_images = cjk_embedded_images
  else:
    ref_embedded_images = cjk_embedded_images
    option_embedded_images = roman_embedded_images
  
  correct_count = 0
  for i, (ref_font_name, font1_name, font2_name, choice) in tqdm(enumerate(data)):
    ref_embedded_image = ref_embedded_images[ref_font_name].to(device)
    embedded_image1 = option_embedded_images[font1_name].to(device)
    embedded_image2 = option_embedded_images[font2_name].to(device)
    cos_sim1 = torch.cosine_similarity(ref_embedded_image, embedded_image1, dim=-1)
    cos_sim2 = torch.cosine_similarity(ref_embedded_image, embedded_image2, dim=-1)
    prediction = 'Font A' if cos_sim1 > cos_sim2 else 'Font B'
    if prediction == choice:
      correct_count += 1
  return correct_count / len(data)

In [12]:
rator_names = ['Akihiri_Kiuchi', 'Atsushi_Maruyama', 'Sotaro_Kanazawa', 'Sodai_Furuoka', 'Kento_Shiraki']

In [13]:
result = []
for rator_name in rator_names:
  result.append(evaluate_with_annotation_json(rator_name, mode='Roman'))
print(sum(result) / len(result))
result = []
for rator_name in rator_names:
  result.append(evaluate_with_annotation_json(rator_name, mode='CJK'))
print(sum(result) / len(result))

100it [00:00, 9134.14it/s]
100it [00:00, 11884.24it/s]
100it [00:00, 11967.65it/s]
100it [00:00, 11792.02it/s]
100it [00:00, 11533.27it/s]


0.6340000000000001


100it [00:00, 12023.58it/s]
100it [00:00, 12063.00it/s]
100it [00:00, 12118.41it/s]
100it [00:00, 12164.45it/s]
100it [00:00, 12131.73it/s]

0.6020000000000001



