In [16]:
# Which checkpoint should we load?
path_to_model = '../models/hrqvae_paralex/'

# Path to datasets+vocabs
DATA_PATH = '../data/'

# What are the inputs?
examples = [
    {'sem_input': 'what is the weight of an average moose?'},
    {'sem_input': 'which country are moose from?'},
]

# How many paraphrases to generate?
top_k = 3

# **** You shouldn't need to change anything below here ****

# Generating multiple paraphrases from the model is a little complicated and requires more than one step:
# First, we need to get z_sem for the inputs. 
# Then, we sample multiple sketches {q_i} from the code prediction network.
# Finally, we run the decoder using each sketch and z_sem as input

import json, jsonlines, sacrebleu, logging, copy
from tqdm import tqdm
from torchseq.agents.para_agent import ParaphraseAgent
from torchseq.datasets.json_loader import JsonDataLoader
from torchseq.utils.config import Config

import torch

# Create the dataset and model
with open(path_to_model + "/config.json") as f:
    cfg_dict = json.load(f)
    
logger = logging.getLogger('TorchSeq')
    
config = Config(cfg_dict)
checkpoint_path = path_to_model + "/model/checkpoint.pt"
agent = ParaphraseAgent(config=config, run_id=None,  output_path=None, data_path=DATA_PATH, silent=False, verbose=False, training_mode=False)

# Load the checkpoint
agent.load_checkpoint(checkpoint_path)
agent.model.eval()


# Generate encodings
logger.info("Generating encodings for eval set")
config_gen_eval = copy.deepcopy(config.data)
config_gen_eval["dataset"] = "json"
config_gen_eval["json_dataset"] = {
    "path": config.eval.metrics.sep_ae.eval_dataset,
    "field_map": [
        {"type": "copy", "from": "sem_input", "to": "s2"},
        {"type": "copy", "from": "sem_input", "to": "template"},
        {"type": "copy", "from": "sem_input", "to": "s1"},
    ],
}
config_gen_eval["eval"]["topk"] = 1

data_loader = JsonDataLoader(
    data_path=agent.data_path,
    config=Config(config_gen_eval),
    dev_samples=examples,
)

config.eval.data["sample_outputs"] = False

post_bottleneck = (
    "_after_bottleneck" if agent.config.bottleneck.code_predictor.get("post_bottleneck", False) else ""
)

_, _, (_, _, _), memory_eval = agent.inference(
    data_loader.valid_loader,
    memory_keys_to_return=[f"sep_encoding_1{post_bottleneck}", f"sep_encoding_2{post_bottleneck}", "vq_codes"],
)

if not agent.config.bottleneck.code_predictor.get("sem_only", False):
    X_eval = torch.cat(
        [
            memory_eval[f"sep_encoding_1{post_bottleneck}"][:, 0, :],
            memory_eval[f"sep_encoding_2{post_bottleneck}"][:, 0, :],
        ],
        dim=1,
    )
else:
    X_eval = memory_eval[f"sep_encoding_1{post_bottleneck}"][:, 0, :]
y_eval = memory_eval["vq_codes"]

# Get top-k predicted codes

logger.info("Running code predictor")

if agent.model.code_predictor.config.get("beam_width", 0) < top_k:
    agent.model.code_predictor.config.data["beam_width"] = top_k

pred_codes = []
# # TODO: batchify!
for ix, x_batch in enumerate(tqdm(X_eval, desc="Predicting codes")):
    curr_codes = agent.model.code_predictor.infer(
        x_batch.unsqueeze(0).to(agent.device), {}, outputs_to_block=y_eval[ix].unsqueeze(0), top_k=top_k
    )
    pred_codes.append(curr_codes)


pred_codes = torch.cat(pred_codes, dim=0)

config_pred_diversity = copy.deepcopy(config.data)
config_pred_diversity["dataset"] = "json"
config_pred_diversity["json_dataset"] = {
    "path": config.eval.metrics.sep_ae.eval_dataset,
    "field_map": [
        {"type": "copy", "from": "sem_input", "to": "s2"},
        {"type": "copy", "from": "sem_input", "to": "template"},
        {"type": "copy", "from": "sem_input", "to": "s1"},
        {"type": "copy", "from": "forced_codes", "to": "forced_codes"},
    ],
}
config_pred_diversity["eval"]["topk"] = 1

data_loader = JsonDataLoader(
    data_path=agent.data_path,
    config=Config(config_pred_diversity),
    dev_samples=examples,
)

config.eval.data["sample_outputs"] = True

topk_outputs = []

for k in range(top_k):
    logger.info(f"Running generation with {k+1}th best codes")

    samples = data_loader._valid.samples
    samples = [{**x, "forced_codes": pred_codes[i, k, :].tolist()} for i, x in enumerate(samples)]
    forced_loader = JsonDataLoader(
        data_path=agent.data_path, config=Config(config_pred_diversity), dev_samples=samples
    )

    _, _, (output, _, _), _ = agent.inference(forced_loader.valid_loader)


    topk_outputs.append(output)


print(topk_outputs)



Validating after 12 epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.89it/s]
Predicting codes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 85.13it/s]
Validating after 12 epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.20it/s]
Validating after 12 epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.50it/s]
Validating after 12 epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.08it/s]

[['how much does a moose weight?', 'where do moose live?'], ['what is the weight of a moose?', 'what country do moose come from?'], ['what is the weight of moose?', 'what country do moose live in?']]



