In [1]:
import hydra
import hydra.experimental
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizerFast
from omegaconf import DictConfig

from torchfly.text.decode import TransformerDecoder
from torchfly.common import set_random_seed

from model import Seq2Seq
from configure_dataloader import DataLoaderHandler, TextRLCollator
from textrl_trainerloop import TextRLExperienceBuffer, TextRewardFunc, TextRLTrainerLoop

import logging

In [2]:
logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.WARNING,
                    format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')

In [3]:
set_random_seed(123)

In [4]:
hydra.experimental.initialize("config")
config = hydra.experimental.compose("config.yaml")

In [5]:
import nltk
from nltk import metrics, stem, tokenize
 
stemmer = stem.PorterStemmer()

class CommonGENReward(TextRewardFunc):
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer
        
    def calc_reward(self, states, actions, is_human_demo=False):
        rewards = []
        
        for idx in range(len(actions['tokens'])):
            generated_text = self.tokenizer.decode(actions['tokens'][idx][1:-1].tolist())
            
            key_words = [item.split("_")[0] for item in states[idx]["source_text"].split("#")]
            key_words = [stemmer.stem(w) for w in key_words]
            
            generated_words = tokenize.wordpunct_tokenize(generated_text.lower().strip())
            generated_words = [stemmer.stem(w) for w in generated_words]
            
            matches = 0
            
            for key_word in key_words:
                if key_word in generated_words:
                    matches += 1
            
            rewards.append(int(len(key_words) == matches))

        return np.array(rewards)

In [6]:
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")



In [7]:
dataloader_handler = DataLoaderHandler(config)
train_dataloader = dataloader_handler.train_dataloader(config)

In [8]:
collator = TextRLCollator(config, tokenizer)

In [9]:
batch = next(iter(train_dataloader))

In [10]:
reward_func = CommonGENReward(tokenizer)

In [11]:
model = Seq2Seq(config)
model_weights = torch.load("/home/wuqy1203/Desktop/Projects/TextGAIL/Experiments/MLE/outputs/CommonGEN/Checkpoints/iter_2439_model_state.pth")
model.load_state_dict(model_weights, strict=False)
model = model.cuda()

File exists: /home/wuqy1203/.cache/torchfly/models/roberta-tokenized-gpt2.pth
_IncompatibleKeys(missing_keys=['lm_head.weight'], unexpected_keys=['lm_head.decoder.weight'])


In [12]:
decoder = TransformerDecoder(config.decode)

In [13]:
decoder.register_generator(model.decoder)
decoder.register_tokenizer(tokenizer)

In [14]:
decoder.prepare_model_inputs_for_generation = model.prepare_model_inputs_for_generation

In [15]:
trainer = TextRLTrainerLoop(config=config,
                            reward_func=reward_func, 
                            decoder=decoder,
                            collator=collator,
                            model=model, 
                            train_dataloader_fn=dataloader_handler.train_dataloader,
                            valid_dataloader_fn=dataloader_handler.valid_dataloader)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [16]:
trainer.train()

--Return--
> /home/wuqy1203/Desktop/Projects/TextGAIL/Experiments/Text RL/textrl_trainerloop.py(165)collect_samples()->None
-> breakpoint()


(Pdb)  q


BdbQuit: 

In [None]:
#results = decoder.generate(input_ids=batch["source_token_ids"])

In [None]:
tokenizer.decode(tokens[2][0][1:-1].tolist())

In [None]:
key_words

In [None]:
batch['source_token_ids'][0]

In [None]:
"attend" in generated_text

In [17]:
np.ceil(100/3)

34.0

In [27]:
from datetime import datetime

In [32]:
datetime.now().strftime("%H:%M:%S")

'02:52:04'

In [33]:
from datetime import datetime
import pytz

tz_NY = pytz.timezone('America/New_York') 
datetime_NY = datetime.now(tz_NY)
print("NY time:", datetime_NY.strftime("%H:%M:%S"))

tz_London = pytz.timezone('Europe/London')
datetime_London = datetime.now(tz_London)
print("London time:", datetime_London.strftime("%H:%M:%S"))

NY time: 05:52:25
London time: 10:52:25
