In [1]:
import torch 
from transformers import ViTImageProcessor, ViTForImageClassification, ViTModel, ViTConfig
from PIL import Image
import requests
import matplotlib.pyplot as plt

import numpy as np
import cv2
from datasets import load_dataset,load_metric
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

from accelerate import Accelerator

accelerator = Accelerator()
device = accelerator.device


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "http://farm3.staticflickr.com/2066/1798910782_5536af8767_z.jpg"
# url = "http://farm1.staticflickr.com/184/399924547_98e6cef97a_z.jpg"
url = "http://farm1.staticflickr.com/128/318959350_1a39aae18c_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

pretrained_name = 'google/vit-base-patch16-224'
# pretrained_name = 'vit-base-patch16-224-finetuned-imageneteval'
# pretrained_name = 'openai/clip-vit-base-patch32'
config = ViTConfig.from_pretrained(pretrained_name)
processor = ViTImageProcessor.from_pretrained(pretrained_name)
# get mean and std to unnormalize the processed images
mean, std = processor.image_mean, processor.image_std

pred_model = ViTForImageClassification.from_pretrained(pretrained_name)
pred_model.to(device)
# set to eval mode
pred_model.eval()

with torch.no_grad():
    inputs = processor(images=image, return_tensors="pt")
    inputs.to(device)
    outputs = pred_model(**inputs, output_hidden_states=True)
    logits = outputs.logits
    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("Predicted class:", pred_model.config.id2label[predicted_class_idx])

Predicted class: golden retriever


In [3]:
from maskgen.models.vision_maskgen_model9 import MaskGeneratingModel

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=config.hidden_size, num_classes=config.num_labels)
mask_gen_model.to(device)
print()




# training

In [4]:
from torch.utils.data import DataLoader
def load_data(seed=42): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=seed)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=seed)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = load_data()
dataset = load_dataset("mrm8488/ImageNet1K-val")['train']

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
if "height" in processor.size:
    size = (processor.size["height"], processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in processor.size:
    size = processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = processor.size.get("longest_edge")

transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_ds.set_transform(preprocess)
test_ds.set_transform(preprocess)
dataset.set_transform(preprocess)

Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.


# training params

In [5]:
batch_size = 256
# train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
# train_dataloader = DataLoader(test_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
train_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)


n_steps = 2
n_samples = 5

params_to_optimize = [name for name, param in mask_gen_model.named_parameters() if param.requires_grad]
print("params to be optimized: ")
print(params_to_optimize)

params to be optimized: 
['similarity_measure.logit_scale', 'similarity_measure.pred_map.input_layer.weight', 'similarity_measure.pred_map.input_layer.bias', 'similarity_measure.pred_map.layers.0.0.weight', 'similarity_measure.pred_map.layers.0.0.bias', 'similarity_measure.pred_map.layers.0.3.weight', 'similarity_measure.pred_map.layers.0.3.bias', 'similarity_measure.pred_map.layers.0.6.weight', 'similarity_measure.pred_map.layers.0.6.bias', 'similarity_measure.pred_map.layers.1.0.weight', 'similarity_measure.pred_map.layers.1.0.bias', 'similarity_measure.pred_map.layers.1.3.weight', 'similarity_measure.pred_map.layers.1.3.bias', 'similarity_measure.pred_map.layers.1.6.weight', 'similarity_measure.pred_map.layers.1.6.bias', 'similarity_measure.pred_map.output_layer.weight', 'similarity_measure.pred_map.output_layer.bias', 'similarity_measure.explain_map.input_layer.weight', 'similarity_measure.explain_map.input_layer.bias', 'similarity_measure.explain_map.layers.0.0.weight', 'similarit

In [6]:
from tqdm import tqdm

params_to_optimize = [param for param in mask_gen_model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)


print()

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        pixel_values = data['pixel_values'].to(device)
        loss_dict = mask_gen_model.train_one_batch(pixel_values, optimizer=optimizer, n_steps=n_steps, n_samples=n_samples)
        # scheduler.step()
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
                             f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
                            #  f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
                            #  f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f} "
                             f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                             f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        if idx % 10 == 0:
            print()
        if (idx) % 10 == 0:
            
            torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 



torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_final_{epoch}_{idx}.pth') 





Epoch 1, Step 1: Loss = 0.5130, Reward Loss = -0.2764, Mask Loss = 0.7894 mask_mean = 0.4313 prob_mean = 0.5886 :   0%|          | 0/196 [00:09<?, ?it/s]




Epoch 1, Step 11: Loss = 0.1432, Reward Loss = -0.4318, Mask Loss = 0.5750 mask_mean = 0.3791 prob_mean = 0.4931 :   5%|▌         | 10/196 [01:37<27:11,  8.77s/it]




Epoch 1, Step 21: Loss = 0.0528, Reward Loss = -0.4230, Mask Loss = 0.4758 mask_mean = 0.3298 prob_mean = 0.4633 :  10%|█         | 20/196 [03:06<25:46,  8.79s/it]




Epoch 1, Step 31: Loss = 0.0136, Reward Loss = -0.4121, Mask Loss = 0.4257 mask_mean = 0.3241 prob_mean = 0.4321 :  15%|█▌        | 30/196 [04:34<24:12,  8.75s/it]




Epoch 1, Step 41: Loss = -0.0137, Reward Loss = -0.3738, Mask Loss = 0.3601 mask_mean = 0.2978 prob_mean = 0.3838 :  20%|██        | 40/196 [06:03<23:04,  8.88s/it]




Epoch 1, Step 51: Loss = 0.0171, Reward Loss = -0.4115, Mask Loss = 0.4286 mask_mean = 0.3172 prob_mean = 0.4341 :  26%|██▌       | 50/196 [07:31<21:18,  8.76s/it] 




Epoch 1, Step 61: Loss = -0.0387, Reward Loss = -0.4220, Mask Loss = 0.3833 mask_mean = 0.2960 prob_mean = 0.4225 :  31%|███       | 60/196 [09:00<19:51,  8.76s/it]




Epoch 1, Step 71: Loss = 0.0008, Reward Loss = -0.4388, Mask Loss = 0.4396 mask_mean = 0.3066 prob_mean = 0.4480 :  36%|███▌      | 70/196 [10:29<18:43,  8.91s/it] 




Epoch 1, Step 81: Loss = -0.0266, Reward Loss = -0.4550, Mask Loss = 0.4284 mask_mean = 0.2982 prob_mean = 0.4533 :  41%|████      | 80/196 [11:59<17:14,  8.92s/it]




Epoch 1, Step 91: Loss = -0.0197, Reward Loss = -0.4424, Mask Loss = 0.4227 mask_mean = 0.2821 prob_mean = 0.4592 :  46%|████▌     | 90/196 [13:27<15:31,  8.78s/it]




Epoch 1, Step 101: Loss = -0.0312, Reward Loss = -0.3872, Mask Loss = 0.3560 mask_mean = 0.2807 prob_mean = 0.3985 :  51%|█████     | 100/196 [14:56<14:02,  8.78s/it]




Epoch 1, Step 111: Loss = -0.0058, Reward Loss = -0.4441, Mask Loss = 0.4384 mask_mean = 0.2858 prob_mean = 0.4549 :  56%|█████▌    | 110/196 [16:24<12:36,  8.80s/it]




Epoch 1, Step 121: Loss = -0.0304, Reward Loss = -0.4294, Mask Loss = 0.3990 mask_mean = 0.2747 prob_mean = 0.4370 :  61%|██████    | 120/196 [17:53<11:15,  8.89s/it]




Epoch 1, Step 131: Loss = -0.0331, Reward Loss = -0.4639, Mask Loss = 0.4308 mask_mean = 0.2827 prob_mean = 0.4642 :  66%|██████▋   | 130/196 [19:23<09:46,  8.89s/it]




Epoch 1, Step 141: Loss = -0.0172, Reward Loss = -0.4494, Mask Loss = 0.4322 mask_mean = 0.2781 prob_mean = 0.4517 :  71%|███████▏  | 140/196 [20:50<08:10,  8.76s/it]




Epoch 1, Step 151: Loss = 0.0261, Reward Loss = -0.4682, Mask Loss = 0.4942 mask_mean = 0.2937 prob_mean = 0.4871 :  77%|███████▋  | 150/196 [22:18<06:40,  8.72s/it] 




Epoch 1, Step 161: Loss = 0.0092, Reward Loss = -0.4711, Mask Loss = 0.4803 mask_mean = 0.2929 prob_mean = 0.4782 :  82%|████████▏ | 160/196 [23:46<05:13,  8.72s/it] 




Epoch 1, Step 171: Loss = 0.0259, Reward Loss = -0.4939, Mask Loss = 0.5198 mask_mean = 0.2972 prob_mean = 0.4972 :  87%|████████▋ | 170/196 [25:12<03:42,  8.55s/it] 




Epoch 1, Step 181: Loss = -0.0099, Reward Loss = -0.4432, Mask Loss = 0.4333 mask_mean = 0.2793 prob_mean = 0.4561 :  92%|█████████▏| 180/196 [26:40<02:20,  8.79s/it]




Epoch 1, Step 191: Loss = -0.0179, Reward Loss = -0.4733, Mask Loss = 0.4554 mask_mean = 0.2909 prob_mean = 0.4719 :  97%|█████████▋| 190/196 [28:09<00:52,  8.79s/it]




Epoch 1, Step 196: Loss = -0.0224, Reward Loss = -0.4865, Mask Loss = 0.4641 mask_mean = 0.2932 prob_mean = 0.4762 : 100%|██████████| 196/196 [28:47<00:00,  8.82s/it]
Epoch 2, Step 1: Loss = -0.0275, Reward Loss = -0.4731, Mask Loss = 0.4456 mask_mean = 0.2841 prob_mean = 0.4715 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 2, Step 11: Loss = 0.0023, Reward Loss = -0.4900, Mask Loss = 0.4924 mask_mean = 0.2933 prob_mean = 0.4880 :   5%|▌         | 10/196 [01:37<27:20,  8.82s/it]




Epoch 2, Step 21: Loss = 0.0063, Reward Loss = -0.5096, Mask Loss = 0.5159 mask_mean = 0.2857 prob_mean = 0.5070 :  10%|█         | 20/196 [03:05<25:37,  8.73s/it] 




Epoch 2, Step 31: Loss = -0.0139, Reward Loss = -0.4785, Mask Loss = 0.4646 mask_mean = 0.2875 prob_mean = 0.4778 :  15%|█▌        | 30/196 [04:33<24:13,  8.75s/it]




Epoch 2, Step 41: Loss = 0.0109, Reward Loss = -0.5390, Mask Loss = 0.5499 mask_mean = 0.3017 prob_mean = 0.5367 :  20%|██        | 40/196 [06:01<22:48,  8.77s/it] 




Epoch 2, Step 51: Loss = 0.0000, Reward Loss = -0.4668, Mask Loss = 0.4668 mask_mean = 0.2955 prob_mean = 0.4660 :  26%|██▌       | 50/196 [07:30<21:35,  8.87s/it] 




Epoch 2, Step 61: Loss = 0.0179, Reward Loss = -0.4890, Mask Loss = 0.5069 mask_mean = 0.3049 prob_mean = 0.4871 :  31%|███       | 60/196 [11:03<29:58, 13.22s/it]   




Epoch 2, Step 71: Loss = 0.0119, Reward Loss = -0.5075, Mask Loss = 0.5194 mask_mean = 0.3052 prob_mean = 0.5051 :  36%|███▌      | 70/196 [12:34<19:12,  9.15s/it] 




Epoch 2, Step 81: Loss = -0.0041, Reward Loss = -0.4588, Mask Loss = 0.4548 mask_mean = 0.2942 prob_mean = 0.4515 :  41%|████      | 80/196 [14:06<17:42,  9.16s/it]




Epoch 2, Step 91: Loss = 0.0303, Reward Loss = -0.4914, Mask Loss = 0.5217 mask_mean = 0.3126 prob_mean = 0.4918 :  46%|████▌     | 90/196 [15:39<16:06,  9.11s/it] 




Epoch 2, Step 101: Loss = 0.0036, Reward Loss = -0.4891, Mask Loss = 0.4927 mask_mean = 0.3005 prob_mean = 0.4831 :  51%|█████     | 100/196 [17:11<14:41,  9.18s/it]




Epoch 2, Step 111: Loss = 0.0144, Reward Loss = -0.5071, Mask Loss = 0.5215 mask_mean = 0.3036 prob_mean = 0.5082 :  56%|█████▌    | 110/196 [18:43<13:02,  9.10s/it] 




Epoch 2, Step 121: Loss = -0.0068, Reward Loss = -0.4378, Mask Loss = 0.4310 mask_mean = 0.2771 prob_mean = 0.4600 :  61%|██████    | 120/196 [20:14<11:21,  8.97s/it]




Epoch 2, Step 131: Loss = 0.0239, Reward Loss = -0.5207, Mask Loss = 0.5446 mask_mean = 0.2884 prob_mean = 0.5238 :  66%|██████▋   | 130/196 [21:45<10:00,  9.10s/it] 




Epoch 2, Step 141: Loss = -0.0279, Reward Loss = -0.4790, Mask Loss = 0.4511 mask_mean = 0.2790 prob_mean = 0.4817 :  71%|███████▏  | 140/196 [23:16<08:30,  9.12s/it]




Epoch 2, Step 151: Loss = -0.0123, Reward Loss = -0.5177, Mask Loss = 0.5054 mask_mean = 0.2823 prob_mean = 0.5131 :  77%|███████▋  | 150/196 [24:49<07:02,  9.19s/it]




Epoch 2, Step 161: Loss = 0.0111, Reward Loss = -0.4826, Mask Loss = 0.4938 mask_mean = 0.2929 prob_mean = 0.4772 :  82%|████████▏ | 160/196 [26:20<05:25,  9.05s/it] 




Epoch 2, Step 171: Loss = 0.0153, Reward Loss = -0.5001, Mask Loss = 0.5155 mask_mean = 0.2987 prob_mean = 0.4940 :  87%|████████▋ | 170/196 [27:52<03:54,  9.03s/it] 




Epoch 2, Step 181: Loss = 0.0162, Reward Loss = -0.5028, Mask Loss = 0.5189 mask_mean = 0.2901 prob_mean = 0.5011 :  92%|█████████▏| 180/196 [29:23<02:25,  9.06s/it] 




Epoch 2, Step 191: Loss = -0.0271, Reward Loss = -0.4972, Mask Loss = 0.4702 mask_mean = 0.2735 prob_mean = 0.4922 :  97%|█████████▋| 190/196 [30:54<00:54,  9.11s/it]




Epoch 2, Step 196: Loss = -0.0164, Reward Loss = -0.5197, Mask Loss = 0.5033 mask_mean = 0.2900 prob_mean = 0.5101 : 100%|██████████| 196/196 [31:34<00:00,  9.67s/it]
Epoch 3, Step 1: Loss = -0.0046, Reward Loss = -0.4957, Mask Loss = 0.4911 mask_mean = 0.2834 prob_mean = 0.4960 :   0%|          | 0/196 [00:09<?, ?it/s]




Epoch 3, Step 11: Loss = 0.0177, Reward Loss = -0.5082, Mask Loss = 0.5258 mask_mean = 0.2981 prob_mean = 0.5037 :   5%|▌         | 10/196 [01:39<28:00,  9.03s/it]




Epoch 3, Step 21: Loss = 0.0168, Reward Loss = -0.4958, Mask Loss = 0.5125 mask_mean = 0.3054 prob_mean = 0.4904 :  10%|█         | 20/196 [03:09<26:19,  8.97s/it]




Epoch 3, Step 31: Loss = 0.0192, Reward Loss = -0.5156, Mask Loss = 0.5348 mask_mean = 0.2938 prob_mean = 0.5131 :  15%|█▌        | 30/196 [04:38<24:10,  8.74s/it] 




Epoch 3, Step 41: Loss = -0.0155, Reward Loss = -0.5027, Mask Loss = 0.4872 mask_mean = 0.2869 prob_mean = 0.4949 :  20%|██        | 40/196 [06:08<22:59,  8.85s/it]




Epoch 3, Step 51: Loss = 0.0270, Reward Loss = -0.5568, Mask Loss = 0.5838 mask_mean = 0.2968 prob_mean = 0.5486 :  26%|██▌       | 50/196 [07:38<21:43,  8.93s/it] 




Epoch 3, Step 61: Loss = 0.0180, Reward Loss = -0.5284, Mask Loss = 0.5464 mask_mean = 0.2976 prob_mean = 0.5236 :  31%|███       | 60/196 [09:08<20:25,  9.01s/it] 




Epoch 3, Step 71: Loss = 0.0188, Reward Loss = -0.5016, Mask Loss = 0.5204 mask_mean = 0.3061 prob_mean = 0.4988 :  36%|███▌      | 70/196 [10:38<18:46,  8.94s/it] 




Epoch 3, Step 81: Loss = -0.0134, Reward Loss = -0.5078, Mask Loss = 0.4944 mask_mean = 0.2822 prob_mean = 0.5052 :  41%|████      | 80/196 [12:07<17:05,  8.84s/it]




Epoch 3, Step 91: Loss = 0.0077, Reward Loss = -0.5339, Mask Loss = 0.5416 mask_mean = 0.2935 prob_mean = 0.5334 :  46%|████▌     | 90/196 [13:36<15:35,  8.82s/it] 




Epoch 3, Step 101: Loss = 0.0296, Reward Loss = -0.5361, Mask Loss = 0.5657 mask_mean = 0.3038 prob_mean = 0.5342 :  51%|█████     | 100/196 [15:05<14:01,  8.77s/it]




Epoch 3, Step 111: Loss = 0.0051, Reward Loss = -0.5236, Mask Loss = 0.5287 mask_mean = 0.3032 prob_mean = 0.5164 :  56%|█████▌    | 110/196 [16:33<12:36,  8.79s/it] 




Epoch 3, Step 121: Loss = 0.0309, Reward Loss = -0.5345, Mask Loss = 0.5654 mask_mean = 0.2929 prob_mean = 0.5516 :  61%|██████    | 120/196 [18:03<11:14,  8.88s/it] 




Epoch 3, Step 131: Loss = 0.0212, Reward Loss = -0.5460, Mask Loss = 0.5672 mask_mean = 0.2996 prob_mean = 0.5416 :  66%|██████▋   | 130/196 [19:31<09:46,  8.89s/it] 




Epoch 3, Step 141: Loss = 0.0190, Reward Loss = -0.4966, Mask Loss = 0.5156 mask_mean = 0.3085 prob_mean = 0.4901 :  71%|███████▏  | 140/196 [20:59<08:07,  8.71s/it] 




Epoch 3, Step 151: Loss = 0.0166, Reward Loss = -0.5042, Mask Loss = 0.5208 mask_mean = 0.2990 prob_mean = 0.5006 :  77%|███████▋  | 150/196 [22:28<06:44,  8.79s/it] 




Epoch 3, Step 161: Loss = 0.0182, Reward Loss = -0.4773, Mask Loss = 0.4955 mask_mean = 0.3055 prob_mean = 0.4756 :  82%|████████▏ | 160/196 [23:57<05:16,  8.79s/it] 




Epoch 3, Step 171: Loss = 0.0204, Reward Loss = -0.5161, Mask Loss = 0.5365 mask_mean = 0.2980 prob_mean = 0.5104 :  87%|████████▋ | 170/196 [25:26<03:49,  8.81s/it] 




Epoch 3, Step 181: Loss = 0.0030, Reward Loss = -0.5273, Mask Loss = 0.5303 mask_mean = 0.2893 prob_mean = 0.5191 :  92%|█████████▏| 180/196 [26:54<02:20,  8.80s/it] 




Epoch 3, Step 191: Loss = -0.0049, Reward Loss = -0.5261, Mask Loss = 0.5212 mask_mean = 0.2866 prob_mean = 0.5227 :  97%|█████████▋| 190/196 [28:23<00:52,  8.80s/it]




Epoch 3, Step 196: Loss = 0.0141, Reward Loss = -0.5318, Mask Loss = 0.5458 mask_mean = 0.2960 prob_mean = 0.5254 : 100%|██████████| 196/196 [29:01<00:00,  8.88s/it] 
Epoch 4, Step 1: Loss = 0.0210, Reward Loss = -0.5159, Mask Loss = 0.5369 mask_mean = 0.2915 prob_mean = 0.5090 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 4, Step 11: Loss = -0.0167, Reward Loss = -0.5221, Mask Loss = 0.5054 mask_mean = 0.2875 prob_mean = 0.5187 :   5%|▌         | 10/196 [01:37<27:27,  8.86s/it]




Epoch 4, Step 21: Loss = -0.0047, Reward Loss = -0.4881, Mask Loss = 0.4834 mask_mean = 0.2880 prob_mean = 0.4817 :  10%|█         | 20/196 [03:06<25:53,  8.83s/it]




Epoch 4, Step 31: Loss = 0.0231, Reward Loss = -0.5314, Mask Loss = 0.5545 mask_mean = 0.2961 prob_mean = 0.5255 :  15%|█▌        | 30/196 [04:34<24:18,  8.79s/it] 




Epoch 4, Step 41: Loss = 0.0046, Reward Loss = -0.4937, Mask Loss = 0.4983 mask_mean = 0.2919 prob_mean = 0.4888 :  20%|██        | 40/196 [06:04<23:51,  9.18s/it] 




Epoch 4, Step 51: Loss = 0.0215, Reward Loss = -0.5616, Mask Loss = 0.5832 mask_mean = 0.2927 prob_mean = 0.5591 :  26%|██▌       | 50/196 [07:33<21:38,  8.90s/it]




Epoch 4, Step 61: Loss = 0.0026, Reward Loss = -0.5147, Mask Loss = 0.5173 mask_mean = 0.2848 prob_mean = 0.5098 :  31%|███       | 60/196 [09:02<20:01,  8.83s/it] 




Epoch 4, Step 71: Loss = 0.0186, Reward Loss = -0.5406, Mask Loss = 0.5593 mask_mean = 0.2915 prob_mean = 0.5369 :  36%|███▌      | 70/196 [10:31<18:36,  8.86s/it] 




Epoch 4, Step 81: Loss = 0.0147, Reward Loss = -0.5410, Mask Loss = 0.5557 mask_mean = 0.2937 prob_mean = 0.5363 :  41%|████      | 80/196 [12:00<17:03,  8.83s/it] 




Epoch 4, Step 91: Loss = 0.0097, Reward Loss = -0.5170, Mask Loss = 0.5267 mask_mean = 0.2888 prob_mean = 0.5136 :  46%|████▌     | 90/196 [13:29<15:41,  8.88s/it]




Epoch 4, Step 101: Loss = 0.0147, Reward Loss = -0.5167, Mask Loss = 0.5314 mask_mean = 0.2922 prob_mean = 0.5104 :  51%|█████     | 100/196 [14:57<13:50,  8.65s/it]




Epoch 4, Step 111: Loss = 0.0148, Reward Loss = -0.4993, Mask Loss = 0.5141 mask_mean = 0.2877 prob_mean = 0.4995 :  56%|█████▌    | 110/196 [16:24<12:17,  8.57s/it] 




Epoch 4, Step 121: Loss = -0.0037, Reward Loss = -0.5075, Mask Loss = 0.5038 mask_mean = 0.2869 prob_mean = 0.5027 :  61%|██████    | 120/196 [17:53<11:11,  8.84s/it]




Epoch 4, Step 131: Loss = 0.0076, Reward Loss = -0.5424, Mask Loss = 0.5500 mask_mean = 0.2939 prob_mean = 0.5372 :  66%|██████▋   | 130/196 [19:22<09:46,  8.89s/it] 




Epoch 4, Step 141: Loss = 0.0095, Reward Loss = -0.5138, Mask Loss = 0.5233 mask_mean = 0.2808 prob_mean = 0.5117 :  71%|███████▏  | 140/196 [20:52<08:15,  8.84s/it]




Epoch 4, Step 151: Loss = 0.0303, Reward Loss = -0.5135, Mask Loss = 0.5438 mask_mean = 0.2852 prob_mean = 0.5249 :  77%|███████▋  | 150/196 [22:21<06:43,  8.78s/it] 




Epoch 4, Step 161: Loss = 0.0063, Reward Loss = -0.5072, Mask Loss = 0.5135 mask_mean = 0.2913 prob_mean = 0.5038 :  82%|████████▏ | 160/196 [23:49<05:16,  8.79s/it] 




Epoch 4, Step 171: Loss = 0.0114, Reward Loss = -0.5403, Mask Loss = 0.5517 mask_mean = 0.2844 prob_mean = 0.5348 :  87%|████████▋ | 170/196 [25:16<03:43,  8.60s/it] 




Epoch 4, Step 181: Loss = 0.0181, Reward Loss = -0.5213, Mask Loss = 0.5394 mask_mean = 0.2916 prob_mean = 0.5157 :  92%|█████████▏| 180/196 [26:42<02:17,  8.59s/it] 




Epoch 4, Step 191: Loss = 0.0085, Reward Loss = -0.5344, Mask Loss = 0.5428 mask_mean = 0.2889 prob_mean = 0.5292 :  97%|█████████▋| 190/196 [28:08<00:51,  8.60s/it]




Epoch 4, Step 196: Loss = -0.0490, Reward Loss = -0.5121, Mask Loss = 0.4631 mask_mean = 0.2709 prob_mean = 0.5015 : 100%|██████████| 196/196 [28:46<00:00,  8.81s/it]
Epoch 5, Step 1: Loss = 0.0094, Reward Loss = -0.4847, Mask Loss = 0.4941 mask_mean = 0.2987 prob_mean = 0.4837 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 5, Step 11: Loss = 0.0153, Reward Loss = -0.5409, Mask Loss = 0.5562 mask_mean = 0.2996 prob_mean = 0.5344 :   5%|▌         | 10/196 [01:34<26:32,  8.56s/it]




Epoch 5, Step 21: Loss = 0.0077, Reward Loss = -0.5135, Mask Loss = 0.5212 mask_mean = 0.2861 prob_mean = 0.5141 :  10%|█         | 20/196 [03:00<25:08,  8.57s/it] 




Epoch 5, Step 31: Loss = 0.0218, Reward Loss = -0.5298, Mask Loss = 0.5516 mask_mean = 0.2865 prob_mean = 0.5259 :  15%|█▌        | 30/196 [04:26<23:49,  8.61s/it] 




Epoch 5, Step 41: Loss = 0.0596, Reward Loss = -0.5520, Mask Loss = 0.6117 mask_mean = 0.3082 prob_mean = 0.5513 :  20%|██        | 40/196 [05:52<22:15,  8.56s/it]




Epoch 5, Step 51: Loss = 0.0112, Reward Loss = -0.5537, Mask Loss = 0.5649 mask_mean = 0.2908 prob_mean = 0.5535 :  26%|██▌       | 50/196 [07:18<20:47,  8.55s/it] 




Epoch 5, Step 61: Loss = 0.0143, Reward Loss = -0.5188, Mask Loss = 0.5331 mask_mean = 0.2966 prob_mean = 0.5137 :  31%|███       | 60/196 [08:44<19:24,  8.56s/it]




Epoch 5, Step 71: Loss = 0.0402, Reward Loss = -0.5848, Mask Loss = 0.6250 mask_mean = 0.2997 prob_mean = 0.5783 :  36%|███▌      | 70/196 [10:10<17:59,  8.57s/it]




Epoch 5, Step 81: Loss = -0.0001, Reward Loss = -0.5177, Mask Loss = 0.5176 mask_mean = 0.2880 prob_mean = 0.5161 :  41%|████      | 80/196 [11:36<16:31,  8.55s/it]




Epoch 5, Step 91: Loss = 0.0123, Reward Loss = -0.5702, Mask Loss = 0.5825 mask_mean = 0.2885 prob_mean = 0.5642 :  46%|████▌     | 90/196 [13:02<15:05,  8.54s/it] 




Epoch 5, Step 101: Loss = 0.0017, Reward Loss = -0.5242, Mask Loss = 0.5259 mask_mean = 0.2928 prob_mean = 0.5193 :  51%|█████     | 100/196 [14:29<13:42,  8.57s/it]




Epoch 5, Step 111: Loss = 0.0383, Reward Loss = -0.5361, Mask Loss = 0.5744 mask_mean = 0.2977 prob_mean = 0.5318 :  56%|█████▌    | 110/196 [15:55<12:14,  8.54s/it]




Epoch 5, Step 121: Loss = 0.0118, Reward Loss = -0.5311, Mask Loss = 0.5429 mask_mean = 0.2961 prob_mean = 0.5264 :  61%|██████    | 120/196 [17:21<10:50,  8.56s/it] 




Epoch 5, Step 131: Loss = 0.0265, Reward Loss = -0.5602, Mask Loss = 0.5867 mask_mean = 0.2927 prob_mean = 0.5561 :  66%|██████▋   | 130/196 [18:47<09:23,  8.54s/it] 




Epoch 5, Step 141: Loss = -0.0099, Reward Loss = -0.5237, Mask Loss = 0.5138 mask_mean = 0.2811 prob_mean = 0.5176 :  71%|███████▏  | 140/196 [20:13<07:58,  8.55s/it]




Epoch 5, Step 151: Loss = 0.0363, Reward Loss = -0.5208, Mask Loss = 0.5571 mask_mean = 0.2943 prob_mean = 0.5195 :  77%|███████▋  | 150/196 [21:38<06:33,  8.56s/it] 




Epoch 5, Step 161: Loss = 0.0278, Reward Loss = -0.5097, Mask Loss = 0.5374 mask_mean = 0.2888 prob_mean = 0.5073 :  82%|████████▏ | 160/196 [23:04<05:07,  8.55s/it] 




Epoch 5, Step 171: Loss = 0.0127, Reward Loss = -0.5375, Mask Loss = 0.5502 mask_mean = 0.2995 prob_mean = 0.5321 :  87%|████████▋ | 170/196 [24:30<03:42,  8.57s/it] 




Epoch 5, Step 181: Loss = 0.0048, Reward Loss = -0.5280, Mask Loss = 0.5328 mask_mean = 0.2879 prob_mean = 0.5249 :  92%|█████████▏| 180/196 [25:56<02:16,  8.54s/it] 




Epoch 5, Step 191: Loss = -0.0034, Reward Loss = -0.5251, Mask Loss = 0.5217 mask_mean = 0.2766 prob_mean = 0.5279 :  97%|█████████▋| 190/196 [27:22<00:51,  8.54s/it]




Epoch 5, Step 196: Loss = -0.0547, Reward Loss = -0.4542, Mask Loss = 0.3995 mask_mean = 0.2430 prob_mean = 0.4419 : 100%|██████████| 196/196 [28:00<00:00,  8.57s/it]
Epoch 6, Step 1: Loss = 0.0061, Reward Loss = -0.5318, Mask Loss = 0.5379 mask_mean = 0.2810 prob_mean = 0.5288 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 6, Step 11: Loss = 0.0106, Reward Loss = -0.5660, Mask Loss = 0.5766 mask_mean = 0.2802 prob_mean = 0.5622 :   5%|▌         | 10/196 [01:34<26:30,  8.55s/it]




Epoch 6, Step 21: Loss = 0.0323, Reward Loss = -0.5376, Mask Loss = 0.5698 mask_mean = 0.2894 prob_mean = 0.5342 :  10%|█         | 20/196 [03:00<25:10,  8.58s/it] 




Epoch 6, Step 31: Loss = 0.0082, Reward Loss = -0.5288, Mask Loss = 0.5370 mask_mean = 0.2826 prob_mean = 0.5259 :  15%|█▌        | 30/196 [04:26<23:40,  8.56s/it] 




Epoch 6, Step 41: Loss = -0.0142, Reward Loss = -0.5129, Mask Loss = 0.4986 mask_mean = 0.2831 prob_mean = 0.5112 :  20%|██        | 40/196 [05:51<22:08,  8.52s/it]




Epoch 6, Step 51: Loss = -0.0294, Reward Loss = -0.5041, Mask Loss = 0.4748 mask_mean = 0.2806 prob_mean = 0.4965 :  26%|██▌       | 50/196 [07:17<20:48,  8.55s/it]




Epoch 6, Step 61: Loss = 0.0210, Reward Loss = -0.5293, Mask Loss = 0.5503 mask_mean = 0.2954 prob_mean = 0.5262 :  31%|███       | 60/196 [08:43<19:24,  8.56s/it] 




Epoch 6, Step 71: Loss = 0.0094, Reward Loss = -0.5479, Mask Loss = 0.5573 mask_mean = 0.2816 prob_mean = 0.5453 :  36%|███▌      | 70/196 [10:09<17:57,  8.55s/it] 




Epoch 6, Step 81: Loss = -0.0178, Reward Loss = -0.5045, Mask Loss = 0.4867 mask_mean = 0.2723 prob_mean = 0.5025 :  41%|████      | 80/196 [11:35<16:33,  8.57s/it]




Epoch 6, Step 91: Loss = -0.0063, Reward Loss = -0.5096, Mask Loss = 0.5034 mask_mean = 0.2771 prob_mean = 0.5020 :  46%|████▌     | 90/196 [13:01<15:08,  8.57s/it]




Epoch 6, Step 101: Loss = 0.0218, Reward Loss = -0.5055, Mask Loss = 0.5273 mask_mean = 0.2953 prob_mean = 0.4976 :  51%|█████     | 100/196 [14:27<13:43,  8.58s/it] 




Epoch 6, Step 111: Loss = 0.0229, Reward Loss = -0.5222, Mask Loss = 0.5451 mask_mean = 0.2899 prob_mean = 0.5159 :  56%|█████▌    | 110/196 [15:53<12:17,  8.57s/it] 




Epoch 6, Step 121: Loss = 0.0514, Reward Loss = -0.5637, Mask Loss = 0.6150 mask_mean = 0.3125 prob_mean = 0.5660 :  61%|██████    | 120/196 [17:19<10:52,  8.58s/it]




Epoch 6, Step 131: Loss = 0.0142, Reward Loss = -0.5174, Mask Loss = 0.5316 mask_mean = 0.2921 prob_mean = 0.5098 :  66%|██████▋   | 130/196 [18:45<09:24,  8.55s/it]




Epoch 6, Step 141: Loss = 0.0280, Reward Loss = -0.5080, Mask Loss = 0.5361 mask_mean = 0.2995 prob_mean = 0.5067 :  71%|███████▏  | 140/196 [20:11<07:58,  8.54s/it] 




Epoch 6, Step 151: Loss = 0.0432, Reward Loss = -0.5620, Mask Loss = 0.6052 mask_mean = 0.2931 prob_mean = 0.5597 :  77%|███████▋  | 150/196 [21:37<06:32,  8.54s/it]




Epoch 6, Step 161: Loss = 0.0233, Reward Loss = -0.5210, Mask Loss = 0.5443 mask_mean = 0.2910 prob_mean = 0.5139 :  82%|████████▏ | 160/196 [23:03<05:06,  8.53s/it] 




Epoch 6, Step 171: Loss = 0.0203, Reward Loss = -0.5429, Mask Loss = 0.5632 mask_mean = 0.2846 prob_mean = 0.5385 :  87%|████████▋ | 170/196 [24:29<03:41,  8.53s/it]




Epoch 6, Step 181: Loss = -0.0070, Reward Loss = -0.5195, Mask Loss = 0.5125 mask_mean = 0.2795 prob_mean = 0.5140 :  92%|█████████▏| 180/196 [25:54<02:16,  8.51s/it]




Epoch 6, Step 191: Loss = 0.0221, Reward Loss = -0.5555, Mask Loss = 0.5776 mask_mean = 0.2878 prob_mean = 0.5492 :  97%|█████████▋| 190/196 [27:20<00:51,  8.56s/it] 




Epoch 6, Step 196: Loss = -0.0039, Reward Loss = -0.5674, Mask Loss = 0.5635 mask_mean = 0.3003 prob_mean = 0.5542 : 100%|██████████| 196/196 [27:58<00:00,  8.56s/it]
Epoch 7, Step 1: Loss = 0.0238, Reward Loss = -0.5880, Mask Loss = 0.6118 mask_mean = 0.2856 prob_mean = 0.5830 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 7, Step 11: Loss = 0.0064, Reward Loss = -0.5229, Mask Loss = 0.5293 mask_mean = 0.2865 prob_mean = 0.5168 :   5%|▌         | 10/196 [01:34<26:24,  8.52s/it] 




Epoch 7, Step 21: Loss = 0.0099, Reward Loss = -0.5009, Mask Loss = 0.5108 mask_mean = 0.2907 prob_mean = 0.4960 :  10%|█         | 20/196 [03:00<25:05,  8.56s/it] 




Epoch 7, Step 31: Loss = 0.0179, Reward Loss = -0.5524, Mask Loss = 0.5703 mask_mean = 0.2843 prob_mean = 0.5505 :  15%|█▌        | 30/196 [04:26<23:44,  8.58s/it] 




Epoch 7, Step 41: Loss = 0.0115, Reward Loss = -0.5248, Mask Loss = 0.5363 mask_mean = 0.2924 prob_mean = 0.5189 :  20%|██        | 40/196 [05:52<22:11,  8.53s/it] 




Epoch 7, Step 51: Loss = 0.0080, Reward Loss = -0.5178, Mask Loss = 0.5258 mask_mean = 0.2902 prob_mean = 0.5105 :  26%|██▌       | 50/196 [07:18<20:48,  8.55s/it] 




Epoch 7, Step 61: Loss = 0.0051, Reward Loss = -0.5718, Mask Loss = 0.5769 mask_mean = 0.2828 prob_mean = 0.5679 :  31%|███       | 60/196 [08:44<19:21,  8.54s/it] 




Epoch 7, Step 71: Loss = 0.0512, Reward Loss = -0.5572, Mask Loss = 0.6084 mask_mean = 0.3055 prob_mean = 0.5483 :  36%|███▌      | 70/196 [10:10<17:52,  8.51s/it] 




Epoch 7, Step 81: Loss = 0.0431, Reward Loss = -0.5419, Mask Loss = 0.5850 mask_mean = 0.3035 prob_mean = 0.5366 :  41%|████      | 80/196 [11:36<16:32,  8.55s/it] 




Epoch 7, Step 91: Loss = 0.0181, Reward Loss = -0.5246, Mask Loss = 0.5426 mask_mean = 0.2870 prob_mean = 0.5219 :  46%|████▌     | 90/196 [13:02<15:07,  8.56s/it]




Epoch 7, Step 101: Loss = -0.0047, Reward Loss = -0.5075, Mask Loss = 0.5028 mask_mean = 0.2787 prob_mean = 0.5036 :  51%|█████     | 100/196 [14:27<13:36,  8.51s/it]




Epoch 7, Step 111: Loss = -0.0007, Reward Loss = -0.5154, Mask Loss = 0.5147 mask_mean = 0.2806 prob_mean = 0.5129 :  56%|█████▌    | 110/196 [15:54<12:16,  8.57s/it]




Epoch 7, Step 121: Loss = 0.0055, Reward Loss = -0.5256, Mask Loss = 0.5310 mask_mean = 0.2826 prob_mean = 0.5223 :  61%|██████    | 120/196 [17:19<10:48,  8.53s/it] 




Epoch 7, Step 131: Loss = 0.0213, Reward Loss = -0.5472, Mask Loss = 0.5686 mask_mean = 0.2990 prob_mean = 0.5398 :  66%|██████▋   | 130/196 [18:45<09:24,  8.55s/it]




Epoch 7, Step 141: Loss = 0.0094, Reward Loss = -0.5819, Mask Loss = 0.5913 mask_mean = 0.2911 prob_mean = 0.5724 :  71%|███████▏  | 140/196 [20:11<07:57,  8.53s/it]




Epoch 7, Step 151: Loss = 0.0399, Reward Loss = -0.5736, Mask Loss = 0.6136 mask_mean = 0.2966 prob_mean = 0.5717 :  77%|███████▋  | 150/196 [21:37<06:34,  8.58s/it]




Epoch 7, Step 161: Loss = 0.0213, Reward Loss = -0.5617, Mask Loss = 0.5830 mask_mean = 0.2952 prob_mean = 0.5589 :  82%|████████▏ | 160/196 [23:03<05:07,  8.54s/it] 




Epoch 7, Step 171: Loss = -0.0062, Reward Loss = -0.5369, Mask Loss = 0.5307 mask_mean = 0.2796 prob_mean = 0.5326 :  87%|████████▋ | 170/196 [24:29<03:42,  8.56s/it]




Epoch 7, Step 181: Loss = -0.0036, Reward Loss = -0.4815, Mask Loss = 0.4779 mask_mean = 0.2811 prob_mean = 0.4794 :  92%|█████████▏| 180/196 [25:55<02:16,  8.55s/it]




Epoch 7, Step 191: Loss = -0.0107, Reward Loss = -0.4910, Mask Loss = 0.4802 mask_mean = 0.2787 prob_mean = 0.4866 :  97%|█████████▋| 190/196 [27:21<00:51,  8.53s/it]




Epoch 7, Step 196: Loss = 0.0420, Reward Loss = -0.4664, Mask Loss = 0.5084 mask_mean = 0.3009 prob_mean = 0.4820 : 100%|██████████| 196/196 [27:58<00:00,  8.57s/it] 
Epoch 8, Step 1: Loss = 0.0031, Reward Loss = -0.5359, Mask Loss = 0.5390 mask_mean = 0.2840 prob_mean = 0.5300 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 8, Step 11: Loss = 0.0078, Reward Loss = -0.4908, Mask Loss = 0.4986 mask_mean = 0.2793 prob_mean = 0.4933 :   5%|▌         | 10/196 [01:35<26:46,  8.64s/it]




Epoch 8, Step 21: Loss = 0.0134, Reward Loss = -0.5259, Mask Loss = 0.5393 mask_mean = 0.2776 prob_mean = 0.5213 :  10%|█         | 20/196 [03:00<25:02,  8.54s/it] 




Epoch 8, Step 31: Loss = 0.0359, Reward Loss = -0.5472, Mask Loss = 0.5830 mask_mean = 0.2952 prob_mean = 0.5410 :  15%|█▌        | 30/196 [04:26<23:39,  8.55s/it] 




Epoch 8, Step 41: Loss = 0.0157, Reward Loss = -0.5195, Mask Loss = 0.5352 mask_mean = 0.2947 prob_mean = 0.5139 :  20%|██        | 40/196 [05:52<22:16,  8.57s/it] 




Epoch 8, Step 51: Loss = 0.0180, Reward Loss = -0.5251, Mask Loss = 0.5431 mask_mean = 0.2848 prob_mean = 0.5196 :  26%|██▌       | 50/196 [07:18<20:50,  8.56s/it] 




Epoch 8, Step 61: Loss = 0.0121, Reward Loss = -0.5558, Mask Loss = 0.5679 mask_mean = 0.2911 prob_mean = 0.5490 :  31%|███       | 60/196 [08:44<19:18,  8.52s/it] 




Epoch 8, Step 71: Loss = 0.0243, Reward Loss = -0.5445, Mask Loss = 0.5688 mask_mean = 0.3040 prob_mean = 0.5376 :  36%|███▌      | 70/196 [10:10<17:58,  8.56s/it]




Epoch 8, Step 81: Loss = 0.0144, Reward Loss = -0.5588, Mask Loss = 0.5732 mask_mean = 0.2866 prob_mean = 0.5499 :  41%|████      | 80/196 [11:35<16:30,  8.54s/it] 




Epoch 8, Step 91: Loss = 0.0317, Reward Loss = -0.5396, Mask Loss = 0.5713 mask_mean = 0.2848 prob_mean = 0.5324 :  46%|████▌     | 90/196 [13:01<15:05,  8.55s/it] 




Epoch 8, Step 101: Loss = 0.0083, Reward Loss = -0.5574, Mask Loss = 0.5657 mask_mean = 0.2874 prob_mean = 0.5545 :  51%|█████     | 100/196 [14:28<13:45,  8.60s/it]




Epoch 8, Step 111: Loss = 0.0234, Reward Loss = -0.5428, Mask Loss = 0.5662 mask_mean = 0.2933 prob_mean = 0.5405 :  56%|█████▌    | 110/196 [15:53<12:12,  8.52s/it] 




Epoch 8, Step 121: Loss = 0.0117, Reward Loss = -0.5381, Mask Loss = 0.5499 mask_mean = 0.2911 prob_mean = 0.5296 :  61%|██████    | 120/196 [17:19<10:50,  8.56s/it] 




Epoch 8, Step 131: Loss = 0.0046, Reward Loss = -0.5683, Mask Loss = 0.5729 mask_mean = 0.2890 prob_mean = 0.5615 :  66%|██████▋   | 130/196 [18:45<09:23,  8.54s/it] 




Epoch 8, Step 141: Loss = -0.0017, Reward Loss = -0.5314, Mask Loss = 0.5297 mask_mean = 0.2786 prob_mean = 0.5278 :  71%|███████▏  | 140/196 [20:11<07:59,  8.56s/it]




Epoch 8, Step 151: Loss = 0.0107, Reward Loss = -0.5299, Mask Loss = 0.5406 mask_mean = 0.2926 prob_mean = 0.5269 :  77%|███████▋  | 150/196 [21:37<06:32,  8.54s/it] 




Epoch 8, Step 161: Loss = 0.0094, Reward Loss = -0.5290, Mask Loss = 0.5384 mask_mean = 0.2842 prob_mean = 0.5291 :  82%|████████▏ | 160/196 [23:03<05:10,  8.62s/it]




Epoch 8, Step 171: Loss = -0.0167, Reward Loss = -0.4922, Mask Loss = 0.4755 mask_mean = 0.2730 prob_mean = 0.4881 :  87%|████████▋ | 170/196 [24:29<03:42,  8.55s/it]




Epoch 8, Step 181: Loss = -0.0041, Reward Loss = -0.5273, Mask Loss = 0.5232 mask_mean = 0.2755 prob_mean = 0.5209 :  92%|█████████▏| 180/196 [25:55<02:17,  8.59s/it]




Epoch 8, Step 191: Loss = 0.0189, Reward Loss = -0.5315, Mask Loss = 0.5504 mask_mean = 0.2937 prob_mean = 0.5252 :  97%|█████████▋| 190/196 [27:21<00:51,  8.57s/it] 




Epoch 8, Step 196: Loss = 0.0130, Reward Loss = -0.4668, Mask Loss = 0.4797 mask_mean = 0.2849 prob_mean = 0.4693 : 100%|██████████| 196/196 [27:59<00:00,  8.57s/it]
Epoch 9, Step 1: Loss = 0.0288, Reward Loss = -0.5150, Mask Loss = 0.5438 mask_mean = 0.2896 prob_mean = 0.5146 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 9, Step 11: Loss = -0.0125, Reward Loss = -0.4946, Mask Loss = 0.4821 mask_mean = 0.2754 prob_mean = 0.4892 :   5%|▌         | 10/196 [01:34<26:28,  8.54s/it]




Epoch 9, Step 21: Loss = 0.0032, Reward Loss = -0.5257, Mask Loss = 0.5288 mask_mean = 0.2766 prob_mean = 0.5216 :  10%|█         | 20/196 [03:00<25:10,  8.58s/it] 




Epoch 9, Step 31: Loss = 0.0215, Reward Loss = -0.5424, Mask Loss = 0.5639 mask_mean = 0.2949 prob_mean = 0.5404 :  15%|█▌        | 30/196 [04:26<23:44,  8.58s/it] 




Epoch 9, Step 41: Loss = 0.0081, Reward Loss = -0.5086, Mask Loss = 0.5167 mask_mean = 0.2834 prob_mean = 0.5013 :  20%|██        | 40/196 [05:52<22:17,  8.57s/it] 




Epoch 9, Step 51: Loss = -0.0065, Reward Loss = -0.4900, Mask Loss = 0.4835 mask_mean = 0.2804 prob_mean = 0.4858 :  26%|██▌       | 50/196 [07:18<20:47,  8.55s/it]




Epoch 9, Step 61: Loss = 0.0173, Reward Loss = -0.5236, Mask Loss = 0.5410 mask_mean = 0.2881 prob_mean = 0.5159 :  31%|███       | 60/196 [08:44<19:23,  8.55s/it] 




Epoch 9, Step 71: Loss = 0.0338, Reward Loss = -0.5303, Mask Loss = 0.5642 mask_mean = 0.2958 prob_mean = 0.5230 :  36%|███▌      | 70/196 [10:10<17:54,  8.53s/it] 




Epoch 9, Step 81: Loss = -0.0071, Reward Loss = -0.5126, Mask Loss = 0.5055 mask_mean = 0.2929 prob_mean = 0.5073 :  41%|████      | 80/196 [11:35<16:32,  8.56s/it]




Epoch 9, Step 91: Loss = 0.0039, Reward Loss = -0.5340, Mask Loss = 0.5379 mask_mean = 0.2888 prob_mean = 0.5271 :  46%|████▌     | 90/196 [13:01<15:03,  8.53s/it] 




Epoch 9, Step 101: Loss = 0.0098, Reward Loss = -0.5351, Mask Loss = 0.5449 mask_mean = 0.2950 prob_mean = 0.5322 :  51%|█████     | 100/196 [14:27<13:38,  8.53s/it] 




Epoch 9, Step 111: Loss = -0.0106, Reward Loss = -0.5443, Mask Loss = 0.5337 mask_mean = 0.2787 prob_mean = 0.5375 :  56%|█████▌    | 110/196 [15:53<12:15,  8.56s/it]




Epoch 9, Step 121: Loss = 0.0540, Reward Loss = -0.6028, Mask Loss = 0.6568 mask_mean = 0.2960 prob_mean = 0.6048 :  61%|██████    | 120/196 [17:19<10:48,  8.54s/it] 




Epoch 9, Step 131: Loss = 0.0280, Reward Loss = -0.5497, Mask Loss = 0.5778 mask_mean = 0.2931 prob_mean = 0.5459 :  66%|██████▋   | 130/196 [18:45<09:22,  8.52s/it] 




Epoch 9, Step 141: Loss = 0.0174, Reward Loss = -0.4935, Mask Loss = 0.5109 mask_mean = 0.2991 prob_mean = 0.4864 :  71%|███████▏  | 140/196 [20:11<07:58,  8.55s/it] 




Epoch 9, Step 151: Loss = 0.0258, Reward Loss = -0.5542, Mask Loss = 0.5801 mask_mean = 0.2928 prob_mean = 0.5517 :  77%|███████▋  | 150/196 [21:37<06:32,  8.54s/it] 




Epoch 9, Step 161: Loss = -0.0097, Reward Loss = -0.5077, Mask Loss = 0.4980 mask_mean = 0.2783 prob_mean = 0.5057 :  82%|████████▏ | 160/196 [23:03<05:07,  8.56s/it]




Epoch 9, Step 171: Loss = 0.0364, Reward Loss = -0.5553, Mask Loss = 0.5916 mask_mean = 0.2951 prob_mean = 0.5584 :  87%|████████▋ | 170/196 [24:28<03:41,  8.52s/it] 




Epoch 9, Step 181: Loss = 0.0007, Reward Loss = -0.5507, Mask Loss = 0.5514 mask_mean = 0.2961 prob_mean = 0.5464 :  92%|█████████▏| 180/196 [25:54<02:16,  8.54s/it] 




Epoch 9, Step 191: Loss = 0.0063, Reward Loss = -0.5314, Mask Loss = 0.5376 mask_mean = 0.2756 prob_mean = 0.5234 :  97%|█████████▋| 190/196 [27:20<00:51,  8.59s/it] 




Epoch 9, Step 196: Loss = -0.0120, Reward Loss = -0.4653, Mask Loss = 0.4533 mask_mean = 0.2794 prob_mean = 0.4577 : 100%|██████████| 196/196 [27:58<00:00,  8.56s/it]
Epoch 10, Step 1: Loss = -0.0033, Reward Loss = -0.4916, Mask Loss = 0.4883 mask_mean = 0.2769 prob_mean = 0.4877 :   0%|          | 0/196 [00:08<?, ?it/s]




Epoch 10, Step 11: Loss = 0.0064, Reward Loss = -0.5547, Mask Loss = 0.5610 mask_mean = 0.2881 prob_mean = 0.5496 :   5%|▌         | 10/196 [01:34<26:24,  8.52s/it]




Epoch 10, Step 21: Loss = 0.0136, Reward Loss = -0.5450, Mask Loss = 0.5587 mask_mean = 0.2891 prob_mean = 0.5402 :  10%|█         | 20/196 [03:00<25:10,  8.58s/it] 




Epoch 10, Step 31: Loss = 0.0169, Reward Loss = -0.5335, Mask Loss = 0.5504 mask_mean = 0.2935 prob_mean = 0.5276 :  15%|█▌        | 30/196 [04:26<23:44,  8.58s/it] 




Epoch 10, Step 41: Loss = 0.0379, Reward Loss = -0.5256, Mask Loss = 0.5635 mask_mean = 0.2952 prob_mean = 0.5185 :  20%|██        | 40/196 [05:52<22:11,  8.54s/it] 




Epoch 10, Step 51: Loss = 0.0375, Reward Loss = -0.5397, Mask Loss = 0.5771 mask_mean = 0.2909 prob_mean = 0.5404 :  26%|██▌       | 50/196 [07:18<20:49,  8.56s/it] 




Epoch 10, Step 61: Loss = 0.0053, Reward Loss = -0.4982, Mask Loss = 0.5035 mask_mean = 0.2830 prob_mean = 0.4931 :  31%|███       | 60/196 [08:44<19:29,  8.60s/it] 




Epoch 10, Step 71: Loss = 0.0473, Reward Loss = -0.5544, Mask Loss = 0.6017 mask_mean = 0.2950 prob_mean = 0.5535 :  36%|███▌      | 70/196 [10:10<17:54,  8.53s/it] 




Epoch 10, Step 81: Loss = 0.0154, Reward Loss = -0.5303, Mask Loss = 0.5457 mask_mean = 0.2897 prob_mean = 0.5239 :  41%|████      | 80/196 [11:36<16:34,  8.57s/it]




Epoch 10, Step 91: Loss = 0.0420, Reward Loss = -0.5705, Mask Loss = 0.6124 mask_mean = 0.2989 prob_mean = 0.5754 :  46%|████▌     | 90/196 [13:02<15:09,  8.58s/it]




Epoch 10, Step 101: Loss = 0.0445, Reward Loss = -0.5833, Mask Loss = 0.6277 mask_mean = 0.3124 prob_mean = 0.5764 :  51%|█████     | 100/196 [14:27<13:39,  8.54s/it]




Epoch 10, Step 111: Loss = 0.0054, Reward Loss = -0.5640, Mask Loss = 0.5694 mask_mean = 0.3010 prob_mean = 0.5538 :  56%|█████▌    | 110/196 [15:53<12:14,  8.54s/it] 




Epoch 10, Step 121: Loss = 0.0148, Reward Loss = -0.5342, Mask Loss = 0.5490 mask_mean = 0.2945 prob_mean = 0.5345 :  61%|██████    | 120/196 [17:20<10:49,  8.55s/it] 




Epoch 10, Step 131: Loss = -0.0026, Reward Loss = -0.5140, Mask Loss = 0.5114 mask_mean = 0.2883 prob_mean = 0.5092 :  66%|██████▋   | 130/196 [18:45<09:23,  8.54s/it]




Epoch 10, Step 141: Loss = -0.0031, Reward Loss = -0.5831, Mask Loss = 0.5800 mask_mean = 0.2923 prob_mean = 0.5740 :  71%|███████▏  | 140/196 [20:11<07:58,  8.54s/it]




Epoch 10, Step 151: Loss = 0.0044, Reward Loss = -0.5491, Mask Loss = 0.5535 mask_mean = 0.2807 prob_mean = 0.5443 :  77%|███████▋  | 150/196 [21:37<06:32,  8.53s/it] 




Epoch 10, Step 161: Loss = 0.0244, Reward Loss = -0.5269, Mask Loss = 0.5513 mask_mean = 0.2901 prob_mean = 0.5259 :  82%|████████▏ | 160/196 [23:03<05:07,  8.53s/it] 




Epoch 10, Step 171: Loss = 0.0182, Reward Loss = -0.5488, Mask Loss = 0.5670 mask_mean = 0.2820 prob_mean = 0.5444 :  87%|████████▋ | 170/196 [24:29<03:41,  8.53s/it] 




Epoch 10, Step 181: Loss = 0.0210, Reward Loss = -0.5423, Mask Loss = 0.5632 mask_mean = 0.2884 prob_mean = 0.5358 :  92%|█████████▏| 180/196 [25:55<02:16,  8.54s/it] 




Epoch 10, Step 191: Loss = 0.0488, Reward Loss = -0.5612, Mask Loss = 0.6099 mask_mean = 0.3046 prob_mean = 0.5654 :  97%|█████████▋| 190/196 [27:21<00:51,  8.58s/it]




Epoch 10, Step 196: Loss = 0.0552, Reward Loss = -0.5665, Mask Loss = 0.6217 mask_mean = 0.3053 prob_mean = 0.5602 : 100%|██████████| 196/196 [27:59<00:00,  8.57s/it]


In [7]:
3262

3262