In [11]:
import torch
import json

import clip
import os
from utils.init_model import load_model, preprocess
from utils.initialize_font_data import (
    retrieve_font_path,
    inclusive_attributes,
    all_gray_scale_image_file_dir,
    cj_font_dir,
    font_dir,
    train_json_path,
    validation_json_path,
    test_json_path,
    validation_font_names,
    all_json,
    fox_text,
    fox_text_four_lines,
)
from utils.evaluate_tools import (
    generate_all_attribute_embedded_prompts,
    user_attribute_choices_count,
    compare_two_fonts,
    evaluate_attribute_comparison_task,
    evaluate_similarity_comparison_task,
    user_similarity_choices,
)

# If using GPU then use mixed precision training.
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Must set jit=False for training
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# count font number
train_font_num = len(list(json.load(open(train_json_path, "r")).keys()))
print(train_font_num)

validation_font_num = len(list(json.load(open(validation_json_path, "r")).keys()))
print(validation_font_num)

test_font_num = len(list(json.load(open(test_json_path, "r")).keys()))
print(test_font_num)

font_names = list(all_json.keys())
font_paths = [
    retrieve_font_path(font_name, font_dir=font_dir) for font_name in font_names
]
# font_names = [os.path.splitext(os.path.basename(f))[0] for f in os.listdir(font_dir)]

fox_text = fox_text

# target_font_names = extract_font_name_from_dir()
target_font_names = list(all_json.keys())

loading JIT archive /home/yuki/.cache/clip/ViT-B-32.pt
120
40
40


In [12]:
from utils.lora_multiheadattention import LoRAConfig
from utils.dataset import TestImageDataset, TestTextDataset

In [20]:
val_image_dataset = TestImageDataset(
    font_dir,
    validation_json_path,
    fox_text,
    dump_image=True,
    image_file_dir=all_gray_scale_image_file_dir,
    preprocess=preprocess,
)
target_attributes = ["angular", "monospace", "capitals", "serif"]
text_dataset = TestTextDataset(
    target_attributes=target_attributes,
    context_length=77,
)

In [21]:
val_font_to_attributes = json.load(open(validation_json_path, "r"))
val_font_to_attributes = {
    font_name: [float(v) for a, v in attributes.items() if a in target_attributes]
    for font_name, attributes in val_font_to_attributes.items()
    if font_name in validation_font_names
}
ground_truth_attributes = torch.tensor([val_font_to_attributes[font_name] for font_name in validation_font_names]).T
ground_truth_attributes = ground_truth_attributes.cpu().detach().numpy()

In [22]:
ground_truth_attributes.shape

(4, 40)

In [23]:
ground_truth_attributes

array([[ 37.49,  33.49,  20.24,  92.26,  22.21,  15.75,  50.42,  35.11,
         72.65,  26.48,  24.37,  41.02,  29.62,  74.69,  23.62,  24.18,
         38.97,  32.17,  20.18,  58.97,  52.11,  68.08,  34.52,  39.88,
         95.05,  24.64,  76.59,  30.56,  46.88,  21.16,  38.32,  15.69,
         32.67,  20.92,  23.61,  32.41,  49.54,  14.98,  48.24,  51.87],
       [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,
          0.  ,   0.  ,   0.  , 100.  ,   0.  ,   0.  ,   0.  ,   0.  ,
          0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,
        100.  ,   0.  , 100.  ,   0.  , 100.  ,   0.  ,   0.  ,   0.  ,
          0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ],
       [  0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,
          0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,
          0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  , 100.  ,   0.  ,
          0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  ,   0.  

In [24]:
lora_config_vision = LoRAConfig(
    r=256,
    alpha=512.0,
    bias=False,
    learnable_alpha=False,
    apply_q=True,
    apply_k=True,
    apply_v=True,
    apply_out=True,
)
lora_config_text = LoRAConfig(
    r=256,
    alpha=1024.0,
    bias=False,
    learnable_alpha=False,
    apply_q=True,
    apply_k=True,
    apply_v=True,
    apply_out=True,
)
checkpoint_path = "model_checkpoints/cv_20_0_task_for_validation_ViT-B_32_bce_lora_t-qkvo_256-1024.0_91011_batch64_aug250_lbound_of_scale0.35_max_attr_num_3_random_p_num_70000_geta0.2_use_negative_til1.0_lr2e-05-0.1_image_file_dir.pt"
tmp_model = load_model(
    model,
    checkpoint_path,
    model_name="ViT-B/32",
    learnable_prompt=False,
    learnable_vision=False,
    precontext_length=48,
    precontext_vision_length=0,
    precontext_dropout_rate=0,
    vpt_applied_layers=None,
    use_oft_vision=False,
    use_oft_text=False,
    oft_config_vision=None,
    oft_config_text=None,
    inject_lora=True,
    lora_config_vision=lora_config_vision,
    lora_config_text=lora_config_text,
)