In [None]:
import hydra
import hydra.experimental
import numpy as np
import tqdm
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer
from omegaconf import DictConfig

from configure_dataloader import DataLoaderHandler,TextRLCollator
from torchfly.common import set_random_seed, move_to_device

from configure_dataloader import DataLoaderHandler
from model import Generator, TextGAILModel

import logging

In [None]:
random_seed = 1
set_random_seed(random_seed)

In [None]:
hydra.experimental.initialize("config")

In [None]:
config = hydra.experimental.compose("config.yaml")
print(config.pretty())

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

In [None]:
dataloader_handler = DataLoaderHandler(config)
test_dataloader = dataloader_handler.test_dataloader(config)
collate_fn = test_dataloader.dataset.collate_fn

tr = TextRLCollator(config,tokenizer)

In [None]:
device = torch.device("cuda")

In [None]:
model = TextGAILModel(config)
model = model.cuda()

In [None]:
decoder = TransformerDecoder(config.decode)
decoder.register_generator(model.generator.decoder)
decoder.register_tokenizer(tokenizer)
decoder.prepare_model_inputs_for_generation = model.generator.prepare_model_inputs_for_generation

## MLE Generation

In [None]:
# location to store results
os.makedirs(config.task.name,exist_ok=True)

In [None]:
print(config.task.mle_weights_path)

In [None]:
mle_weights = torch.load(config.task.mle_weights_path)
model.generator.load_state_dict(mle_weights)

In [None]:
temperatures = (np.arange(10) + 1) / 10.0

In [None]:
for temperature in temperatures:
    f_write = open(f"{config.task.name}/mle_{temperature}_{random_seed}.txt", "w")

    for batch in tqdm.tqdm(valid_dataloader):
        batch = collate_fn(batch)
        batch = move_to_device(batch, device)
        batch = tr.sample_collate(batch)
        
        ground_truth = batch["target_text"]

        results = decoder.generate(batch["source_token_ids"], temperature=temperature)
        generated = []

        for i in range(len(results["tokens"])):
            res = tokenizer.decode(results["tokens"][i][0][1:-1].tolist())
            generated.append(res)

        for gt, gen in zip(ground_truth, generated):
            f_write.write(json.dumps([gt, gen]))
            f_write.write("\n")
        
    f_write.close()

### Beam Search

In [None]:
# f_write = open(f"{config.task.name}/mle_beam_4.txt", "w")

# for batch in tqdm.tqdm(valid_dataloader):
#     batch = collate_fn(batch)
#     batch = move_to_device(batch, device)

#     ground_truth = batch["target_text"]

#     results = decoder.generate(batch["source_token_ids"], do_sample=False, num_beams=4)
#     generated = []

#     for i in range(len(results["tokens"])):
#         res = tokenizer.decode(results["tokens"][i][0][1:-1].tolist())
#         generated.append(res)

#     for gt, gen in zip(ground_truth, generated):
#         f_write.write(json.dumps([gt, gen]))
#         f_write.write("\n")        

# f_write.close()

## TextGAIL Generation

In [None]:
print(config.task.textgail_weights_path)

In [None]:
textgail_weights = torch.load(config.task.textgail_weights_path)
model.load_state_dict(textgail_weights)

In [None]:
for temperature in temperatures:
    f_write = open(f"{config.task.name}/textgail_{temperature}_{random_seed}.txt", "w")

    for batch in tqdm.tqdm(valid_dataloader):
        batch = collate_fn(batch)
        batch = move_to_device(batch, device)
        batch = tr.sample_collate(batch)
        
        ground_truth = batch["target_text"]

        results = decoder.generate(batch["source_token_ids"], temperature=temperature)
        generated = []

        for i in range(len(results["tokens"])):
            res = tokenizer.decode(results["tokens"][i][0][1:-1].tolist())
            generated.append(res)

        for gt, gen in zip(ground_truth, generated):
            f_write.write(json.dumps([gt, gen]))
            f_write.write("\n")
        
    f_write.close()

### Beam Search

In [None]:
f_write = open(f"{config.task.name}/textgail_no_pretrain2_beam_4.txt", "w")

for batch in tqdm.tqdm(valid_dataloader):
    batch = collate_fn(batch)
    batch = move_to_device(batch, device)
    batch = tr.sample_collate(batch)
    
    ground_truth = batch["target_text"]

    results = decoder.generate(batch["source_token_ids"], do_sample=False, num_beams=4)
    generated = []

    for i in range(len(results["tokens"])):
        res = tokenizer.decode(results["tokens"][i][0][1:-1].tolist())
        generated.append(res)
        
    for gt, gen in zip(ground_truth, generated):
        f_write.write(json.dumps([gt, gen]))
        f_write.write("\n")        

f_write.close()