In [1]:
import os
import json 
import logging

logging.basicConfig(
    filename='log/app.log',            # Specify the log file name
    level=logging.DEBUG,           # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(levelname)s - %(message)s'  # Set the log format
)

# Load the environment configuration JSON data
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token

In [2]:
import transformers 
print(transformers.__version__)

from transformers import pipeline
import torch

from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import LlamaTokenizerFast
import torch.nn.functional as F


accelerator = Accelerator()
device = accelerator.device

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3-8B"  # non-instruct version

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=access_token,
)

tokenizer = LlamaTokenizerFast.from_pretrained(model_id, token=access_token)
tokenizer.pad_token = tokenizer.eos_token

  from .autonotebook import tqdm as notebook_tqdm


4.41.0


Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'PreTrainedTokenizerFast'. 
The class this function is called from is 'LlamaTokenizerFast'.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
from llmexp.helper import LlmExpHelper
from datasets import load_dataset
from torch.utils.data import DataLoader

# imdb = load_dataset("imdb")
ds = load_dataset("rajpurkar/squad")
train_ds = ds['train']
test_ds = ds['validation']
# from datasets import load_dataset

# ds = load_dataset("stanfordnlp/sst2")
# train_ds = ds['train']
llm_exp_helper = LlmExpHelper(tokenizer, 'squad')
collate_fn = llm_exp_helper.get_collate_fun()

# Define batch size here!
batch_size = 16
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)

In [4]:
# next(iter(train_dataloader))
train_ds[0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [5]:
from llmexp.squad_model import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(hidden_size=4096, mlp_hidden_dim=1024, mlp_bottleneck_dim=768, mlp_num_blocks=2)
mask_gen_model.to(device)


from tqdm import tqdm

# Set pad_token_id if it is not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pad_token_id = tokenizer.pad_token_id

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

optimizer = torch.optim.Adam(mask_gen_model.parameters(), lr=1e-5)

In [13]:
mask_gen_model.train()
for epoch in range(1):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        input_ids = data['input_ids'].to(device)
        attention_mask = data['attention_mask'].to(device)
        context_mask = data['context_mask'].to(device)
        # get generated texts
        gen_outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
        gen_tokens = gen_outputs.sequences
        pad_length = gen_tokens.size(1) - input_ids.size(1)
        # get the attention mask for the generated tokens, and also mask the padding tokens
        gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
        # (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
        unpaded_token_mask = (gen_tokens != pad_token_id).long()
        unpaded_token_mask[:, :-pad_length] = 1
        gen_attention_mask = gen_attention_mask * unpaded_token_mask
        # get the response mask, which is the mask for the generated tokens (the user inputs are masked with 0)
        response_mask = gen_attention_mask.clone()
        response_mask[:, :-pad_length] = 0 # TODO: 有问题. 有问题吗？

        context_mask = F.pad(context_mask, (0, pad_length), mode='constant', value=0)

        # Get the last hidden state for the prompt + response sequence
        with torch.no_grad():
            full_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)
            last_hidden_state = full_outputs.hidden_states[-1]
            last_hidden_state = last_hidden_state.float()
        
        mask_logits = mask_gen_model(last_hidden_state)

        mask_gen_outputs = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits, response_mask, 
                                                                           num_samples=5)
        loss, reward_loss, mask_loss, mask_mean, mean_reward = mask_gen_outputs['loss'], mask_gen_outputs['reward_loss'], mask_gen_outputs['mask_loss'], mask_gen_outputs['mask_mean'], mask_gen_outputs['mean_reward']
        log = (f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss.item():.4f}, " 
                             f"Reward Loss = {reward_loss.item():.4f}, "
                             f"Mean Reward = {mean_reward.mean().item():.4f},"
                             f"Mask_loss = {mask_loss.item():.4f} "
                             f"mask_mean = {mask_mean.item():.4f}"
        )
        pbar.set_description(log)
        logging.debug(log)
    
        # the parameters before updating
        params_before = mask_gen_model.state_dict()

        # Train the model (inner loop)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # the mask_prob after the updates
        with torch.no_grad():
            mask_logits_after = mask_gen_model(last_hidden_state)

            mask_gen_outputs_after = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits_after, response_mask, 
                                                                            num_samples=5)
            loss_after, reward_loss_after, mask_loss_after, mask_mean_after, mean_reward_after = mask_gen_outputs_after['loss'], mask_gen_outputs_after['reward_loss'], mask_gen_outputs_after['mask_loss'], mask_gen_outputs_after['mask_mean'], mask_gen_outputs_after['mean_reward']
            mask_prob_after = (torch.sigmoid(mask_logits_after) * context_mask).clone().detach()
            mean_reward_after = mean_reward_after.clone().detach()

        # load the parameters before the updates
        mask_gen_model.load_state_dict(params_before)
        mask_logits_before = mask_gen_model(last_hidden_state)

        mask_gen_outputs_before = mask_gen_model.loss_func(model, gen_tokens, gen_attention_mask, context_mask, mask_logits_before, response_mask, 
                                                                           num_samples=5)
        loss_before, reward_loss_before, mask_loss_before, mask_mean_before, mean_reward_before = mask_gen_outputs_before['loss'], mask_gen_outputs_before['reward_loss'], mask_gen_outputs_before['mask_loss'], mask_gen_outputs_before['mask_mean'], mask_gen_outputs_before['mean_reward']
        mask_prob_before = (torch.sigmoid(mask_logits_before) * context_mask)

        # calculate the ratio of the mask probabilities before and after the updates
        ratio = mask_prob_after / (mask_prob_before + 1e-6)

        # 定义PPO的损失函数，假设clip_param是你定义的剪切参数
        clip_param = 0.2
        advantage = (mean_reward_after - mean_reward_before).unsqueeze(-1)  # 计算优势函数（advantage），这是根据任务定义的
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1 - clip_param, 1 + clip_param) * advantage
        ppo_loss = -torch.min(surr1, surr2).mean()

        # 更新模型参数
        optimizer.zero_grad()
        ppo_loss.backward()
        optimizer.step()

        if idx % 10 == 0:
            print()
        if idx % 200 == 0 and idx != 0:
            torch.save(mask_gen_model.state_dict(), f'saved_model/mask_gen_model_{epoch}_{idx}.pth') 
            print()
            # break

Epoch 1, Step 1: Loss = 0.2883, Reward Loss = 0.2800, Mean Reward = 0.4903,Mask_loss = 0.1661 mask_mean = 0.5699:   0%|          | 1/5475 [00:10<15:13:49, 10.02s/it]




Epoch 1, Step 11: Loss = 0.1777, Reward Loss = 0.1748, Mean Reward = 0.3219,Mask_loss = 0.0578 mask_mean = 0.2001:   0%|          | 11/5475 [01:50<17:12:01, 11.33s/it]




Epoch 1, Step 21: Loss = 0.2929, Reward Loss = 0.2784, Mean Reward = 0.4837,Mask_loss = 0.2917 mask_mean = 0.7479:   0%|          | 21/5475 [03:41<17:32:33, 11.58s/it]




Epoch 1, Step 31: Loss = 0.2095, Reward Loss = 0.2071, Mean Reward = 0.3650,Mask_loss = 0.0488 mask_mean = 0.3160:   1%|          | 31/5475 [05:38<17:09:34, 11.35s/it]




Epoch 1, Step 41: Loss = 0.2735, Reward Loss = 0.2679, Mean Reward = 0.4563,Mask_loss = 0.1118 mask_mean = 0.5527:   1%|          | 41/5475 [07:24<15:30:27, 10.27s/it]




Epoch 1, Step 51: Loss = 0.2697, Reward Loss = 0.2588, Mean Reward = 0.4674,Mask_loss = 0.2172 mask_mean = 0.6638:   1%|          | 51/5475 [08:59<15:12:35, 10.10s/it]




Epoch 1, Step 61: Loss = 0.2454, Reward Loss = 0.2439, Mean Reward = 0.4322,Mask_loss = 0.0296 mask_mean = 0.3583:   1%|          | 61/5475 [10:29<13:58:37,  9.29s/it]




Epoch 1, Step 71: Loss = 0.2525, Reward Loss = 0.2491, Mean Reward = 0.4436,Mask_loss = 0.0687 mask_mean = 0.4041:   1%|▏         | 71/5475 [12:09<13:49:02,  9.20s/it]




Epoch 1, Step 81: Loss = 0.2504, Reward Loss = 0.2464, Mean Reward = 0.4339,Mask_loss = 0.0805 mask_mean = 0.4991:   1%|▏         | 81/5475 [13:54<15:15:24, 10.18s/it]




Epoch 1, Step 91: Loss = 0.2671, Reward Loss = 0.2615, Mean Reward = 0.4539,Mask_loss = 0.1102 mask_mean = 0.5355:   2%|▏         | 91/5475 [15:37<15:53:30, 10.63s/it]




Epoch 1, Step 101: Loss = 0.2407, Reward Loss = 0.2385, Mean Reward = 0.4119,Mask_loss = 0.0432 mask_mean = 0.3322:   2%|▏         | 101/5475 [17:20<15:45:11, 10.55s/it]




Epoch 1, Step 111: Loss = 0.2347, Reward Loss = 0.2336, Mean Reward = 0.4180,Mask_loss = 0.0222 mask_mean = 0.2967:   2%|▏         | 111/5475 [18:50<12:27:09,  8.36s/it]




Epoch 1, Step 121: Loss = 0.2859, Reward Loss = 0.2798, Mean Reward = 0.4841,Mask_loss = 0.1229 mask_mean = 0.5403:   2%|▏         | 121/5475 [20:27<12:56:56,  8.71s/it]




Epoch 1, Step 131: Loss = 0.2250, Reward Loss = 0.2235, Mean Reward = 0.3937,Mask_loss = 0.0311 mask_mean = 0.3632:   2%|▏         | 131/5475 [22:07<15:36:22, 10.51s/it]




Epoch 1, Step 141: Loss = 0.2622, Reward Loss = 0.2589, Mean Reward = 0.4391,Mask_loss = 0.0642 mask_mean = 0.4520:   3%|▎         | 141/5475 [23:46<15:14:26, 10.29s/it]




Epoch 1, Step 151: Loss = 0.2553, Reward Loss = 0.2512, Mean Reward = 0.4221,Mask_loss = 0.0811 mask_mean = 0.4980:   3%|▎         | 151/5475 [25:25<15:09:00, 10.24s/it]




Epoch 1, Step 161: Loss = 0.2042, Reward Loss = 0.2023, Mean Reward = 0.3587,Mask_loss = 0.0379 mask_mean = 0.3316:   3%|▎         | 161/5475 [27:09<14:29:12,  9.81s/it]




Epoch 1, Step 171: Loss = 0.2568, Reward Loss = 0.2552, Mean Reward = 0.4437,Mask_loss = 0.0327 mask_mean = 0.3692:   3%|▎         | 171/5475 [28:49<15:44:08, 10.68s/it]




Epoch 1, Step 181: Loss = 0.2159, Reward Loss = 0.2147, Mean Reward = 0.3819,Mask_loss = 0.0240 mask_mean = 0.3148:   3%|▎         | 181/5475 [30:37<14:52:23, 10.11s/it]




Epoch 1, Step 191: Loss = 0.2739, Reward Loss = 0.2689, Mean Reward = 0.4550,Mask_loss = 0.0999 mask_mean = 0.5226:   3%|▎         | 191/5475 [32:07<13:22:40,  9.11s/it]




Epoch 1, Step 201: Loss = 0.2892, Reward Loss = 0.2803, Mean Reward = 0.4839,Mask_loss = 0.1778 mask_mean = 0.5968:   4%|▎         | 200/5475 [33:34<13:01:57,  8.89s/it]




Epoch 1, Step 201: Loss = 0.2892, Reward Loss = 0.2803, Mean Reward = 0.4839,Mask_loss = 0.1778 mask_mean = 0.5968:   4%|▎         | 201/5475 [33:39<13:25:53,  9.17s/it]




Epoch 1, Step 211: Loss = 0.2911, Reward Loss = 0.2849, Mean Reward = 0.4871,Mask_loss = 0.1256 mask_mean = 0.5468:   4%|▍         | 211/5475 [35:16<14:22:07,  9.83s/it]




Epoch 1, Step 221: Loss = 0.1883, Reward Loss = 0.1859, Mean Reward = 0.3385,Mask_loss = 0.0487 mask_mean = 0.2511:   4%|▍         | 221/5475 [37:04<14:20:49,  9.83s/it]




Epoch 1, Step 231: Loss = 0.2768, Reward Loss = 0.2687, Mean Reward = 0.4592,Mask_loss = 0.1624 mask_mean = 0.6096:   4%|▍         | 231/5475 [38:38<13:21:22,  9.17s/it]




Epoch 1, Step 241: Loss = 0.2917, Reward Loss = 0.2854, Mean Reward = 0.4800,Mask_loss = 0.1254 mask_mean = 0.5477:   4%|▍         | 241/5475 [40:11<13:10:36,  9.06s/it]




Epoch 1, Step 251: Loss = 0.2725, Reward Loss = 0.2684, Mean Reward = 0.4482,Mask_loss = 0.0820 mask_mean = 0.4777:   5%|▍         | 251/5475 [41:49<13:15:22,  9.14s/it]




Epoch 1, Step 261: Loss = 0.2981, Reward Loss = 0.2942, Mean Reward = 0.4824,Mask_loss = 0.0784 mask_mean = 0.4707:   5%|▍         | 261/5475 [43:19<12:32:11,  8.66s/it]




Epoch 1, Step 271: Loss = 0.1955, Reward Loss = 0.1931, Mean Reward = 0.3507,Mask_loss = 0.0484 mask_mean = 0.2532:   5%|▍         | 271/5475 [45:05<15:52:23, 10.98s/it]




Epoch 1, Step 281: Loss = 0.1724, Reward Loss = 0.1692, Mean Reward = 0.3204,Mask_loss = 0.0646 mask_mean = 0.1879:   5%|▌         | 281/5475 [47:00<18:41:14, 12.95s/it]




Epoch 1, Step 291: Loss = 0.2825, Reward Loss = 0.2794, Mean Reward = 0.4695,Mask_loss = 0.0624 mask_mean = 0.4439:   5%|▌         | 291/5475 [48:45<14:49:46, 10.30s/it]




Epoch 1, Step 301: Loss = 0.2989, Reward Loss = 0.2934, Mean Reward = 0.4777,Mask_loss = 0.1099 mask_mean = 0.5376:   5%|▌         | 301/5475 [50:27<14:45:14, 10.27s/it]




Epoch 1, Step 311: Loss = 0.2770, Reward Loss = 0.2719, Mean Reward = 0.4607,Mask_loss = 0.1022 mask_mean = 0.4977:   6%|▌         | 311/5475 [52:05<13:15:32,  9.24s/it]




Epoch 1, Step 321: Loss = 0.2706, Reward Loss = 0.2678, Mean Reward = 0.4586,Mask_loss = 0.0555 mask_mean = 0.4089:   6%|▌         | 321/5475 [53:33<12:43:36,  8.89s/it]




Epoch 1, Step 331: Loss = 0.2927, Reward Loss = 0.2860, Mean Reward = 0.4774,Mask_loss = 0.1331 mask_mean = 0.5776:   6%|▌         | 331/5475 [55:08<13:38:20,  9.55s/it]




Epoch 1, Step 341: Loss = 0.2547, Reward Loss = 0.2527, Mean Reward = 0.4298,Mask_loss = 0.0403 mask_mean = 0.3758:   6%|▌         | 341/5475 [56:53<13:36:23,  9.54s/it]




Epoch 1, Step 351: Loss = 0.2687, Reward Loss = 0.2645, Mean Reward = 0.4469,Mask_loss = 0.0837 mask_mean = 0.4535:   6%|▋         | 351/5475 [58:41<15:51:57, 11.15s/it]




Epoch 1, Step 361: Loss = 0.1802, Reward Loss = 0.1774, Mean Reward = 0.3308,Mask_loss = 0.0558 mask_mean = 0.2073:   7%|▋         | 361/5475 [1:00:16<14:08:59,  9.96s/it]




Epoch 1, Step 371: Loss = 0.2876, Reward Loss = 0.2827, Mean Reward = 0.4660,Mask_loss = 0.0967 mask_mean = 0.5118:   7%|▋         | 371/5475 [1:02:02<14:25:35, 10.18s/it]




Epoch 1, Step 381: Loss = 0.2628, Reward Loss = 0.2615, Mean Reward = 0.4416,Mask_loss = 0.0260 mask_mean = 0.3641:   7%|▋         | 381/5475 [1:03:45<14:23:11, 10.17s/it]




Epoch 1, Step 391: Loss = 0.2540, Reward Loss = 0.2509, Mean Reward = 0.4192,Mask_loss = 0.0617 mask_mean = 0.4413:   7%|▋         | 391/5475 [1:05:19<12:37:14,  8.94s/it]




Epoch 1, Step 401: Loss = 0.1996, Reward Loss = 0.1987, Mean Reward = 0.3627,Mask_loss = 0.0186 mask_mean = 0.2622:   7%|▋         | 400/5475 [1:06:52<13:31:33,  9.59s/it]




Epoch 1, Step 401: Loss = 0.1996, Reward Loss = 0.1987, Mean Reward = 0.3627,Mask_loss = 0.0186 mask_mean = 0.2622:   7%|▋         | 401/5475 [1:06:58<13:45:33,  9.76s/it]




Epoch 1, Step 411: Loss = 0.1612, Reward Loss = 0.1578, Mean Reward = 0.2874,Mask_loss = 0.0682 mask_mean = 0.2394:   8%|▊         | 411/5475 [1:08:32<13:43:23,  9.76s/it]




Epoch 1, Step 421: Loss = 0.2595, Reward Loss = 0.2576, Mean Reward = 0.4419,Mask_loss = 0.0378 mask_mean = 0.3643:   8%|▊         | 421/5475 [1:10:10<13:11:05,  9.39s/it]




Epoch 1, Step 431: Loss = 0.2625, Reward Loss = 0.2615, Mean Reward = 0.4564,Mask_loss = 0.0194 mask_mean = 0.3302:   8%|▊         | 431/5475 [1:11:49<13:01:35,  9.30s/it]




Epoch 1, Step 441: Loss = 0.2824, Reward Loss = 0.2726, Mean Reward = 0.4676,Mask_loss = 0.1976 mask_mean = 0.6672:   8%|▊         | 441/5475 [1:13:31<13:31:18,  9.67s/it]




Epoch 1, Step 451: Loss = 0.2859, Reward Loss = 0.2790, Mean Reward = 0.4634,Mask_loss = 0.1374 mask_mean = 0.5813:   8%|▊         | 451/5475 [1:15:07<13:22:27,  9.58s/it]




Epoch 1, Step 461: Loss = 0.2121, Reward Loss = 0.2103, Mean Reward = 0.3634,Mask_loss = 0.0370 mask_mean = 0.3581:   8%|▊         | 461/5475 [1:17:07<14:56:17, 10.73s/it]




Epoch 1, Step 471: Loss = 0.2682, Reward Loss = 0.2660, Mean Reward = 0.4453,Mask_loss = 0.0447 mask_mean = 0.3793:   9%|▊         | 471/5475 [1:18:37<12:58:05,  9.33s/it]




Epoch 1, Step 481: Loss = 0.3021, Reward Loss = 0.2930, Mean Reward = 0.5000,Mask_loss = 0.1808 mask_mean = 0.6181:   9%|▉         | 481/5475 [1:20:10<13:09:27,  9.48s/it]




Epoch 1, Step 491: Loss = 0.2109, Reward Loss = 0.2095, Mean Reward = 0.3742,Mask_loss = 0.0290 mask_mean = 0.2714:   9%|▉         | 491/5475 [1:21:50<13:13:08,  9.55s/it]




Epoch 1, Step 501: Loss = 0.2062, Reward Loss = 0.2048, Mean Reward = 0.3733,Mask_loss = 0.0278 mask_mean = 0.2863:   9%|▉         | 501/5475 [1:23:32<13:33:21,  9.81s/it]




Epoch 1, Step 511: Loss = 0.2242, Reward Loss = 0.2229, Mean Reward = 0.4058,Mask_loss = 0.0259 mask_mean = 0.2827:   9%|▉         | 511/5475 [1:25:05<12:51:55,  9.33s/it]




Epoch 1, Step 521: Loss = 0.2072, Reward Loss = 0.2056, Mean Reward = 0.3744,Mask_loss = 0.0308 mask_mean = 0.2625:  10%|▉         | 521/5475 [1:26:46<13:24:02,  9.74s/it]




Epoch 1, Step 531: Loss = 0.2814, Reward Loss = 0.2711, Mean Reward = 0.4773,Mask_loss = 0.2052 mask_mean = 0.6522:  10%|▉         | 531/5475 [1:28:29<15:32:20, 11.31s/it]




Epoch 1, Step 541: Loss = 0.2739, Reward Loss = 0.2611, Mean Reward = 0.4750,Mask_loss = 0.2566 mask_mean = 0.7214:  10%|▉         | 541/5475 [1:30:19<15:15:20, 11.13s/it]




Epoch 1, Step 551: Loss = 0.2893, Reward Loss = 0.2756, Mean Reward = 0.4971,Mask_loss = 0.2728 mask_mean = 0.7211:  10%|█         | 551/5475 [1:32:05<14:09:36, 10.35s/it]




Epoch 1, Step 561: Loss = 0.2077, Reward Loss = 0.2061, Mean Reward = 0.3573,Mask_loss = 0.0324 mask_mean = 0.2907:  10%|█         | 561/5475 [1:33:45<13:40:38, 10.02s/it]




Epoch 1, Step 571: Loss = 0.2898, Reward Loss = 0.2833, Mean Reward = 0.4670,Mask_loss = 0.1295 mask_mean = 0.5747:  10%|█         | 571/5475 [1:35:24<13:55:17, 10.22s/it]




Epoch 1, Step 581: Loss = 0.2951, Reward Loss = 0.2914, Mean Reward = 0.4714,Mask_loss = 0.0737 mask_mean = 0.4858:  11%|█         | 581/5475 [1:37:08<14:04:21, 10.35s/it]




Epoch 1, Step 591: Loss = 0.2818, Reward Loss = 0.2786, Mean Reward = 0.4573,Mask_loss = 0.0635 mask_mean = 0.4653:  11%|█         | 591/5475 [1:38:54<15:39:08, 11.54s/it]




Epoch 1, Step 601: Loss = 0.2490, Reward Loss = 0.2475, Mean Reward = 0.4188,Mask_loss = 0.0290 mask_mean = 0.3620:  11%|█         | 600/5475 [1:40:26<13:58:05, 10.31s/it]




Epoch 1, Step 601: Loss = 0.2490, Reward Loss = 0.2475, Mean Reward = 0.4188,Mask_loss = 0.0290 mask_mean = 0.3620:  11%|█         | 601/5475 [1:40:31<14:00:57, 10.35s/it]




Epoch 1, Step 611: Loss = 0.2717, Reward Loss = 0.2698, Mean Reward = 0.4442,Mask_loss = 0.0370 mask_mean = 0.4192:  11%|█         | 611/5475 [1:42:15<14:27:50, 10.71s/it]




Epoch 1, Step 621: Loss = 0.2805, Reward Loss = 0.2762, Mean Reward = 0.4489,Mask_loss = 0.0843 mask_mean = 0.5066:  11%|█▏        | 621/5475 [1:43:50<11:41:18,  8.67s/it]




Epoch 1, Step 631: Loss = 0.2895, Reward Loss = 0.2869, Mean Reward = 0.4692,Mask_loss = 0.0522 mask_mean = 0.4325:  12%|█▏        | 631/5475 [1:45:28<13:41:53, 10.18s/it]




Epoch 1, Step 641: Loss = 0.2958, Reward Loss = 0.2901, Mean Reward = 0.4772,Mask_loss = 0.1140 mask_mean = 0.5411:  12%|█▏        | 641/5475 [1:47:03<12:13:32,  9.10s/it]




Epoch 1, Step 651: Loss = 0.2944, Reward Loss = 0.2901, Mean Reward = 0.4750,Mask_loss = 0.0858 mask_mean = 0.4974:  12%|█▏        | 651/5475 [1:48:39<12:45:13,  9.52s/it]




Epoch 1, Step 661: Loss = 0.1549, Reward Loss = 0.1517, Mean Reward = 0.2870,Mask_loss = 0.0629 mask_mean = 0.2053:  12%|█▏        | 661/5475 [1:50:20<14:44:52, 11.03s/it]




Epoch 1, Step 671: Loss = 0.2442, Reward Loss = 0.2430, Mean Reward = 0.4198,Mask_loss = 0.0234 mask_mean = 0.3410:  12%|█▏        | 671/5475 [1:52:02<13:23:22, 10.03s/it]




Epoch 1, Step 681: Loss = 0.2881, Reward Loss = 0.2843, Mean Reward = 0.4671,Mask_loss = 0.0772 mask_mean = 0.4841:  12%|█▏        | 681/5475 [1:53:43<13:35:36, 10.21s/it]




Epoch 1, Step 691: Loss = 0.2873, Reward Loss = 0.2820, Mean Reward = 0.4700,Mask_loss = 0.1055 mask_mean = 0.5338:  13%|█▎        | 691/5475 [1:55:20<14:37:36, 11.01s/it]




Epoch 1, Step 701: Loss = 0.2630, Reward Loss = 0.2614, Mean Reward = 0.4376,Mask_loss = 0.0309 mask_mean = 0.3934:  13%|█▎        | 701/5475 [1:57:03<13:57:26, 10.52s/it]




Epoch 1, Step 711: Loss = 0.2847, Reward Loss = 0.2807, Mean Reward = 0.4566,Mask_loss = 0.0800 mask_mean = 0.4756:  13%|█▎        | 711/5475 [1:58:40<13:49:32, 10.45s/it]




Epoch 1, Step 721: Loss = 0.2932, Reward Loss = 0.2901, Mean Reward = 0.4712,Mask_loss = 0.0626 mask_mean = 0.4536:  13%|█▎        | 721/5475 [2:00:18<13:13:26, 10.01s/it]




Epoch 1, Step 731: Loss = 0.2708, Reward Loss = 0.2660, Mean Reward = 0.4307,Mask_loss = 0.0973 mask_mean = 0.5383:  13%|█▎        | 731/5475 [2:02:08<14:54:15, 11.31s/it]




Epoch 1, Step 741: Loss = 0.2499, Reward Loss = 0.2482, Mean Reward = 0.4127,Mask_loss = 0.0346 mask_mean = 0.4095:  14%|█▎        | 741/5475 [2:03:51<14:12:13, 10.80s/it]




Epoch 1, Step 751: Loss = 0.2897, Reward Loss = 0.2862, Mean Reward = 0.4665,Mask_loss = 0.0685 mask_mean = 0.4668:  14%|█▎        | 751/5475 [2:05:25<11:44:27,  8.95s/it]




Epoch 1, Step 761: Loss = 0.2887, Reward Loss = 0.2840, Mean Reward = 0.4723,Mask_loss = 0.0924 mask_mean = 0.5063:  14%|█▍        | 761/5475 [2:07:14<15:21:57, 11.73s/it]




Epoch 1, Step 771: Loss = 0.2983, Reward Loss = 0.2857, Mean Reward = 0.4992,Mask_loss = 0.2524 mask_mean = 0.6950:  14%|█▍        | 771/5475 [2:08:53<12:21:17,  9.46s/it]




Epoch 1, Step 781: Loss = 0.2780, Reward Loss = 0.2757, Mean Reward = 0.4522,Mask_loss = 0.0472 mask_mean = 0.4478:  14%|█▍        | 781/5475 [2:10:29<12:03:18,  9.25s/it]




Epoch 1, Step 791: Loss = 0.2669, Reward Loss = 0.2641, Mean Reward = 0.4328,Mask_loss = 0.0559 mask_mean = 0.4523:  14%|█▍        | 791/5475 [2:12:16<15:13:13, 11.70s/it]




Epoch 1, Step 801: Loss = 0.2889, Reward Loss = 0.2814, Mean Reward = 0.4701,Mask_loss = 0.1491 mask_mean = 0.6033:  15%|█▍        | 800/5475 [2:13:57<13:15:39, 10.21s/it]




Epoch 1, Step 801: Loss = 0.2889, Reward Loss = 0.2814, Mean Reward = 0.4701,Mask_loss = 0.1491 mask_mean = 0.6033:  15%|█▍        | 801/5475 [2:14:02<12:59:56, 10.01s/it]




Epoch 1, Step 811: Loss = 0.2830, Reward Loss = 0.2778, Mean Reward = 0.4509,Mask_loss = 0.1044 mask_mean = 0.5321:  15%|█▍        | 811/5475 [2:15:37<12:08:23,  9.37s/it]




Epoch 1, Step 821: Loss = 0.2430, Reward Loss = 0.2417, Mean Reward = 0.4080,Mask_loss = 0.0261 mask_mean = 0.3838:  15%|█▍        | 821/5475 [2:17:20<13:06:09, 10.14s/it]




Epoch 1, Step 831: Loss = 0.2657, Reward Loss = 0.2647, Mean Reward = 0.4466,Mask_loss = 0.0203 mask_mean = 0.3593:  15%|█▌        | 831/5475 [2:18:53<11:50:18,  9.18s/it]




Epoch 1, Step 841: Loss = 0.2264, Reward Loss = 0.2252, Mean Reward = 0.3864,Mask_loss = 0.0248 mask_mean = 0.3596:  15%|█▌        | 841/5475 [2:20:32<14:05:12, 10.94s/it]




Epoch 1, Step 851: Loss = 0.2546, Reward Loss = 0.2524, Mean Reward = 0.4148,Mask_loss = 0.0427 mask_mean = 0.4146:  16%|█▌        | 851/5475 [2:22:02<12:19:35,  9.60s/it]




Epoch 1, Step 861: Loss = 0.1522, Reward Loss = 0.1484, Mean Reward = 0.2901,Mask_loss = 0.0760 mask_mean = 0.1649:  16%|█▌        | 861/5475 [2:23:44<12:54:13, 10.07s/it]




Epoch 1, Step 871: Loss = 0.2672, Reward Loss = 0.2642, Mean Reward = 0.4349,Mask_loss = 0.0607 mask_mean = 0.4369:  16%|█▌        | 871/5475 [2:25:26<14:07:06, 11.04s/it]




Epoch 1, Step 881: Loss = 0.2163, Reward Loss = 0.2147, Mean Reward = 0.3670,Mask_loss = 0.0310 mask_mean = 0.3506:  16%|█▌        | 881/5475 [2:26:59<11:54:56,  9.34s/it]




Epoch 1, Step 891: Loss = 0.2796, Reward Loss = 0.2735, Mean Reward = 0.4575,Mask_loss = 0.1203 mask_mean = 0.5678:  16%|█▋        | 891/5475 [2:28:28<11:53:23,  9.34s/it]




Epoch 1, Step 901: Loss = 0.2879, Reward Loss = 0.2819, Mean Reward = 0.4697,Mask_loss = 0.1209 mask_mean = 0.5496:  16%|█▋        | 901/5475 [2:30:06<12:20:12,  9.71s/it]




Epoch 1, Step 911: Loss = 0.2822, Reward Loss = 0.2789, Mean Reward = 0.4594,Mask_loss = 0.0663 mask_mean = 0.4563:  17%|█▋        | 911/5475 [2:31:46<13:15:06, 10.45s/it]




Epoch 1, Step 921: Loss = 0.2308, Reward Loss = 0.2298, Mean Reward = 0.4005,Mask_loss = 0.0205 mask_mean = 0.2964:  17%|█▋        | 921/5475 [2:33:29<13:22:41, 10.58s/it]




Epoch 1, Step 931: Loss = 0.1613, Reward Loss = 0.1580, Mean Reward = 0.3141,Mask_loss = 0.0653 mask_mean = 0.1614:  17%|█▋        | 931/5475 [2:35:06<12:30:27,  9.91s/it]




Epoch 1, Step 941: Loss = 0.2152, Reward Loss = 0.2136, Mean Reward = 0.3946,Mask_loss = 0.0332 mask_mean = 0.2495:  17%|█▋        | 941/5475 [2:36:55<13:30:51, 10.73s/it]




Epoch 1, Step 951: Loss = 0.2725, Reward Loss = 0.2702, Mean Reward = 0.4504,Mask_loss = 0.0447 mask_mean = 0.3986:  17%|█▋        | 951/5475 [2:38:38<12:52:46, 10.25s/it]




Epoch 1, Step 961: Loss = 0.2609, Reward Loss = 0.2579, Mean Reward = 0.4193,Mask_loss = 0.0583 mask_mean = 0.4573:  18%|█▊        | 961/5475 [2:40:24<13:16:19, 10.58s/it]




Epoch 1, Step 971: Loss = 0.2642, Reward Loss = 0.2619, Mean Reward = 0.4381,Mask_loss = 0.0460 mask_mean = 0.4062:  18%|█▊        | 971/5475 [2:41:58<11:16:02,  9.01s/it]




Epoch 1, Step 981: Loss = 0.2118, Reward Loss = 0.2098, Mean Reward = 0.3529,Mask_loss = 0.0405 mask_mean = 0.3485:  18%|█▊        | 981/5475 [2:43:34<12:27:25,  9.98s/it]




Epoch 1, Step 991: Loss = 0.2825, Reward Loss = 0.2783, Mean Reward = 0.4533,Mask_loss = 0.0844 mask_mean = 0.5003:  18%|█▊        | 991/5475 [2:45:16<13:46:50, 11.06s/it]




Epoch 1, Step 1001: Loss = 0.2781, Reward Loss = 0.2639, Mean Reward = 0.4871,Mask_loss = 0.2827 mask_mean = 0.7376:  18%|█▊        | 1000/5475 [2:46:47<11:45:03,  9.45s/it]




Epoch 1, Step 1001: Loss = 0.2781, Reward Loss = 0.2639, Mean Reward = 0.4871,Mask_loss = 0.2827 mask_mean = 0.7376:  18%|█▊        | 1001/5475 [2:46:53<12:22:35,  9.96s/it]




Epoch 1, Step 1011: Loss = 0.3062, Reward Loss = 0.3021, Mean Reward = 0.4905,Mask_loss = 0.0817 mask_mean = 0.4752:  18%|█▊        | 1011/5475 [2:48:26<12:27:16, 10.04s/it]




Epoch 1, Step 1021: Loss = 0.2022, Reward Loss = 0.2002, Mean Reward = 0.3487,Mask_loss = 0.0403 mask_mean = 0.2906:  19%|█▊        | 1021/5475 [2:50:01<11:44:19,  9.49s/it]




Epoch 1, Step 1031: Loss = 0.2213, Reward Loss = 0.2203, Mean Reward = 0.3852,Mask_loss = 0.0189 mask_mean = 0.3009:  19%|█▉        | 1031/5475 [2:51:39<13:23:35, 10.85s/it]




Epoch 1, Step 1041: Loss = 0.2282, Reward Loss = 0.2267, Mean Reward = 0.3982,Mask_loss = 0.0285 mask_mean = 0.3302:  19%|█▉        | 1041/5475 [2:53:18<12:17:52,  9.98s/it]




Epoch 1, Step 1051: Loss = 0.2932, Reward Loss = 0.2898, Mean Reward = 0.4680,Mask_loss = 0.0675 mask_mean = 0.4740:  19%|█▉        | 1051/5475 [2:55:00<12:47:15, 10.41s/it]




Epoch 1, Step 1061: Loss = 0.2640, Reward Loss = 0.2622, Mean Reward = 0.4445,Mask_loss = 0.0366 mask_mean = 0.3510:  19%|█▉        | 1061/5475 [2:56:33<10:57:12,  8.93s/it]




Epoch 1, Step 1071: Loss = 0.2353, Reward Loss = 0.2345, Mean Reward = 0.4024,Mask_loss = 0.0156 mask_mean = 0.3229:  20%|█▉        | 1071/5475 [2:58:14<11:27:39,  9.37s/it]




Epoch 1, Step 1081: Loss = 0.2251, Reward Loss = 0.2239, Mean Reward = 0.3899,Mask_loss = 0.0252 mask_mean = 0.2975:  20%|█▉        | 1081/5475 [2:59:47<10:56:56,  8.97s/it]




Epoch 1, Step 1091: Loss = 0.2944, Reward Loss = 0.2873, Mean Reward = 0.4668,Mask_loss = 0.1428 mask_mean = 0.6027:  20%|█▉        | 1091/5475 [3:01:24<12:02:16,  9.89s/it]




Epoch 1, Step 1101: Loss = 0.2971, Reward Loss = 0.2876, Mean Reward = 0.4870,Mask_loss = 0.1898 mask_mean = 0.6291:  20%|██        | 1101/5475 [3:03:01<11:19:49,  9.33s/it]




Epoch 1, Step 1111: Loss = 0.3040, Reward Loss = 0.2965, Mean Reward = 0.4901,Mask_loss = 0.1512 mask_mean = 0.5789:  20%|██        | 1111/5475 [3:04:44<12:54:31, 10.65s/it]




Epoch 1, Step 1121: Loss = 0.2683, Reward Loss = 0.2659, Mean Reward = 0.4405,Mask_loss = 0.0471 mask_mean = 0.4081:  20%|██        | 1121/5475 [3:06:14<10:52:25,  8.99s/it]




Epoch 1, Step 1131: Loss = 0.2733, Reward Loss = 0.2715, Mean Reward = 0.4474,Mask_loss = 0.0354 mask_mean = 0.4138:  21%|██        | 1131/5475 [3:07:45<10:33:14,  8.75s/it]




Epoch 1, Step 1141: Loss = 0.2747, Reward Loss = 0.2730, Mean Reward = 0.4542,Mask_loss = 0.0323 mask_mean = 0.3842:  21%|██        | 1141/5475 [3:09:21<11:18:29,  9.39s/it]




Epoch 1, Step 1151: Loss = 0.3059, Reward Loss = 0.3029, Mean Reward = 0.4867,Mask_loss = 0.0599 mask_mean = 0.4452:  21%|██        | 1151/5475 [3:10:58<10:53:01,  9.06s/it]




Epoch 1, Step 1161: Loss = 0.3077, Reward Loss = 0.3034, Mean Reward = 0.4908,Mask_loss = 0.0872 mask_mean = 0.4820:  21%|██        | 1161/5475 [3:12:33<10:49:03,  9.03s/it]




Epoch 1, Step 1171: Loss = 0.2699, Reward Loss = 0.2683, Mean Reward = 0.4486,Mask_loss = 0.0321 mask_mean = 0.3614:  21%|██▏       | 1171/5475 [3:14:12<11:09:31,  9.33s/it]




Epoch 1, Step 1181: Loss = 0.2433, Reward Loss = 0.2422, Mean Reward = 0.4218,Mask_loss = 0.0201 mask_mean = 0.3126:  22%|██▏       | 1181/5475 [3:15:56<12:43:10, 10.66s/it]




Epoch 1, Step 1191: Loss = 0.1832, Reward Loss = 0.1809, Mean Reward = 0.3388,Mask_loss = 0.0464 mask_mean = 0.2051:  22%|██▏       | 1191/5475 [3:17:41<12:56:52, 10.88s/it]




Epoch 1, Step 1201: Loss = 0.1983, Reward Loss = 0.1965, Mean Reward = 0.3618,Mask_loss = 0.0352 mask_mean = 0.2464:  22%|██▏       | 1200/5475 [3:19:15<11:12:37,  9.44s/it]




Epoch 1, Step 1201: Loss = 0.1983, Reward Loss = 0.1965, Mean Reward = 0.3618,Mask_loss = 0.0352 mask_mean = 0.2464:  22%|██▏       | 1201/5475 [3:19:24<13:44:06, 11.57s/it]




Epoch 1, Step 1211: Loss = 0.1968, Reward Loss = 0.1957, Mean Reward = 0.3603,Mask_loss = 0.0217 mask_mean = 0.2347:  22%|██▏       | 1211/5475 [3:20:55<11:25:49,  9.65s/it]




Epoch 1, Step 1221: Loss = 0.3012, Reward Loss = 0.2980, Mean Reward = 0.4690,Mask_loss = 0.0645 mask_mean = 0.4652:  22%|██▏       | 1221/5475 [3:22:25<9:53:44,  8.37s/it] 




Epoch 1, Step 1231: Loss = 0.2932, Reward Loss = 0.2896, Mean Reward = 0.4554,Mask_loss = 0.0729 mask_mean = 0.4795:  22%|██▏       | 1231/5475 [3:24:03<11:59:20, 10.17s/it]




Epoch 1, Step 1241: Loss = 0.2541, Reward Loss = 0.2526, Mean Reward = 0.4283,Mask_loss = 0.0285 mask_mean = 0.3376:  23%|██▎       | 1241/5475 [3:25:43<12:49:34, 10.91s/it]




Epoch 1, Step 1251: Loss = 0.2341, Reward Loss = 0.2328, Mean Reward = 0.3923,Mask_loss = 0.0263 mask_mean = 0.3457:  23%|██▎       | 1251/5475 [3:27:20<12:18:04, 10.48s/it]




Epoch 1, Step 1261: Loss = 0.1752, Reward Loss = 0.1731, Mean Reward = 0.3308,Mask_loss = 0.0419 mask_mean = 0.2113:  23%|██▎       | 1261/5475 [3:29:05<12:31:50, 10.71s/it]




Epoch 1, Step 1271: Loss = 0.2109, Reward Loss = 0.2100, Mean Reward = 0.3892,Mask_loss = 0.0183 mask_mean = 0.2263:  23%|██▎       | 1271/5475 [3:30:45<11:14:44,  9.63s/it]




Epoch 1, Step 1281: Loss = 0.1901, Reward Loss = 0.1878, Mean Reward = 0.3449,Mask_loss = 0.0458 mask_mean = 0.2269:  23%|██▎       | 1281/5475 [3:32:20<10:51:01,  9.31s/it]




Epoch 1, Step 1291: Loss = 0.3065, Reward Loss = 0.3033, Mean Reward = 0.4978,Mask_loss = 0.0635 mask_mean = 0.4311:  24%|██▎       | 1291/5475 [3:33:59<10:47:22,  9.28s/it]




Epoch 1, Step 1301: Loss = 0.3010, Reward Loss = 0.2956, Mean Reward = 0.4805,Mask_loss = 0.1083 mask_mean = 0.5400:  24%|██▍       | 1301/5475 [3:35:49<12:15:52, 10.58s/it]




Epoch 1, Step 1311: Loss = 0.3020, Reward Loss = 0.2929, Mean Reward = 0.4937,Mask_loss = 0.1812 mask_mean = 0.6158:  24%|██▍       | 1311/5475 [3:37:32<11:37:19, 10.05s/it]




Epoch 1, Step 1321: Loss = 0.2549, Reward Loss = 0.2534, Mean Reward = 0.4172,Mask_loss = 0.0303 mask_mean = 0.3943:  24%|██▍       | 1321/5475 [3:39:17<12:01:10, 10.42s/it]




Epoch 1, Step 1331: Loss = 0.2913, Reward Loss = 0.2891, Mean Reward = 0.4666,Mask_loss = 0.0445 mask_mean = 0.4132:  24%|██▍       | 1331/5475 [3:40:57<10:39:44,  9.26s/it]




Epoch 1, Step 1341: Loss = 0.1461, Reward Loss = 0.1418, Mean Reward = 0.2725,Mask_loss = 0.0855 mask_mean = 0.1739:  24%|██▍       | 1341/5475 [3:42:44<12:57:56, 11.29s/it]




Epoch 1, Step 1351: Loss = 0.2178, Reward Loss = 0.2163, Mean Reward = 0.3843,Mask_loss = 0.0287 mask_mean = 0.2807:  25%|██▍       | 1351/5475 [3:44:25<11:04:34,  9.67s/it]




Epoch 1, Step 1361: Loss = 0.2235, Reward Loss = 0.2217, Mean Reward = 0.3857,Mask_loss = 0.0363 mask_mean = 0.3035:  25%|██▍       | 1361/5475 [3:46:03<10:48:26,  9.46s/it]




Epoch 1, Step 1371: Loss = 0.1817, Reward Loss = 0.1791, Mean Reward = 0.3341,Mask_loss = 0.0517 mask_mean = 0.2154:  25%|██▌       | 1371/5475 [3:47:27<9:38:27,  8.46s/it] 




Epoch 1, Step 1381: Loss = 0.2809, Reward Loss = 0.2782, Mean Reward = 0.4468,Mask_loss = 0.0527 mask_mean = 0.4313:  25%|██▌       | 1381/5475 [3:49:03<10:23:45,  9.14s/it]




Epoch 1, Step 1391: Loss = 0.2610, Reward Loss = 0.2602, Mean Reward = 0.4497,Mask_loss = 0.0163 mask_mean = 0.3233:  25%|██▌       | 1391/5475 [3:50:38<10:40:51,  9.42s/it]




Epoch 1, Step 1401: Loss = 0.2815, Reward Loss = 0.2784, Mean Reward = 0.4639,Mask_loss = 0.0627 mask_mean = 0.4425:  26%|██▌       | 1400/5475 [3:52:14<11:24:58, 10.09s/it]




Epoch 1, Step 1401: Loss = 0.2815, Reward Loss = 0.2784, Mean Reward = 0.4639,Mask_loss = 0.0627 mask_mean = 0.4425:  26%|██▌       | 1401/5475 [3:52:20<11:41:19, 10.33s/it]




Epoch 1, Step 1411: Loss = 0.3017, Reward Loss = 0.2978, Mean Reward = 0.4849,Mask_loss = 0.0773 mask_mean = 0.4790:  26%|██▌       | 1411/5475 [3:54:06<11:42:09, 10.37s/it]




Epoch 1, Step 1421: Loss = 0.2893, Reward Loss = 0.2854, Mean Reward = 0.4591,Mask_loss = 0.0784 mask_mean = 0.4938:  26%|██▌       | 1421/5475 [3:55:36<11:24:32, 10.13s/it]




Epoch 1, Step 1431: Loss = 0.2213, Reward Loss = 0.2201, Mean Reward = 0.3930,Mask_loss = 0.0230 mask_mean = 0.2574:  26%|██▌       | 1431/5475 [3:57:09<10:55:54,  9.73s/it]




Epoch 1, Step 1441: Loss = 0.1781, Reward Loss = 0.1760, Mean Reward = 0.3184,Mask_loss = 0.0416 mask_mean = 0.2754:  26%|██▋       | 1441/5475 [3:58:56<13:16:46, 11.85s/it]




Epoch 1, Step 1451: Loss = 0.2528, Reward Loss = 0.2521, Mean Reward = 0.4247,Mask_loss = 0.0148 mask_mean = 0.3355:  27%|██▋       | 1451/5475 [4:00:42<11:31:36, 10.31s/it]




Epoch 1, Step 1461: Loss = 0.2763, Reward Loss = 0.2747, Mean Reward = 0.4536,Mask_loss = 0.0317 mask_mean = 0.3988:  27%|██▋       | 1461/5475 [4:02:15<11:03:02,  9.91s/it]




Epoch 1, Step 1471: Loss = 0.2375, Reward Loss = 0.2371, Mean Reward = 0.4198,Mask_loss = 0.0089 mask_mean = 0.2868:  27%|██▋       | 1471/5475 [4:03:52<10:47:26,  9.70s/it]




Epoch 1, Step 1481: Loss = 0.2674, Reward Loss = 0.2662, Mean Reward = 0.4592,Mask_loss = 0.0239 mask_mean = 0.3357:  27%|██▋       | 1481/5475 [4:05:20<9:42:27,  8.75s/it] 




Epoch 1, Step 1491: Loss = 0.2461, Reward Loss = 0.2450, Mean Reward = 0.4291,Mask_loss = 0.0217 mask_mean = 0.2914:  27%|██▋       | 1491/5475 [4:07:00<10:58:43,  9.92s/it]




Epoch 1, Step 1501: Loss = 0.2438, Reward Loss = 0.2425, Mean Reward = 0.4182,Mask_loss = 0.0250 mask_mean = 0.3198:  27%|██▋       | 1501/5475 [4:08:35<10:44:00,  9.72s/it]




Epoch 1, Step 1511: Loss = 0.2907, Reward Loss = 0.2889, Mean Reward = 0.4736,Mask_loss = 0.0353 mask_mean = 0.3903:  28%|██▊       | 1511/5475 [4:10:14<11:21:01, 10.31s/it]




Epoch 1, Step 1521: Loss = 0.2705, Reward Loss = 0.2688, Mean Reward = 0.4381,Mask_loss = 0.0348 mask_mean = 0.3878:  28%|██▊       | 1521/5475 [4:12:01<12:20:22, 11.23s/it]




Epoch 1, Step 1531: Loss = 0.1975, Reward Loss = 0.1967, Mean Reward = 0.3528,Mask_loss = 0.0169 mask_mean = 0.2703:  28%|██▊       | 1531/5475 [4:13:46<11:27:41, 10.46s/it]




Epoch 1, Step 1541: Loss = 0.2879, Reward Loss = 0.2853, Mean Reward = 0.4533,Mask_loss = 0.0522 mask_mean = 0.4344:  28%|██▊       | 1541/5475 [4:15:40<11:54:39, 10.90s/it]




Epoch 1, Step 1551: Loss = 0.2459, Reward Loss = 0.2436, Mean Reward = 0.4118,Mask_loss = 0.0452 mask_mean = 0.3599:  28%|██▊       | 1551/5475 [4:17:20<11:11:51, 10.27s/it]




Epoch 1, Step 1561: Loss = 0.2146, Reward Loss = 0.2137, Mean Reward = 0.3890,Mask_loss = 0.0184 mask_mean = 0.2633:  29%|██▊       | 1561/5475 [4:18:57<11:20:08, 10.43s/it]




Epoch 1, Step 1571: Loss = 0.2100, Reward Loss = 0.2090, Mean Reward = 0.3821,Mask_loss = 0.0184 mask_mean = 0.2576:  29%|██▊       | 1571/5475 [4:20:40<12:13:39, 11.28s/it]




Epoch 1, Step 1581: Loss = 0.2715, Reward Loss = 0.2694, Mean Reward = 0.4362,Mask_loss = 0.0426 mask_mean = 0.3859:  29%|██▉       | 1581/5475 [4:22:23<10:40:09,  9.86s/it]




Epoch 1, Step 1591: Loss = 0.2762, Reward Loss = 0.2728, Mean Reward = 0.4357,Mask_loss = 0.0673 mask_mean = 0.4631:  29%|██▉       | 1591/5475 [4:24:06<10:26:06,  9.67s/it]




Epoch 1, Step 1601: Loss = 0.1893, Reward Loss = 0.1869, Mean Reward = 0.3488,Mask_loss = 0.0484 mask_mean = 0.2132:  29%|██▉       | 1600/5475 [4:25:44<11:06:34, 10.32s/it]




Epoch 1, Step 1601: Loss = 0.1893, Reward Loss = 0.1869, Mean Reward = 0.3488,Mask_loss = 0.0484 mask_mean = 0.2132:  29%|██▉       | 1601/5475 [4:25:48<10:34:50,  9.83s/it]




Epoch 1, Step 1611: Loss = 0.1797, Reward Loss = 0.1773, Mean Reward = 0.3369,Mask_loss = 0.0483 mask_mean = 0.2149:  29%|██▉       | 1611/5475 [4:27:31<9:57:57,  9.29s/it] 




Epoch 1, Step 1621: Loss = 0.2550, Reward Loss = 0.2536, Mean Reward = 0.4249,Mask_loss = 0.0285 mask_mean = 0.3486:  30%|██▉       | 1621/5475 [4:29:10<10:21:33,  9.68s/it]




Epoch 1, Step 1631: Loss = 0.3097, Reward Loss = 0.3058, Mean Reward = 0.4809,Mask_loss = 0.0766 mask_mean = 0.4826:  30%|██▉       | 1631/5475 [4:30:46<10:07:50,  9.49s/it]




Epoch 1, Step 1641: Loss = 0.2368, Reward Loss = 0.2346, Mean Reward = 0.3908,Mask_loss = 0.0435 mask_mean = 0.3748:  30%|██▉       | 1641/5475 [4:32:22<10:18:01,  9.67s/it]




Epoch 1, Step 1651: Loss = 0.2895, Reward Loss = 0.2807, Mean Reward = 0.4731,Mask_loss = 0.1760 mask_mean = 0.6319:  30%|███       | 1651/5475 [4:34:03<10:20:30,  9.74s/it]




Epoch 1, Step 1661: Loss = 0.3050, Reward Loss = 0.2965, Mean Reward = 0.4853,Mask_loss = 0.1698 mask_mean = 0.6173:  30%|███       | 1661/5475 [4:35:52<11:08:28, 10.52s/it]




Epoch 1, Step 1671: Loss = 0.3055, Reward Loss = 0.2984, Mean Reward = 0.4754,Mask_loss = 0.1430 mask_mean = 0.5931:  31%|███       | 1671/5475 [4:37:23<9:31:22,  9.01s/it] 




Epoch 1, Step 1681: Loss = 0.2848, Reward Loss = 0.2833, Mean Reward = 0.4630,Mask_loss = 0.0300 mask_mean = 0.3835:  31%|███       | 1681/5475 [4:38:53<9:06:40,  8.65s/it]




Epoch 1, Step 1691: Loss = 0.2799, Reward Loss = 0.2777, Mean Reward = 0.4626,Mask_loss = 0.0439 mask_mean = 0.4152:  31%|███       | 1691/5475 [4:40:33<10:45:23, 10.23s/it]




Epoch 1, Step 1701: Loss = 0.3079, Reward Loss = 0.3015, Mean Reward = 0.4925,Mask_loss = 0.1291 mask_mean = 0.5503:  31%|███       | 1701/5475 [4:42:05<10:58:54, 10.48s/it]




Epoch 1, Step 1711: Loss = 0.2880, Reward Loss = 0.2832, Mean Reward = 0.4413,Mask_loss = 0.0978 mask_mean = 0.5299:  31%|███▏      | 1711/5475 [4:43:42<10:51:33, 10.39s/it]




Epoch 1, Step 1721: Loss = 0.2753, Reward Loss = 0.2742, Mean Reward = 0.4606,Mask_loss = 0.0219 mask_mean = 0.3578:  31%|███▏      | 1721/5475 [4:45:24<10:04:34,  9.66s/it]




Epoch 1, Step 1731: Loss = 0.1340, Reward Loss = 0.1275, Mean Reward = 0.2498,Mask_loss = 0.1301 mask_mean = 0.1389:  32%|███▏      | 1731/5475 [4:47:12<12:26:34, 11.96s/it]




Epoch 1, Step 1741: Loss = 0.1887, Reward Loss = 0.1866, Mean Reward = 0.3472,Mask_loss = 0.0427 mask_mean = 0.2235:  32%|███▏      | 1741/5475 [4:48:41<8:53:18,  8.57s/it] 




Epoch 1, Step 1751: Loss = 0.1980, Reward Loss = 0.1953, Mean Reward = 0.3501,Mask_loss = 0.0525 mask_mean = 0.2263:  32%|███▏      | 1751/5475 [4:50:26<10:46:03, 10.41s/it]




Epoch 1, Step 1761: Loss = 0.2471, Reward Loss = 0.2458, Mean Reward = 0.4206,Mask_loss = 0.0257 mask_mean = 0.3242:  32%|███▏      | 1761/5475 [4:52:02<9:23:03,  9.10s/it] 




Epoch 1, Step 1771: Loss = 0.2523, Reward Loss = 0.2497, Mean Reward = 0.4198,Mask_loss = 0.0525 mask_mean = 0.3490:  32%|███▏      | 1771/5475 [4:53:41<9:40:25,  9.40s/it] 




Epoch 1, Step 1781: Loss = 0.1941, Reward Loss = 0.1927, Mean Reward = 0.3772,Mask_loss = 0.0285 mask_mean = 0.1899:  33%|███▎      | 1781/5475 [4:55:23<10:21:35, 10.10s/it]




Epoch 1, Step 1791: Loss = 0.2874, Reward Loss = 0.2835, Mean Reward = 0.4513,Mask_loss = 0.0791 mask_mean = 0.4749:  33%|███▎      | 1791/5475 [4:57:08<10:43:59, 10.49s/it]




Epoch 1, Step 1801: Loss = 0.2839, Reward Loss = 0.2789, Mean Reward = 0.4491,Mask_loss = 0.0996 mask_mean = 0.5103:  33%|███▎      | 1800/5475 [4:58:43<10:22:11, 10.16s/it]




Epoch 1, Step 1801: Loss = 0.2839, Reward Loss = 0.2789, Mean Reward = 0.4491,Mask_loss = 0.0996 mask_mean = 0.5103:  33%|███▎      | 1801/5475 [4:58:50<11:48:17, 11.57s/it]




Epoch 1, Step 1811: Loss = 0.2947, Reward Loss = 0.2912, Mean Reward = 0.4708,Mask_loss = 0.0693 mask_mean = 0.4691:  33%|███▎      | 1811/5475 [5:00:32<9:59:13,  9.81s/it] 




Epoch 1, Step 1821: Loss = 0.3148, Reward Loss = 0.3073, Mean Reward = 0.4986,Mask_loss = 0.1501 mask_mean = 0.5832:  33%|███▎      | 1821/5475 [5:02:16<10:25:57, 10.28s/it]




Epoch 1, Step 1831: Loss = 0.2875, Reward Loss = 0.2780, Mean Reward = 0.4650,Mask_loss = 0.1906 mask_mean = 0.6555:  33%|███▎      | 1831/5475 [5:03:56<9:54:02,  9.78s/it] 




Epoch 1, Step 1841: Loss = 0.2919, Reward Loss = 0.2880, Mean Reward = 0.4552,Mask_loss = 0.0770 mask_mean = 0.4838:  34%|███▎      | 1841/5475 [5:05:39<10:33:17, 10.46s/it]




Epoch 1, Step 1851: Loss = 0.2680, Reward Loss = 0.2650, Mean Reward = 0.4300,Mask_loss = 0.0599 mask_mean = 0.4239:  34%|███▍      | 1851/5475 [5:07:24<10:29:37, 10.42s/it]




Epoch 1, Step 1861: Loss = 0.2093, Reward Loss = 0.2071, Mean Reward = 0.3624,Mask_loss = 0.0442 mask_mean = 0.2872:  34%|███▍      | 1861/5475 [5:09:00<10:36:23, 10.57s/it]




Epoch 1, Step 1871: Loss = 0.2276, Reward Loss = 0.2262, Mean Reward = 0.3977,Mask_loss = 0.0280 mask_mean = 0.2979:  34%|███▍      | 1871/5475 [5:10:31<9:13:15,  9.21s/it] 




Epoch 1, Step 1881: Loss = 0.2277, Reward Loss = 0.2269, Mean Reward = 0.4075,Mask_loss = 0.0164 mask_mean = 0.2458:  34%|███▍      | 1881/5475 [5:12:04<9:34:16,  9.59s/it]




Epoch 1, Step 1891: Loss = 0.1743, Reward Loss = 0.1714, Mean Reward = 0.3210,Mask_loss = 0.0575 mask_mean = 0.2102:  35%|███▍      | 1891/5475 [5:13:41<9:58:55, 10.03s/it] 




Epoch 1, Step 1901: Loss = 0.1562, Reward Loss = 0.1525, Mean Reward = 0.2928,Mask_loss = 0.0730 mask_mean = 0.1880:  35%|███▍      | 1901/5475 [5:15:18<9:30:11,  9.57s/it] 




Epoch 1, Step 1911: Loss = 0.2642, Reward Loss = 0.2631, Mean Reward = 0.4395,Mask_loss = 0.0219 mask_mean = 0.3380:  35%|███▍      | 1911/5475 [5:16:48<9:35:46,  9.69s/it]




Epoch 1, Step 1921: Loss = 0.1879, Reward Loss = 0.1863, Mean Reward = 0.3313,Mask_loss = 0.0317 mask_mean = 0.2646:  35%|███▌      | 1921/5475 [5:18:29<9:10:56,  9.30s/it] 




Epoch 1, Step 1931: Loss = 0.2531, Reward Loss = 0.2525, Mean Reward = 0.4352,Mask_loss = 0.0120 mask_mean = 0.3070:  35%|███▌      | 1931/5475 [5:20:07<10:08:42, 10.31s/it]




Epoch 1, Step 1941: Loss = 0.2535, Reward Loss = 0.2523, Mean Reward = 0.4057,Mask_loss = 0.0246 mask_mean = 0.3802:  35%|███▌      | 1941/5475 [5:21:39<9:09:38,  9.33s/it] 




Epoch 1, Step 1951: Loss = 0.2709, Reward Loss = 0.2656, Mean Reward = 0.4262,Mask_loss = 0.1068 mask_mean = 0.5491:  36%|███▌      | 1951/5475 [5:23:15<9:13:51,  9.43s/it] 




Epoch 1, Step 1961: Loss = 0.2926, Reward Loss = 0.2908, Mean Reward = 0.4658,Mask_loss = 0.0366 mask_mean = 0.4032:  36%|███▌      | 1961/5475 [5:25:04<11:23:58, 11.68s/it]




Epoch 1, Step 1971: Loss = 0.2225, Reward Loss = 0.2205, Mean Reward = 0.3862,Mask_loss = 0.0398 mask_mean = 0.2971:  36%|███▌      | 1971/5475 [5:26:42<9:37:46,  9.89s/it] 




Epoch 1, Step 1981: Loss = 0.2920, Reward Loss = 0.2868, Mean Reward = 0.4573,Mask_loss = 0.1045 mask_mean = 0.5247:  36%|███▌      | 1981/5475 [5:28:19<10:32:18, 10.86s/it]




Epoch 1, Step 1991: Loss = 0.2016, Reward Loss = 0.2003, Mean Reward = 0.3650,Mask_loss = 0.0258 mask_mean = 0.2545:  36%|███▋      | 1991/5475 [5:29:59<9:30:03,  9.82s/it] 




Epoch 1, Step 2001: Loss = 0.2535, Reward Loss = 0.2519, Mean Reward = 0.4251,Mask_loss = 0.0321 mask_mean = 0.3447:  37%|███▋      | 2000/5475 [5:31:39<10:03:30, 10.42s/it]




Epoch 1, Step 2001: Loss = 0.2535, Reward Loss = 0.2519, Mean Reward = 0.4251,Mask_loss = 0.0321 mask_mean = 0.3447:  37%|███▋      | 2001/5475 [5:31:44<10:01:32, 10.39s/it]




Epoch 1, Step 2011: Loss = 0.2595, Reward Loss = 0.2583, Mean Reward = 0.4243,Mask_loss = 0.0249 mask_mean = 0.3931:  37%|███▋      | 2011/5475 [5:33:23<10:27:16, 10.86s/it]




Epoch 1, Step 2021: Loss = 0.3043, Reward Loss = 0.2991, Mean Reward = 0.4919,Mask_loss = 0.1052 mask_mean = 0.5140:  37%|███▋      | 2021/5475 [5:34:53<8:49:11,  9.19s/it] 




Epoch 1, Step 2031: Loss = 0.2894, Reward Loss = 0.2860, Mean Reward = 0.4611,Mask_loss = 0.0685 mask_mean = 0.4771:  37%|███▋      | 2031/5475 [5:36:42<9:49:36, 10.27s/it] 




Epoch 1, Step 2041: Loss = 0.2620, Reward Loss = 0.2599, Mean Reward = 0.4245,Mask_loss = 0.0431 mask_mean = 0.4242:  37%|███▋      | 2041/5475 [5:38:22<9:03:46,  9.50s/it] 




Epoch 1, Step 2051: Loss = 0.2441, Reward Loss = 0.2431, Mean Reward = 0.4248,Mask_loss = 0.0198 mask_mean = 0.3018:  37%|███▋      | 2051/5475 [5:40:08<10:52:27, 11.43s/it]




Epoch 1, Step 2061: Loss = 0.2619, Reward Loss = 0.2600, Mean Reward = 0.4201,Mask_loss = 0.0381 mask_mean = 0.4045:  38%|███▊      | 2061/5475 [5:41:42<8:50:02,  9.32s/it] 




Epoch 1, Step 2071: Loss = 0.2389, Reward Loss = 0.2379, Mean Reward = 0.3962,Mask_loss = 0.0201 mask_mean = 0.3468:  38%|███▊      | 2071/5475 [5:43:23<9:33:26, 10.11s/it] 




Epoch 1, Step 2081: Loss = 0.2526, Reward Loss = 0.2501, Mean Reward = 0.4108,Mask_loss = 0.0482 mask_mean = 0.4074:  38%|███▊      | 2081/5475 [5:44:59<9:43:39, 10.32s/it]




Epoch 1, Step 2091: Loss = 0.2911, Reward Loss = 0.2886, Mean Reward = 0.4607,Mask_loss = 0.0512 mask_mean = 0.4283:  38%|███▊      | 2091/5475 [5:46:46<9:15:08,  9.84s/it] 




Epoch 1, Step 2101: Loss = 0.2586, Reward Loss = 0.2574, Mean Reward = 0.4189,Mask_loss = 0.0243 mask_mean = 0.3898:  38%|███▊      | 2101/5475 [5:48:29<9:05:20,  9.70s/it] 




Epoch 1, Step 2111: Loss = 0.2284, Reward Loss = 0.2277, Mean Reward = 0.3959,Mask_loss = 0.0142 mask_mean = 0.2985:  39%|███▊      | 2111/5475 [5:50:13<9:49:16, 10.51s/it] 




Epoch 1, Step 2121: Loss = 0.2480, Reward Loss = 0.2470, Mean Reward = 0.4254,Mask_loss = 0.0193 mask_mean = 0.3160:  39%|███▊      | 2121/5475 [5:52:01<10:55:59, 11.74s/it]




Epoch 1, Step 2131: Loss = 0.2355, Reward Loss = 0.2344, Mean Reward = 0.3942,Mask_loss = 0.0230 mask_mean = 0.3394:  39%|███▉      | 2131/5475 [5:53:49<10:43:01, 11.54s/it]




Epoch 1, Step 2141: Loss = 0.1845, Reward Loss = 0.1830, Mean Reward = 0.3501,Mask_loss = 0.0301 mask_mean = 0.2287:  39%|███▉      | 2141/5475 [5:55:22<8:58:47,  9.70s/it] 




Epoch 1, Step 2151: Loss = 0.2764, Reward Loss = 0.2737, Mean Reward = 0.4416,Mask_loss = 0.0555 mask_mean = 0.4413:  39%|███▉      | 2151/5475 [5:56:55<9:02:27,  9.79s/it]




Epoch 1, Step 2161: Loss = 0.2922, Reward Loss = 0.2899, Mean Reward = 0.4605,Mask_loss = 0.0459 mask_mean = 0.4247:  39%|███▉      | 2161/5475 [5:58:44<9:01:17,  9.80s/it] 




Epoch 1, Step 2171: Loss = 0.2719, Reward Loss = 0.2703, Mean Reward = 0.4468,Mask_loss = 0.0312 mask_mean = 0.3987:  40%|███▉      | 2171/5475 [6:00:28<9:03:24,  9.87s/it] 




Epoch 1, Step 2181: Loss = 0.1822, Reward Loss = 0.1798, Mean Reward = 0.3540,Mask_loss = 0.0479 mask_mean = 0.1856:  40%|███▉      | 2181/5475 [6:02:07<8:31:14,  9.31s/it]




Epoch 1, Step 2191: Loss = 0.1496, Reward Loss = 0.1455, Mean Reward = 0.2885,Mask_loss = 0.0805 mask_mean = 0.1574:  40%|████      | 2191/5475 [6:03:38<8:28:59,  9.30s/it]




Epoch 1, Step 2201: Loss = 0.2476, Reward Loss = 0.2461, Mean Reward = 0.4048,Mask_loss = 0.0307 mask_mean = 0.3749:  40%|████      | 2200/5475 [6:05:17<8:54:18,  9.79s/it] 




Epoch 1, Step 2201: Loss = 0.2476, Reward Loss = 0.2461, Mean Reward = 0.4048,Mask_loss = 0.0307 mask_mean = 0.3749:  40%|████      | 2201/5475 [6:05:23<9:04:20,  9.98s/it]




Epoch 1, Step 2211: Loss = 0.2323, Reward Loss = 0.2307, Mean Reward = 0.3856,Mask_loss = 0.0312 mask_mean = 0.3305:  40%|████      | 2211/5475 [6:07:06<9:21:09, 10.32s/it]




Epoch 1, Step 2221: Loss = 0.2380, Reward Loss = 0.2369, Mean Reward = 0.4109,Mask_loss = 0.0213 mask_mean = 0.3181:  41%|████      | 2221/5475 [6:08:42<9:23:34, 10.39s/it] 




Epoch 1, Step 2231: Loss = 0.2583, Reward Loss = 0.2575, Mean Reward = 0.4513,Mask_loss = 0.0157 mask_mean = 0.2987:  41%|████      | 2231/5475 [6:10:22<8:28:43,  9.41s/it] 




Epoch 1, Step 2241: Loss = 0.2468, Reward Loss = 0.2460, Mean Reward = 0.4327,Mask_loss = 0.0156 mask_mean = 0.2997:  41%|████      | 2241/5475 [6:12:01<8:49:22,  9.82s/it]




Epoch 1, Step 2251: Loss = 0.2795, Reward Loss = 0.2781, Mean Reward = 0.4735,Mask_loss = 0.0275 mask_mean = 0.3633:  41%|████      | 2251/5475 [6:13:45<8:42:38,  9.73s/it] 




Epoch 1, Step 2261: Loss = 0.2409, Reward Loss = 0.2406, Mean Reward = 0.4402,Mask_loss = 0.0064 mask_mean = 0.2660:  41%|████▏     | 2261/5475 [6:15:35<8:50:43,  9.91s/it] 




Epoch 1, Step 2271: Loss = 0.2901, Reward Loss = 0.2884, Mean Reward = 0.4629,Mask_loss = 0.0344 mask_mean = 0.3902:  41%|████▏     | 2271/5475 [6:17:22<8:51:15,  9.95s/it] 




Epoch 1, Step 2281: Loss = 0.2598, Reward Loss = 0.2585, Mean Reward = 0.4310,Mask_loss = 0.0250 mask_mean = 0.3787:  42%|████▏     | 2281/5475 [6:18:58<8:09:58,  9.20s/it]




Epoch 1, Step 2291: Loss = 0.2909, Reward Loss = 0.2890, Mean Reward = 0.4844,Mask_loss = 0.0379 mask_mean = 0.3856:  42%|████▏     | 2291/5475 [6:20:44<9:35:23, 10.84s/it]




Epoch 1, Step 2301: Loss = 0.2867, Reward Loss = 0.2848, Mean Reward = 0.4713,Mask_loss = 0.0362 mask_mean = 0.3833:  42%|████▏     | 2301/5475 [6:22:26<8:43:14,  9.89s/it]




Epoch 1, Step 2311: Loss = 0.2091, Reward Loss = 0.2069, Mean Reward = 0.3668,Mask_loss = 0.0439 mask_mean = 0.2975:  42%|████▏     | 2311/5475 [6:24:05<8:12:06,  9.33s/it]




Epoch 1, Step 2321: Loss = 0.2887, Reward Loss = 0.2875, Mean Reward = 0.4661,Mask_loss = 0.0245 mask_mean = 0.3677:  42%|████▏     | 2321/5475 [6:25:36<8:04:44,  9.22s/it]




Epoch 1, Step 2331: Loss = 0.3058, Reward Loss = 0.3028, Mean Reward = 0.4946,Mask_loss = 0.0589 mask_mean = 0.4316:  43%|████▎     | 2331/5475 [6:27:21<9:41:09, 11.09s/it]




Epoch 1, Step 2341: Loss = 0.3008, Reward Loss = 0.2968, Mean Reward = 0.4664,Mask_loss = 0.0794 mask_mean = 0.4874:  43%|████▎     | 2341/5475 [6:28:52<8:00:17,  9.19s/it]




Epoch 1, Step 2351: Loss = 0.2445, Reward Loss = 0.2435, Mean Reward = 0.4140,Mask_loss = 0.0190 mask_mean = 0.3142:  43%|████▎     | 2351/5475 [6:30:35<8:27:09,  9.74s/it]




Epoch 1, Step 2361: Loss = 0.2786, Reward Loss = 0.2772, Mean Reward = 0.4541,Mask_loss = 0.0276 mask_mean = 0.3657:  43%|████▎     | 2361/5475 [6:32:14<8:08:51,  9.42s/it]




Epoch 1, Step 2371: Loss = 0.2806, Reward Loss = 0.2790, Mean Reward = 0.4534,Mask_loss = 0.0307 mask_mean = 0.3854:  43%|████▎     | 2371/5475 [6:34:00<7:57:02,  9.22s/it] 




Epoch 1, Step 2381: Loss = 0.2322, Reward Loss = 0.2316, Mean Reward = 0.4091,Mask_loss = 0.0133 mask_mean = 0.2788:  43%|████▎     | 2381/5475 [6:35:43<9:10:02, 10.67s/it] 




Epoch 1, Step 2391: Loss = 0.2970, Reward Loss = 0.2947, Mean Reward = 0.4649,Mask_loss = 0.0450 mask_mean = 0.4391:  44%|████▎     | 2391/5475 [6:37:31<8:28:39,  9.90s/it] 




Epoch 1, Step 2401: Loss = 0.2618, Reward Loss = 0.2605, Mean Reward = 0.4414,Mask_loss = 0.0257 mask_mean = 0.3140:  44%|████▍     | 2400/5475 [6:38:59<7:50:29,  9.18s/it]




Epoch 1, Step 2401: Loss = 0.2618, Reward Loss = 0.2605, Mean Reward = 0.4414,Mask_loss = 0.0257 mask_mean = 0.3140:  44%|████▍     | 2401/5475 [6:39:04<8:08:57,  9.54s/it]




Epoch 1, Step 2411: Loss = 0.2578, Reward Loss = 0.2567, Mean Reward = 0.4120,Mask_loss = 0.0218 mask_mean = 0.3991:  44%|████▍     | 2411/5475 [6:40:34<7:29:24,  8.80s/it]




Epoch 1, Step 2421: Loss = 0.2639, Reward Loss = 0.2628, Mean Reward = 0.4388,Mask_loss = 0.0218 mask_mean = 0.3327:  44%|████▍     | 2421/5475 [6:42:06<7:50:26,  9.24s/it]




Epoch 1, Step 2431: Loss = 0.2988, Reward Loss = 0.2929, Mean Reward = 0.4683,Mask_loss = 0.1187 mask_mean = 0.5495:  44%|████▍     | 2431/5475 [6:43:47<7:34:57,  8.97s/it]




Epoch 1, Step 2441: Loss = 0.1625, Reward Loss = 0.1600, Mean Reward = 0.3078,Mask_loss = 0.0494 mask_mean = 0.2114:  45%|████▍     | 2441/5475 [6:45:23<7:41:45,  9.13s/it]




Epoch 1, Step 2451: Loss = 0.2240, Reward Loss = 0.2230, Mean Reward = 0.3977,Mask_loss = 0.0200 mask_mean = 0.2794:  45%|████▍     | 2451/5475 [6:47:10<8:58:24, 10.68s/it]




Epoch 1, Step 2461: Loss = 0.2421, Reward Loss = 0.2413, Mean Reward = 0.4201,Mask_loss = 0.0167 mask_mean = 0.2976:  45%|████▍     | 2461/5475 [6:48:55<9:34:49, 11.44s/it] 




Epoch 1, Step 2471: Loss = 0.2154, Reward Loss = 0.2144, Mean Reward = 0.3883,Mask_loss = 0.0218 mask_mean = 0.2671:  45%|████▌     | 2471/5475 [6:50:31<8:30:24, 10.19s/it]




Epoch 1, Step 2481: Loss = 0.2233, Reward Loss = 0.2224, Mean Reward = 0.3849,Mask_loss = 0.0187 mask_mean = 0.3154:  45%|████▌     | 2481/5475 [6:52:14<8:42:52, 10.48s/it]




Epoch 1, Step 2491: Loss = 0.2383, Reward Loss = 0.2374, Mean Reward = 0.4053,Mask_loss = 0.0178 mask_mean = 0.3181:  45%|████▌     | 2491/5475 [6:54:06<8:31:22, 10.28s/it] 




Epoch 1, Step 2501: Loss = 0.2429, Reward Loss = 0.2419, Mean Reward = 0.4027,Mask_loss = 0.0203 mask_mean = 0.3431:  46%|████▌     | 2501/5475 [6:55:41<7:51:22,  9.51s/it]




Epoch 1, Step 2511: Loss = 0.2290, Reward Loss = 0.2281, Mean Reward = 0.4109,Mask_loss = 0.0181 mask_mean = 0.2542:  46%|████▌     | 2511/5475 [6:57:40<9:26:59, 11.48s/it] 




Epoch 1, Step 2521: Loss = 0.2342, Reward Loss = 0.2336, Mean Reward = 0.4020,Mask_loss = 0.0120 mask_mean = 0.3372:  46%|████▌     | 2521/5475 [6:59:20<7:53:12,  9.61s/it]




Epoch 1, Step 2530: Loss = 0.1483, Reward Loss = 0.1442, Mean Reward = 0.2899,Mask_loss = 0.0816 mask_mean = 0.1463:  46%|████▌     | 2530/5475 [7:00:56<8:09:59,  9.98s/it]


KeyboardInterrupt: 

In [None]:
ratio.shape


torch.Size([16, 417])

In [14]:
import numpy as np
idx = 0
from captum.attr import visualization as viz
import torch.nn.functional as F

mask_gen_model.eval()

# tokens = tokenizer.convert_ids_to_tokens(gen_tokens[idx])
# texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I did not like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
# texts = "I really love this film. The acting was great, and the story was amazing. I would recommend this movie to everyone."
# # texts = "I don't like this movie. The acting was terrible, and the story was boring. I would not recommend this movie to anyone."
# messages_lambda = lambda texts: [
#             {"role": "system", "content": "Answer the question based on the context."},
#             # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
#             {"role": "user", "content": texts},
#         ]
# messages = messages_lambda(texts)
# messages_with_template_applied = tokenizer.apply_chat_template(
#             messages,
#             tokenize=False,
#             add_generation_prompt=True,
#         )

# # test_text = [{"text": texts, "label": None}]
# test_text = [{"sentence": texts, "label": None}]
# test_inputs = collate_fn(test_text).to(device)

test_inputs = next(iter(test_dataloader)).to(device)
tokens = tokenizer.convert_ids_to_tokens(test_inputs['input_ids'][idx])

# generate the answer for the test inputs
gen_outputs = model.generate(
            input_ids=test_inputs['input_ids'],
            attention_mask=test_inputs['attention_mask'],
            max_new_tokens=128,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            return_dict_in_generate=True,
            output_scores=True,
        )
input_ids = test_inputs['input_ids']
attention_mask = test_inputs['attention_mask']
gen_tokens = gen_outputs.sequences
pad_length = gen_tokens.size(1) - input_ids.size(1)
# get the attention mask for the generated tokens, and also mask the padding tokens
gen_attention_mask = F.pad(attention_mask, (0, pad_length), mode='constant', value=1)
context_mask = F.pad(test_inputs['context_mask'], (0, pad_length), mode='constant', value=0)
# (gen_tokens != pad_token_id).long() is the tokens mask, 1 for real tokens and 0 for padding tokens
unpaded_token_mask = (gen_tokens != pad_token_id).long()
unpaded_token_mask[:, :-pad_length] = 1
gen_attention_mask = gen_attention_mask * unpaded_token_mask

with torch.no_grad():
    # prompt_outputs = model(input_ids=test_inputs['input_ids'], attention_mask=test_inputs['attention_mask'], output_hidden_states=True, return_dict=True)
    prompt_outputs = model(input_ids=gen_tokens, attention_mask=gen_attention_mask, output_hidden_states=True, return_dict=True)

    last_hidden_state = prompt_outputs.hidden_states[-1].float()
    mask_logits = mask_gen_model(last_hidden_state)




In [17]:
idx = 0
test_ids = gen_tokens[idx]
test_mask = gen_attention_mask[idx]
test_mask_prob = torch.sigmoid(mask_logits[idx])
test_context_mask = context_mask[idx]

In [18]:
test_tokens = tokenizer.convert_ids_to_tokens(test_ids)
scores = test_mask_prob * test_context_mask

# remove special tokens
filtered_token_scores = [(token, score) for token, score in zip(test_tokens, scores) if token not in tokenizer.all_special_tokens]

# combine subwords
merged_tokens_scores = []
current_token = ""
current_score = 0
count = 0

for token, score in filtered_token_scores:
    if token.startswith("Ġ"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
            # merged_tokens_scores.append((" ", 0))  # 添加空格
        current_token = token[1:] # remove the speical character
        current_score = score.detach().cpu().numpy()
        count = 1
    elif token.endswith("Ċ"):
        if current_token:
            merged_tokens_scores.append((current_token, current_score / count))
        merged_tokens_scores.append(("<br><br>", 0))  # 添加换行符
        current_token = ""
        current_score = 0
        count = 0
    else:
        current_token += token
        current_score += score.detach().cpu().numpy()
        count += 1

if current_token:
    merged_tokens_scores.append((current_token, current_score / count))


# 根据分数高亮文本（示例中使用HTML标签）
highlighted_text = ""
for token, score in merged_tokens_scores:
    # 动态设置背景颜色：score为0时为白色，score为1时为绿色
    red = int((1 - score) * 255)
    green = 255
    blue = int((1 - score) * 255)
    color = f'rgb({red}, {green}, {blue})'
    highlighted_text += f'<span style="background-color: {color}; color: black;">{token}</span> '

# 打印高亮后的文本
from IPython.display import display, HTML
display(HTML(highlighted_text.strip()))

In [19]:
merged_tokens_scores

[('<|start_header_id|>system<|end_header_id|>', 0.0),
 ('<br><br>', 0),
 ('You', 0.0),
 ('are', 0.0),
 ('a', 0.0),
 ('chatbot', 0.0),
 ('for', 0.0),
 ('answering', 0.0),
 ('questions.', 0.0),
 ('You', 0.0),
 ('can', 0.0),
 ('help', 0.0),
 ('users', 0.0),
 ('with', 0.0),
 ('their', 0.0),
 ('questions', 0.0),
 ('via', 0.0),
 ('concise', 0.0),
 ('responses.<|start_header_id|>user<|end_header_id|>', 0.0),
 ('<br><br>', 0),
 ('Context:', 0.2531610056757927),
 ('Super', 0.032565828412771225),
 ('Bowl', 0.07368159294128418),
 ('50', 0.09878716617822647),
 ('was', 0.04960312694311142),
 ('an', 0.20547136664390564),
 ('American', 0.10978242754936218),
 ('football', 0.10689886659383774),
 ('game', 0.11067824810743332),
 ('to', 0.18042917549610138),
 ('determine', 0.19631139934062958),
 ('the', 0.21517875790596008),
 ('champion', 0.09768622368574142),
 ('of', 0.24110572040081024),
 ('the', 0.14212559163570404),
 ('National', 0.09662216156721115),
 ('Football', 0.04373268038034439),
 ('League', 0.

In [None]:
# Set print options for precision
torch.set_printoptions(precision=4)
np.set_printoptions(precision=4)
print(expl)

[0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00
 0.0000e+00 2.4742e-02 6.0288e-01 4.3501e-01 3.8011e-01 7.7301e-02
 4.1161e-04 9.5759e-05 8.4586e-01 4.7034e-01 7.6526e-01 9.9955e-03
 4.9102e-01 5.7602e-01 6.6518e-01 8.1205e-01 6.8847e-01 1.8136e-03
 1.8811e-04 9.5622e-01 9.5475e-01 2.4792e-01 1.1541e-01 1.3485e-01
 5.3585e-03 1.0000e+00 0.0000e+00 5.5876e-05 0.0000e+00 0.0000e+00
 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00]


In [None]:
expl_raw

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [None]:
mask_prob = torch.sigmoid(mask_logits)
(mask_prob * context_mask).sum(-1) / context_mask.sum(dim=-1)

tensor([0.1241, 0.2500, 0.7787, 0.0909, 0.1908, 0.1361, 0.3316, 0.5104, 0.2696,
        0.1001, 0.5931, 0.0812, 0.4139, 0.3334, 0.4853, 0.2987],
       device='cuda:0')

In [None]:
mask_prob

tensor([[9.9998e-01, 1.3412e-02, 2.1168e-05, 7.2758e-04, 9.5831e-01, 2.8712e-06,
         1.9530e-03, 6.8251e-04, 9.6814e-01, 7.6582e-04, 2.2164e-06, 1.7009e-02,
         2.4596e-09, 1.4519e-04, 3.0463e-06, 7.9092e-04, 4.0336e-06, 5.2874e-05,
         1.4755e-07, 5.3417e-09, 1.0859e-07, 8.2496e-07, 1.0534e-05, 2.1604e-06,
         6.6980e-10, 1.8266e-01, 5.4030e-07, 4.0081e-02, 3.5185e-04, 3.1933e-01,
         3.2836e-09, 1.5197e-06, 8.2054e-01, 8.0839e-01, 5.4567e-01, 4.8829e-01,
         1.0000e+00, 4.3639e-05, 3.5097e-07, 3.8482e-09, 1.9672e-04, 1.1242e-07,
         2.7982e-13, 5.4683e-09, 4.5730e-01, 9.5505e-01, 6.7815e-04, 7.4368e-02,
         9.9924e-01, 6.6332e-03, 1.0508e-08, 1.8774e-01]], device='cuda:0')

In [None]:
test_inputs['input_ids'][idx]

tensor([128000, 128006,   9125, 128007,    271,   2675,    527,    264,   6369,
          6465,    369,  27065,   6492,     13,   1472,    649,   1520,   3932,
           449,    872,   4860,   4669,  64694,  14847,    315,  27592,  45450,
            11,    477,  85165,  24093,     13, 128009, 128006,    882, 128007,
           271,   2028,   5818,    574,    279,   1888,   5818,    358,    617,
          3596,   3970,      0,   1063,  16451,   1051,  27873,     11,    719,
         15718,    574,   2294,     13, 128009, 128006,  78191, 128007,    271],
       device='cuda:0')

In [None]:
tokenizer.convert_ids_to_tokens(271)

'ĊĊ'

[0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.
 0.     0.     0.     0.     0.     0.     0.     0.6432 0.8566 0.5179
 0.2417 0.     0.1211 0.3355 0.618  0.5345 0.1401 0.3401 0.4729 0.3531
 0.661  0.7049 0.0297 0.1724 0.9905 1.     0.1606 0.1107 0.2363 0.2891
 0.116  0.0777 0.     0.     0.     0.     0.     0.     0.     0.    ]


In [None]:
(torch.sigmoid(mask_logits) * context_mask)[idx]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.8527, 0.1770, 0.1719, 0.1660, 0.1613, 0.1549, 0.1581, 0.1622, 0.1666,
        0.1635, 0.1568, 0.1589, 0.1618, 0.1634, 0.1668, 0.1647, 0.1546, 0.1576,
        0.1788, 0.1779, 0.1592, 0.1581, 0.1607, 0.1619, 0.1586, 0.1583, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       device='cuda:0')

In [None]:
torch.sigmoid(mask_logits)

tensor([[0.4822, 0.4821, 0.4816,  ..., 0.4920, 0.4874, 0.4901],
        [0.4843, 0.4851, 0.4855,  ..., 0.4941, 0.4861, 0.4891],
        [0.4785, 0.4805, 0.4805,  ..., 0.4918, 0.4823, 0.4864],
        ...,
        [0.4753, 0.4758, 0.4761,  ..., 0.4847, 0.4761, 0.4802],
        [0.4876, 0.4883, 0.4882,  ..., 0.4887, 0.4880, 0.4943],
        [0.4843, 0.4851, 0.4852,  ..., 0.4946, 0.4853, 0.4947]],
       device='cuda:0')

In [None]:
mask_gen_model

MaskGeneratingModel(
  (explain_map): MLP(
    (input_layer): Linear(in_features=4096, out_features=1024, bias=True)
    (attention_layers): ModuleList(
      (0-1): 2 x MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (layers): ModuleList(
      (0-1): 2 x Sequential(
        (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (output_layer): Linear(in_features=1024, out_features=1, bias=True)
  )
)

In [None]:
print(tokens[35])
print(expl[35])

<|end_header_id|>
0.0


In [None]:
texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I really didn't like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
messages_lambda = lambda texts: [
    {"role": "system", "content": "You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE."},
    # {"role": "system", "content": "You are a chatbot for sentimate analysis."},
    {"role": "user", "content": texts},
]
messages = messages_lambda(texts)
messages_with_template_applied = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
print(messages_with_template_applied)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a chatbot for sentimate analysis. You can help users with their questions via concise responses of POSITIVE, or NEGATIVE.<|eot_id|><|start_header_id|>user<|end_header_id|>

This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


