In [1]:
import hydra
import hydra.experimental
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaTokenizer
from omegaconf import DictConfig

from torchfly.text.decode import TransformerDecoder
from torchfly.common import set_random_seed
from torchfly.text.rl import TextRLRewardFunc

from configure_dataloader import DataLoaderHandler, TextRLCollator

from model import TextGAILModel
from textgail_trainerloop import TextGAILTrainerLoop 

import logging

In [2]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

In [3]:
set_random_seed(123)

In [4]:
logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')

In [5]:
hydra.experimental.initialize("config")
config = hydra.experimental.compose("config.yaml")

In [7]:
dataloader_handler = DataLoaderHandler(config)
train_dataloader = dataloader_handler.train_dataloader(config)

2020-05-04 11:48:56,587 - configuration_utils.py[line:256] - INFO: loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /home/wuqy1203/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.a7ab0e5de2d8321d6d6a15b199110f2c99be72976b7d151423cb8d8c261a13b6
2020-05-04 11:48:56,588 - configuration_utils.py[line:292] - INFO: Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "eos_token_ids": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-05,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position

In [8]:
model = TextGAILModel(config)
model_weights = torch.load("/home/wuqy1203/Desktop/Projects/TextGAIL/Experiments/MLE/outputs/CommonGEN/Checkpoints/iter_2439_model_state.pth")
model.generator.load_state_dict(model_weights, strict=False)
model = model.cuda()

2020-05-04 11:48:58,588 - configuration_utils.py[line:256] - INFO: loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /home/wuqy1203/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.a7ab0e5de2d8321d6d6a15b199110f2c99be72976b7d151423cb8d8c261a13b6
2020-05-04 11:48:58,589 - configuration_utils.py[line:292] - INFO: Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "eos_token_ids": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-05,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position

In [9]:
torch.save(model.state_dict(), "test.pth")

In [10]:
decoder = TransformerDecoder(config.decode)
decoder.register_generator(model.generator.decoder)
decoder.register_tokenizer(tokenizer)
decoder.prepare_model_inputs_for_generation = model.generator.prepare_model_inputs_for_generation

In [11]:
reward_func = TextGAILDiscriminator(config, tokenizer, model.discriminator)

In [12]:
trainer = TextGAILTrainerLoop(config=config,
                            reward_func=reward_func, 
                            decoder_helper=decoder,
                            model=model, 
                            train_dataloader_fn=dataloader_handler.train_dataloader,
                            valid_dataloader_fn=dataloader_handler.valid_dataloader)

2020-05-04 11:49:04,334 - checkpoint.py[line:88] - INFO: Try to restore the latest checkpoint
2020-05-04 11:49:04,338 - textrl_log_handler.py[line:76] - INFO: decode:
  bos_token_ids:
  - 0
  do_sample: true
  early_stopping: true
  eos_token_ids:
  - 2
  length_penalty: 1.0
  max_steps: 100
  num_beams: 1
  num_return_sequences: 1
  output_log_probs: true
  repetition_penalty: 1.0
  retition_penalty: 1.0
  temperature: 0.9
  top_k: -1
  top_p: 0.9
model:
  attn_pdrop: 0.0
  embd_pdrop: 0.0
  initializer_range: 0.02
  layer_norm_epsilon: 1.0e-05
  n_ctx: 1024
  n_embd: 768
  n_head: 12
  n_layer: 12
  n_positions: 1024
  name: roberta-tokenized-gpt2
  output_attentions: false
  output_hidden_states: false
  output_past: true
  pad_token_id: 1
  resid_pdrop: 0.0
  vocab_size: 50265
task:
  data_dir: /home/wuqy1203/Desktop/Projects/TextGAIL/data/${task.name}
  name: CommonGEN
text_gail:
  batch_size: None
  constant_human_demo_reward: true
  discriminator_pretrain_steps: 200
  mix_human_

In [13]:
trainer.train()

2020-05-04 11:49:04,341 - textrl_log_handler.py[line:85] - INFO: Training Starts!
2020-05-04 11:49:06,690 - textrl_log_handler.py[line:177] - INFO: Train Steps - 0      - [ 0.0000%] - Speed:  0.0 - Discriminator/loss:  0.6936 - Mix_human_demo_ratio:  0.3996 - Mean_reward:  0.4999
2020-05-04 11:49:09,056 - textrl_log_handler.py[line:177] - INFO: Train Steps - 1      - [ 0.1000%] - Speed: 13.5 - Discriminator/loss:  0.6935 - Mix_human_demo_ratio:  0.3996 - Mean_reward:  0.4999
2020-05-04 11:49:11,180 - textrl_log_handler.py[line:177] - INFO: Train Steps - 2      - [ 0.2000%] - Speed: 15.1 - Discriminator/loss:  0.6929 - Mix_human_demo_ratio:  0.3995 - Mean_reward:  0.4999
2020-05-04 11:49:13,377 - textrl_log_handler.py[line:177] - INFO: Train Steps - 3      - [ 0.3000%] - Speed: 14.6 - Discriminator/loss:  0.6931 - Mix_human_demo_ratio:  0.3995 - Mean_reward:  0.5000


> /home/wuqy1203/Desktop/Projects/TorchFly/torchfly/training/trainer.py(335)get_trainer_state()
-> "epochs_trained": self.epochs_trained + 1,


(Pdb)  self.optimizers[0]


AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 8.000000000000001e-06
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 8.000000000000001e-06
    weight_decay: 0.0
)


(Pdb)  self.optimizers[1]


AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 0.0
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 0.0
    weight_decay: 0.0
)


(Pdb)  self.schedulers[0]


<torchfly.training.optimization.warmup_scheduler.WarmupLinearSchedule object at 0x7f16e242ad90>


(Pdb)  self.schedulers[1]


<torchfly.training.optimization.warmup_scheduler.WarmupLinearSchedule object at 0x7f16e242ae50>


(Pdb)  [optimizer for optimizer in self.optimizers]


[AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 8.000000000000001e-06
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 8.000000000000001e-06
    weight_decay: 0.0
), AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 0.0
    weight_decay: 0.01

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 1e-05
    lr: 0.0
    weight_decay: 0.0
)]


(Pdb)  q


BdbQuit: 