In [None]:
import hydra
import hydra.experimental
import numpy as np
import tqdm
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer
from omegaconf import DictConfig

from torchfly.text.decode import TransformerDecoder
from torchfly.common import set_random_seed, move_to_device

from configure_dataloader import DataLoaderHandler
from model import Generator, TextGAILModel

import logging
import matplotlib.pyplot as plt
import matplotlib

In [None]:
# Please specify the weight path for evaluation
textgail_weights_path = ""
mle_weights_path = ""

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

In [None]:
dataloader_handler = DataLoaderHandler(config)
test_dataloader = dataloader_handler.test_dataloader(config)
# collate_fn = dataloader_handler.collator.sample_collate

In [None]:
device = torch.device("cuda")

In [None]:
model = TextGAILModel(config)
model = model.cuda()

In [None]:
print(mle_weights_path)

In [None]:
temperatures = (np.arange(3, 10, 1) + 1) / 10.0
print(temperatures)

## MLE Evaluation

In [None]:
mle_weights = torch.load(mle_weights_path)
model.generator.load_state_dict(mle_weights)

In [None]:
mle = []
for temperature in temperatures:
    for batch in tqdm.tqdm(test_dataloader):
        batch = move_to_device(batch, device)
        batch["temperature"] = temperature

        model.predict(batch)
    metrics = model.get_metrics(reset=True)
    mle.append(metrics['perplexity'])

In [None]:
print(mle)

## TextGAIL Evaluation

In [None]:
print(textgail_weights_path)

In [None]:
textgail_weights = torch.load(textgail_weights_path)
model.load_state_dict(textgail_weights)

In [None]:
textgail = []
for temperature in temperatures:
    for batch in tqdm.tqdm(test_dataloader):
        batch = move_to_device(batch, device)
        batch["temperature"] = temperature

        model.predict(batch)
    metrics = model.get_metrics(reset=True)
    textgail.append(metrics['perplexity'])

In [None]:
print(textgail)

In [None]:
# # Store the intermediate results
# with open(f"{config.task.name}_perplexity.txt", "w") as f:
#     line = json.dumps({"mle": mle, "textgail": textgail})
#     f.write(line)
#     f.write("\n")

In [None]:
config.task.name = "EMNLP_NEWS"

In [None]:
with open(f"{config.task.name}_perplexity.txt") as f:
    data = json.loads(f.read())
mle = data["mle"]
textgail = data["textgail"]

In [None]:
matplotlib.rcParams.update({'font.size': 16})

In [None]:
fig=plt.figure()
ax = fig.add_subplot(111)


ax.plot(temperatures, mle, marker="o", color="r", ls="--")
ax.plot(temperatures, textgail, marker="*", color="b")
ax.legend(["GPT-2+MLE", "TextGAIL"])
ax.set_xlabel("Temperature", fontsize=18)
ax.set_ylabel("Perplexity", fontsize=18)
ax.set_title(f"Perplexity vs. Temperature")

# major_ticks = np.arange(10, 40, 5)
# ax.set_yticks(major_ticks)
ax.set_xticks(temperatures)

ax.grid("on")

plt.savefig(f"{config.task.name} perplexity.png", dpi=300, pad_inches=0.1, bbox_inches='tight')