In [None]:
import hydra
import hydra.experimental
import numpy as np
import tqdm
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer
from omegaconf import DictConfig

from torchfly.text.decode import TransformerDecoder
from torchfly.common import set_random_seed, move_to_device

from configure_dataloader import DataLoaderHandler
from model import Generator, TextGAILModel

import logging

In [None]:
random_seed = 1
set_random_seed(random_seed)

In [None]:
hydra.experimental.initialize("config")

In [None]:
config = hydra.experimental.compose("config.yaml")
print(config.pretty())

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

In [None]:
dataloader_handler = DataLoaderHandler(config)
test_dataloader = dataloader_handler.test_dataloader(config)

In [None]:
device = torch.device("cuda")

In [None]:
model = TextGAILModel(config)
model = model.cuda()

In [None]:
decoder = TransformerDecoder(config.decode)
decoder.register_generator(model.generator.decoder)
decoder.register_tokenizer(tokenizer)

In [None]:
# location where the generations will be stored
os.makedirs(config.task.name,exist_ok=True)

In [None]:
# Please make sure you have specified the corresponding MLE weight path and TextGAIL weights path
print(config.task.mle_weights_path)
print(config.task.textgail_weights_path)

In [None]:
temperatures = (np.arange(5) + 1) / 5.0
print(temperatures)

## MLE Generation

In [None]:
mle_weights = torch.load(config.task.mle_weights_path)
model.generator.load_state_dict(mle_weights)

In [None]:
for temperature in temperatures:
    f_write = open(f"{config.task.name}/mle_{temperature}_{random_seed}.txt", "w")

    for i in tqdm.trange(100):

        results = decoder.generate(input_ids, temperature=temperature)
        generated = []

        for i in range(len(results["tokens"])):
            res = tokenizer.decode(results["tokens"][i][0][1:-1].tolist())
            generated.append(res)

        for gen in generated:
            f_write.write(json.dumps(gen))
            f_write.write("\n")
        
    f_write.close()

## TextGAIL Generation

In [None]:
textgail_weights = torch.load(config.task.textgail_weights_path)
model.load_state_dict(textgail_weights)

In [None]:
for temperature in temperatures:
    f_write = open(f"{config.task.name}/textgail_{temperature}_{random_seed}.txt", "w")

    for i in tqdm.trange(100):
        results = decoder.generate(input_ids, temperature=temperature)
        generated = []

        for i in range(len(results["tokens"])):
            res = tokenizer.decode(results["tokens"][i][0][1:-1].tolist())
            generated.append(res)

        for gen in generated:
            f_write.write(json.dumps(gen))
            f_write.write("\n")
        
    f_write.close()