In [None]:
import hydra
import hydra.experimental
import numpy as np
import tqdm
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer
from omegaconf import DictConfig

from torchfly.text.decode import TransformerDecoder
from torchfly.common import set_random_seed, move_to_device

from configure_dataloader import DataLoaderHandler
from model import Generator, TextGAILModel

import logging
import matplotlib.pyplot as plt
import matplotlib

In [None]:
set_random_seed(12)

In [None]:
hydra.experimental.initialize("config")

In [None]:
config = hydra.experimental.compose("config.yaml")
print(config.pretty())

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

In [None]:
dataloader_handler = DataLoaderHandler(config)
test_dataloader = dataloader_handler.test_dataloader(config)
collate_fn = test_dataloader.dataset.collate_fn

In [None]:
device = torch.device("cuda")

In [None]:
model = TextGAILModel(config)
model = model.cuda()

In [None]:
temperatures = (np.arange(3, 10, 1) + 1) / 10.0
print(temperatures)

## MLE

In [None]:
print(config.task.mle_weights_path)
mle_weights = torch.load(config.task.mle_weights_path)
model.generator.load_state_dict(mle_weights)

In [None]:
mle = []
for temperature in temperatures:
    for batch in tqdm.tqdm(test_dataloader):
        batch = collate_fn(batch)
        batch = move_to_device(batch, device)
        batch["temperature"] = temperature

        model.predict(batch)
    metrics = model.get_metrics(reset=True)
    mle.append(metrics['perplexity'])

In [None]:
print(mle)

## TextGAIL

In [None]:
print(config.task.textgail_weights_path)

In [None]:
textgail_weights = torch.load(config.task.textgail_weights_path)
model.load_state_dict(textgail_weights)

In [None]:
textgail = []
for temperature in temperatures:
    for batch in tqdm.tqdm(test_dataloader):
        batch = collate_fn(batch)
        batch = move_to_device(batch, device)
        batch["temperature"] = temperature

        model.predict(batch)
    metrics = model.get_metrics(reset=True)
    textgail.append(metrics['perplexity'])

In [None]:
print(textgail)

## Plot

In [None]:
# # write results
# with open(f"{config.task.name}_perplexity.txt", "w") as f:
#     line = json.dumps({"mle": mle, "textgail": textgail})
#     f.write(line)
#     f.write("\n")

In [None]:
config.task.name = "CommonGEN"

In [None]:
with open(f"{config.task.name}_perplexity.txt") as f:
    line = f.readline()
    line = json.loads(line)
    mle = line["mle"]
    textgail = line["textgail"]

In [None]:
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

In [None]:
fig=plt.figure()
ax = fig.add_subplot(111)
matplotlib.rcParams.update({'font.size': 16})

ax.plot(temperatures, mle, marker="o", color="r", ls="--")
ax.plot(temperatures, textgail, marker="*", color="b")
ax.legend(["GPT-2+MLE", "TextGAIL"])
ax.set_xlabel("Temperature", fontsize=18)
ax.set_ylabel("Perplexity", fontsize=18)
ax.set_title(f"Perplexity vs. Temperature")

major_ticks = np.arange(10, 100, 20)
ax.set_yticks(major_ticks)
ax.set_xticks(temperatures)

ax.grid("on")

plt.savefig(f"{config.task.name} perplexity.png", dpi=300, pad_inches=0.1, bbox_inches='tight')