# 現実写真以外を含む、文字なし、固有名詞全部なし、キャプション、ほかの画像の大喜利

In [1]:
from utils import *

In [2]:
# PCによって変更する
NUM_WORKERS = 16
# データセットが既に存在する場合に，再度作り直すか
REST_DATA = False

EXPERIENCE_NUMBER = "002"

# 現実写真以外を使用するか
USE_UNREAL_IMAGE = True
# 文字を含む画像を使用するか
USE_WORD_IMAGE = False
# 固有名詞を含む大喜利を使用するか
USE_UNIQUE_NOUN_BOKE = False
# 負例としてキャプションを使用するか
USE_CAPTION = True
# 負例として別の画像の大喜利を使用するか
USE_MISS_BOKE = True
# 正例の何倍の別の画像の大喜利を使用するか
NUM_RATIO_MISS_BOKE = 1

RESULT_DIR = f"../../results/Boke_Judge/{EXPERIENCE_NUMBER}/"
if not os.path.exists("../../results/Boke_Judge/"):
    os.mkdir("../../results/Boke_Judge/")
if not os.path.exists(RESULT_DIR):
    os.mkdir(RESULT_DIR)

# データセットの作成

In [3]:
if not os.path.exists(f"{RESULT_DIR}test_caption_datas.json") or REST_DATA:
    
    boke_datas = list()
    caption_datas = list()

    max_num_boke = 0
    for JP in tqdm(os.listdir(DATA_DIR)):
        N = int(JP.split(".")[0])

        with open(f"{DATA_DIR}{JP}", "r") as f:
            a = json.load(f)

        image_information = a["image_information"]
        is_photographic_probability = image_information["is_photographic_probability"]
        ja_caption = image_information["ja_caption"]
        ocr = image_information["ocr"]

        # 現実写真以外を除去
        if not USE_UNREAL_IMAGE:
            if is_photographic_probability < 0.8: continue
            
        # 文字のある画像を除去
        if not USE_WORD_IMAGE:
            if len(ocr) != 0: continue

        bokes = a["bokes"]

        max_num_boke = max(max_num_boke, len(a["bokes"]))
        for i, B in enumerate(bokes):

            # 固有名詞を含む大喜利を除去
            if not USE_UNIQUE_NOUN_BOKE:
                if len(B["unique_nouns"]) != 0: continue

            boke_datas.append({
                "boke_number": i,
                "image_number": N
            })

        caption_datas.append({
            "caption_number": N,
            "image_number": N
        })

    len(boke_datas), len(caption_datas)

In [4]:
if not os.path.exists(f"{RESULT_DIR}test_caption_datas.json") or REST_DATA:
    
    train_boke_datas, test_boke_datas = train_test_split(boke_datas, test_size = 0.01)
    train_caption_datas, test_caption_datas = train_test_split(caption_datas, test_size = 0.01)

    with open(f"{RESULT_DIR}train_boke_datas.json", "w") as f:
        json.dump(train_boke_datas, f)
    with open(f"{RESULT_DIR}train_caption_datas.json", "w") as f:
        json.dump(train_caption_datas, f)

    with open(f"{RESULT_DIR}test_boke_datas.json", "w") as f:
        json.dump(test_boke_datas, f)
    with open(f"{RESULT_DIR}test_caption_datas.json", "w") as f:
        json.dump(test_caption_datas, f)

# モデルの学習

In [5]:
with open(f"{RESULT_DIR}train_boke_datas.json", "r") as f:
    train_boke_datas = json.load(f)
with open(f"{RESULT_DIR}train_caption_datas.json", "r") as f:
    train_caption_datas = json.load(f)

with open(f"{RESULT_DIR}test_boke_datas.json", "r") as f:
    test_boke_datas = json.load(f)
with open(f"{RESULT_DIR}test_caption_datas.json", "r") as f:
    test_caption_datas = json.load(f)

len(train_boke_datas), len(train_caption_datas), len(test_boke_datas), len(test_caption_datas)

(2297194, 251799, 23204, 2544)

In [None]:
train_loss_history = []
train_accuracy_history = []
test_loss_history = []
test_accuracy_history = []

model = BokeJudgeModel()
optimizer = optim.AdamW(model.parameters(), lr = LEARNING_RATO)

for epoch in range(EPOCH):
    # train
    train_loss_obj = 0.0
    train_accuracy_obj = 0.0
    model.train()
    train_dataloader = make_dataloader(train_boke_datas, train_caption_datas,
                                       use_caption = USE_CAPTION, use_miss_boke = USE_MISS_BOKE, num_ratio_miss_boke = NUM_RATIO_MISS_BOKE,
                                       num_workers = NUM_WORKERS)
    pb = tqdm(train_dataloader, desc = f"Epoch {epoch+1}/{EPOCH}")
    
    for CIF, CSF, LSF, TS in pb:
        loss, accuracy = train_step(model, optimizer, (CIF, CSF, LSF), TS)
        train_loss_obj += loss
        train_accuracy_obj += accuracy
        pb.set_postfix({"train_loss": train_loss_obj / (pb.n + 1), "train_accuracy": train_accuracy_obj / (pb.n + 1)})

    train_loss = train_loss_obj / len(train_dataloader)
    train_accuracy = train_accuracy_obj / len(train_dataloader)

    # test
    test_loss_obj = 0.0
    test_accuracy_obj = 0.0
    model.eval()
    test_dataloader = make_dataloader(test_boke_datas, test_caption_datas,
                                      use_caption = USE_CAPTION, use_miss_boke = USE_MISS_BOKE, num_ratio_miss_boke = NUM_RATIO_MISS_BOKE,
                                       num_workers = NUM_WORKERS)
    pb = tqdm(test_dataloader, desc = "Evaluating")

    for CIF, CSF, LSF, TS in pb:
        loss, accuracy = evaluate(model, (CIF, CSF, LSF), TS)
        test_loss_obj += loss
        test_accuracy_obj += accuracy
        pb.set_postfix({"test_loss": test_loss_obj / (pb.n + 1), "test_accuracy": test_accuracy_obj / (pb.n + 1)})

    test_loss = test_loss_obj / len(test_dataloader)
    test_accuracy = test_accuracy_obj / len(test_dataloader)

    print(f"Epoch: {epoch+1}/{EPOCH}, "
          f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
          f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    train_loss_history.append(train_loss)
    train_accuracy_history.append(train_accuracy)
    test_loss_history.append(test_loss)
    test_accuracy_history.append(test_accuracy)

    if max(test_accuracy_history) == test_accuracy:
        torch.save(model.state_dict(), f"{RESULT_DIR}best_model_weights.pth")

with open(f"{RESULT_DIR}history.json", "w") as f:
    json.dump({
        "train_loss": train_loss_history,
        "train_accuracy": train_accuracy_history,
        "test_loss": test_loss_history,
        "test_accuracy": test_accuracy_history
    }, f)

2297194it [00:01, 1173016.86it/s]


num data: 4846187


Epoch 1/25:  19%|█▉        | 14575/75722 [02:53<12:00, 84.89it/s, train_loss=0.629, train_accuracy=0.615]

In [None]:
fig = plt.figure(figsize = (10, 5))
ax = fig.add_subplot(1, 2, 1)
ax.plot(train_loss_history, label = "train")
ax.plot(test_loss_history, label = "test")
ax.set_xlabel("epoch")
ax.set_ylabel("loss")
ax.legend()
ax.grid()

ax = fig.add_subplot(1, 2, 2)
ax.plot(train_accuracy_history, label = "train")
ax.plot(test_accuracy_history, label = "test")
ax.set_xlabel("epoch")
ax.set_ylabel("accuracy")
ax.legend()
ax.grid()