In [10]:
import os
import time
from PIL import Image
import re
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_tensor, to_pil_image
from IPython.display import display
from tqdm import tqdm

from clip.simple_tokenizer import SimpleTokenizer
from utils.initialize_font_data import (
    font_dir,
    fox_text_four_lines,
    train_json_path,
    exclusive_attributes,
    gray_scale_image_file_dir,
    all_gray_scale_image_file_dir,
    fox_text,
    train_json,
)
from utils.init_model import (
    model,
    preprocess,
    my_preprocess,
    preprocess_for_single_character,
    device,
    convert_weights,
    model_name,
    _download,
    _MODELS,
    load,
)
from utils.dataset import MyDataset, PairedImageDataset, set_image_tensors
from utils.lora_clip import LoRACLIP

tokenizer = SimpleTokenizer()

In [2]:
def extract_attributes_from_decoded_text(text: str) -> str:
    text = (
        text.replace("<|startoftext|>", " ")
        .replace("font <|endoftext|>", "")
        .replace("!", "")
    )
    # pattern = r"(\bnot\s+\b\w+|\b\w+)"
    # compiled_pattern = re.compile(pattern)
    # attributes = compiled_pattern.findall(text)

    attributes = text.split(",")

    # 'attention - grabbing' to 'attention-grabbing'
    attributes = [attribute.replace(" - ", "-") for attribute in attributes]

    # remove the first and last empty strings
    attributes = [attribute[1:-1] for attribute in attributes]

    return attributes

In [15]:
dataset = MyDataset(
    font_dir=font_dir,
    json_path=train_json_path,
    texts_for_font_image=[fox_text_four_lines],
    char_size=150,
    attribute_threshold=50,
    use_negative=True,
    use_weight=False,
    use_negative_loss=True,
    preprocess=my_preprocess,
    use_multiple_attributes=True,
    use_random_attributes=True,
    max_sample_num=3,
    random_prompts_num=100,
    exclusive_attributes=exclusive_attributes,
    image_file_dir=all_gray_scale_image_file_dir,
    dump_image=True,
    single_character=False,
    sample_num_each_epoch=100,
)

100%|██████████| 120/120 [00:01<00:00, 100.95it/s]


In [16]:
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [18]:
start = time.time()
count = 0
for batch in tqdm(data_loader):
    images, texts, font_indices, attribute_indices = batch
    mask_matrix = dataset.mask_font_idx_signed_attribute_matrix_ground_truth_fast(
        font_indices, attribute_indices
    )

    # check if the attribute_indices are correct
    for i in range(len(images)):
        int_text = [int(token) for token in texts[i]]
        decoded_text = tokenizer.decode(int_text)
        extracted_attributes = extract_attributes_from_decoded_text(decoded_text)
        extracted_attribute_indices = [dataset.signed_attribute_to_index(
            extracted_attribute) for extracted_attribute in extracted_attributes]
        # fill extracted_attribute_indices with 0
        extracted_attribute_indices += [0] * (3 - len(extracted_attribute_indices))
        assert torch.all(attribute_indices[i] == torch.Tensor(
            extracted_attribute_indices)), f'{attribute_indices[i]} != {extracted_attribute_indices}'

    # check if the mask_matrix is correct
    for i in range(len(images)):
        font_idx = font_indices[i]
        font_name = dataset.font_names[font_idx]
        for j in range(len(images)):
            mask_value = mask_matrix[i][j]
            if mask_value == 1:
                for attribute_idx in attribute_indices[j]:
                    if attribute_idx != 0:
                        signed_attribute = dataset.index_to_signed_attribute(
                            attribute_idx)
                        ground_truth = - \
                            1 if signed_attribute.startswith('not ') else 1
                        attribute = signed_attribute.replace('not ', '')
                        attribute_sign = 1 if float(
                            train_json[font_name][attribute]) >= 50 else -1
                        assert attribute_sign == ground_truth, f'{font_name} {attribute} {attribute_sign} {ground_truth}'
            else:
                flag = False
                for attribute_idx in attribute_indices[j]:
                    if attribute_idx != 0:
                        signed_attribute = dataset.index_to_signed_attribute(
                            attribute_idx)
                        ground_truth = - \
                            1 if signed_attribute.startswith('not ') else 1
                        attribute = signed_attribute.replace('not ', '')
                        attribute_sign = 1 if float(
                            train_json[font_name][attribute]) >= 50 else -1
                        if attribute_sign != ground_truth:
                            flag = True
                assert flag, f'{i} {j} {font_name} {attribute} {attribute_sign} {ground_truth}'

    # display(to_pil_image(images[-1]))
    # print(mask_matrix)
    # print(extracted_attributes)
    # print(font_indices[-1])
    # print(attribute_indices[-1])

end = time.time()
print(end - start)

100%|██████████| 3000/3000 [00:03<00:00, 804.52it/s] 

3.7323505878448486





In [None]:
6.805548429489136

In [None]:
attribute_indices[1][0]