In [1]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DistilBertConfig
from datasets import load_dataset,load_metric
import numpy as np


from accelerate import Accelerator


accelerator = Accelerator()
device = accelerator.device




  from .autonotebook import tqdm as notebook_tqdm
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


#  Prediction model example

In [21]:
texts = "I like this movie!"

pretrained_name = "distilbert-base-uncased-finetuned-sst-2-english"

pred_config = DistilBertConfig.from_pretrained(pretrained_name)
pred_tokenizer = DistilBertTokenizer.from_pretrained(pretrained_name)
pred_model = DistilBertForSequenceClassification.from_pretrained(pretrained_name).to(device)

inputs = pred_tokenizer(texts, return_tensors="pt")
with torch.no_grad():
    inputs = {key:val.to(device) for key,val in inputs.items()}
    logits = pred_model(**inputs).logits

predicted_class_id = logits.argmax().item()
print(pred_model.config.id2label[predicted_class_id])

print(inputs)

POSITIVE
{'input_ids': tensor([[ 101, 1045, 2066, 2023, 3185,  999,  102]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


# Explaination model example

In [28]:
from transformers import CLIPTextModel, CLIPTokenizer

explain_pretrained_name = "openai/clip-vit-large-patch14"

explain_tokenizer = CLIPTokenizer.from_pretrained(explain_pretrained_name)
explain_model = CLIPTextModel.from_pretrained(explain_pretrained_name).to(device)
explain_config = explain_model.config

explain_inputs = explain_tokenizer(texts, return_tensors="pt")

# print(explain_inputs)

# with torch.no_grad():
#     explain_inputs = {key:val.to(device) for key,val in explain_inputs.items()}
#     explain_logits = explain_model(**explain_inputs).logits
print(explain_tokenizer)

<|endoftext|>


In [24]:
pred_hidden_dim = pred_model.config.dim
num_labels = pred_model.config.num_labels
explain_hidden_dim = explain_config.projection_dim
# clip_model.config
print(pred_hidden_dim, explain_hidden_dim, num_labels)

768 768 2


In [6]:
import torch 
import torch.nn as nn
from maskgen.utils import idx_to_selector
import torch.nn.functional as F
import numpy as np
from maskgen.models import MLP
import math
from transformers import CLIPTextModel, CLIPTokenizer
from typing import List


class SimilarityMeasure(nn.Module):
    def __init__(self, pred_hidden_size, explain_hidden_size, embed_size=512):
        super(SimilarityMeasure, self).__init__()

        self.pred_map = MLP(pred_hidden_size, 128, embed_size, num_blocks=2, bottleneck_dim=64)
        self.explain_map = MLP(explain_hidden_size, 128, embed_size, num_blocks=2, bottleneck_dim=64)

        self.logit_scale = nn.Parameter(torch.tensor(1.0))
    
    def forward(self, pred_feature, explain_features):
        """
        Forward pass of the model.

        Args:
            q (torch.Tensor): Query tensor of shape [N, pred_hidden_size].
            k (torch.Tensor): Key tensor of shape [N, L, explain_hidden_size].

        Returns:
            torch.Tensor: Similarity tensor of shape [N, L].
        """
        pred_feature = F.normalize(self.pred_map(pred_feature), p=2, dim=-1).unsqueeze(1)  # [N, 1, embed_size]
        explain_features = F.normalize(self.explain_map(explain_features), p=2, dim=-1)  # [N, L, embed_size]


        logit_scale = self.logit_scale.exp()

        similarity = torch.matmul(explain_features, pred_feature.transpose(-1, -2)).squeeze(-1) * logit_scale  # [N, L]

        return similarity  # [N, L]


class MaskGeneratingModel(nn.Module):
    def __init__(self, pred_model: nn.Module, hidden_size, num_classes):
        super().__init__()

        self.explain_model = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')
        self.pred_model = pred_model
        self.num_classes = num_classes

        explain_hidden_size = self.explain_model.config.projection_dim
        pred_hidden_size = hidden_size
        self.similarity_measure = SimilarityMeasure(pred_hidden_size, explain_hidden_size)

        self.mask_token = 103

        self.bce_loss = nn.BCELoss(reduction='none')

        self.freeze_params()


    def freeze_params(self):
        """
        Freezes the parameters of the ViT and prediction model.

        This method sets the `requires_grad` attribute of all parameters in the ViT and prediction model to False,
        effectively freezing them and preventing them from being updated during training.
        """
        for param in self.explain_model.parameters():
            param.requires_grad = False
        for param in self.pred_model.parameters():
            param.requires_grad = False
    
    def get_interpretable_text_features(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        """
        Extract interpretable features using the text part of CLIP model.

        Args:
            x: An image tensor of shape [N, C, H, W].

        Returns:
            Interpretable features of shape [N, L, d].
        """
        # get the output of the ViT model
        print(input_ids, attention_mask)
        output = self.explain_model(input_ids, attention_mask)
        # get the last hidden state, exclude the cls token
        hidden_states = output['last_hidden_state'][:, 1:-1, :]  # [N, L, d], the first and last tokens are [CLS] and [SEP]
        return hidden_states
    
    def get_original_text_feature(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, predicted_class_selector: torch.Tensor):
        """
        Extract the original feature using the original prediction model.

        Args:
            input_ids: The input tensor of shape [N, L].

        Returns:
            Original feature of shape [N, d].
        """
        # get the output of the prediction model
        output = self.pred_model.distilbert(input_ids, attention_mask)
        # get the first hidden state, which corresponds to the cls token
        hidden_state = output[0] # [N, d]
        pooled_output = hidden_state[:, 0, :] # [N, d]
        pooled_output = self.pred_model.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)

        W = self.pred_model.classifier.weight # [n_classes, d]
        original_feature = pooled_output.unsqueeze(1) * W.unsqueeze(0)   # [N, n_classes, d]
        original_feature = (original_feature * predicted_class_selector.unsqueeze(-1)).sum(1) # [N, d]
        return original_feature

    def get_predicted_class_selector(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
            """
            Returns the predicted class selector for the given input tensor, with the predicted class set to 1 and all other classes set to 0.

            Args:
                x (torch.Tensor): The input tensor of shape [N, C, H, W].

            Returns:
                torch.Tensor: The predicted class selector tensor of shape [N, C], where N is the batch size and C is the number of classes.
            """
            logits = self.pred_model(input_ids, attention_mask).logits # [N, n_classes]
            predicted_class_idx = logits.argmax(-1) # [N, 1]
            predicted_class_selector = idx_to_selector(predicted_class_idx, self.num_classes) # [N, n_classes]
            return predicted_class_selector


    def loss_func(self, sim, mask_list, reverse_mask_list, masked_probs_list, reverse_masked_probs_list):
        """Calculate the loss for the given mask.

        Args:
            sim (Tensor): The similarity tensor of shape [N, L].
            mask (Tensor): The generated mask tensor of shape [N, L].
            probs (Tensor): The probability tensor of shape [N,]. Obtained by feeding the randomly generated mask to the prediction model.

        Returns:
            Tensor: The loss tensor of shape [N, L].
        """
        bce_loss = nn.BCELoss(reduction='none')
        n_steps = len(mask_list)
        L = sim.shape[-1]

        # generating probability
        mask_prob = torch.sigmoid(sim).unsqueeze(1).expand(-1, n_steps, -1) # [N, n_steps, L]
        reverse_mask_prob = 1 - mask_prob # [N, n_steps, L]
        # generated mask samples
        mask_samples = torch.stack(mask_list, dim=1) # [N, n_steps, L]
        reverse_mask_samples = torch.stack(reverse_mask_list, dim=1) # [N, n_steps, L]

        # the prediction probability of the masked input
        mask_sample_probs = torch.stack(masked_probs_list, dim=1) # [N, n_steps, 1]
        # the prediction probability of the reverse masked input
        reverse_mask_sample_probs = torch.stack(reverse_masked_probs_list, dim=1) # [N, n_steps, 1]

        # reward loss, if mask_sample_probs is higher, we want to optimize the probability of generating the masks
        reward_loss = bce_loss(mask_prob , mask_samples) # [N, n_steps, L]
        reward_loss = (reward_loss * mask_sample_probs).mean() # [N, n_steps, L]
        # regret loss, if reverse_mask_sample_probs is higher, we want to optimize the probability of generating the reverse masks
        regret_loss = bce_loss(mask_prob, reverse_mask_samples) * reverse_mask_samples # [N, n_steps, L]
        regret_loss = (regret_loss * torch.relu(reverse_mask_sample_probs - 0.1)).mean() 

        # regret_loss = (regret_loss * torch.relu(0.1 - reverse_mask_sample_probs)).mean() # [N, n_steps, L]
        
        # total_reward_loss = reward_loss.sum() / (mask_selector.sum() + 1e-5)  + regret_loss.sum() / ((1 - mask_selector).sum() + 1e-5) # [1]

        # mask_loss
        mask_loss = (mask_prob * mask_samples).sum() / mask_samples.sum() # [1]

        alt_mask_loss = ((0.1 - mask_prob.mean([-1, -2]) - mask_sample_probs.mean([-1, -2])) ** 2).mean()
        # alt_mask_loss_pos = torch.relu(0.1 - mask_prob.mean([-1, -2]) - mask_sample_probs.mean([-1, -2]))
        # alt_mask_loss_neg = torch.relu(mask_prob.mean([-1, -2]) + mask_sample_probs.mean([-1, -2]) - 0.1)
        # alt_mask_loss = (alt_mask_loss_pos + alt_mask_loss_neg).mean()
        

        loss =  reward_loss + 0.01 * alt_mask_loss
        mask_mean = mask_prob.mean([1, 2]) # [N]
        prob_mean = mask_sample_probs.mean([1, 2]) # [N]

        # print("mask_prob: ", mask_prob[0, 0])
        # print("mask_samples: ", mask_samples[0, 0])


        return {'loss': loss,
                'reward_loss': reward_loss,
                'regret_loss': regret_loss,
                'mask_mean': mask_mean.mean(),
                'prob_mean': prob_mean.mean(),
                'mask_loss': mask_loss,
                'alt_mask_loss': alt_mask_loss}

    def generate_mask(self, sim):
        """Generate a mask based on the similarity tensor. [generate action based on policy]

        Args:
            sim (Tensor): The similarity tensor of shape [N, L].

        Returns:
            Tensor: The generated mask tensor of shape [N, L].
        """
        with torch.no_grad():
            mask_prob = torch.sigmoid(sim)
            # scaler = torch.rand(sim.shape[0], 1, device=sim.device) 

            # mask_prob = (mask_prob - 0) / (1e-5 + mask_prob.max(dim=-1, keepdim=True)[0])
            # mask_prob = mask_prob * scaler
            # mask_prob = torch.clamp(mask_prob, 0.2, 1.0) # prevent the mask_prob from being too close to 0 or 1

            # sample a mask (action) based on the mask probability (policy)
            mask = torch.bernoulli(mask_prob) # [N, L]
        
        return mask # [N, L]
    
    def get_mask_probs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, mask: torch.Tensor, predicted_class_selector: torch.Tensor):
        # No gradients upon the parameters of the prediction model
        with torch.no_grad():
            masked_input_ids = input_ids * mask + (1 - mask) * self.mask_token  # [N, L]
            masked_input_ids = masked_input_ids.long()

            masked_probs = torch.softmax(self.pred_model(masked_input_ids, attention_mask).logits, dim=-1) # [N, n_classes]
            masked_probs = (masked_probs * predicted_class_selector).sum(-1, keepdim=True) # [N, 1]

        return masked_probs


    def sample_one_step(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, sim: torch.Tensor, predicted_class_selector: torch.Tensor):
        with torch.no_grad():
            mask = self.generate_mask(sim)
            mask_probs = self.get_mask_probs(input_ids, attention_mask, mask, predicted_class_selector)
            # reverse_mask = 1.0 - mask
            # reverse_mask_probs = self.get_mask_probs(x, reverse_mask, predicted_class_selector)
        return mask, mask_probs #, reverse_mask, reverse_mask_probs


    def train_one_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, optimizer: torch.optim.Optimizer, n_steps=10):
        self.train()
        optimizer.zero_grad()
        predicted_class_selector = self.get_predicted_class_selector(input_ids, attention_mask)
        outputs = self.forward(input_ids, attention_mask, predicted_class_selector)
        sim = outputs['sim']
        

        mask_list, reverse_mask_list = [], []
        masked_probs_list, reverse_masked_probs_list = [], []
        for idx in range(n_steps):

            mask, masked_probs = self.sample_one_step(input_ids, attention_mask, sim, predicted_class_selector)
            reverse_mask, reverse_masked_probs = self.sample_one_step(input_ids, attention_mask, -sim, predicted_class_selector)

            mask_list.append(mask)
            reverse_mask_list.append(reverse_mask)
            masked_probs_list.append(masked_probs)
            reverse_masked_probs_list.append(reverse_masked_probs)
        
        loss_dict = self.loss_func(sim, mask_list, reverse_mask_list, masked_probs_list, reverse_masked_probs_list)
        # loss_dict = self.loss_func(sim, mask_list, masked_probs_list)

        loss = loss_dict['loss']
        loss.backward()
        optimizer.step()
        return loss_dict

    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, predicted_class_selector: torch.Tensor):
        original_feature = self.get_original_text_feature(input_ids, attention_mask, predicted_class_selector) # [N, d]

        interpretable_features = self.get_interpretable_text_features(input_ids, attention_mask) # [N, L, d]

        
        sim = self.similarity_measure(pred_feature=original_feature, explain_features=interpretable_features) # [N, L]
        return {'sim': sim}

    
    def attribute_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
        with torch.no_grad():
            predicted_class_selector = self.get_predicted_class_selector(input_ids, attention_mask)
            outputs = self.forward(input_ids, attention_mask, predicted_class_selector)
            probs = torch.sigmoid(outputs['sim'])
        return probs
    

In [7]:
# from maskgen.models.mask_generating_model12 import MaskGeneratingModel

pred_hidden_dim = pred_model.config.dim
num_labels = pred_model.config.num_labels

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=pred_hidden_dim, num_classes=num_labels)
mask_gen_model.to(device)
print()




# Load dataset

In [8]:
# from datasets import load_dataset
# imdb = load_dataset("imdb")
# texts = imdb["test"][0]['text']
# print(texts)

# inputs = tokenizer(texts, return_tensors="pt")
# with torch.no_grad():
#     inputs = {key:val.to(device) for key,val in inputs.items()}
#     logits = pred_model(**inputs).logits

# predicted_class_id = logits.argmax().item()
# pred_model.config.id2label[predicted_class_id]

In [9]:
mask_gen_model.attribute_text(inputs['input_ids'], inputs['attention_mask'])

tensor([[ 101, 1045,  103, 2023, 3185,  999,  102]], device='cuda:0') tensor([[1, 1, 1, 1, 1, 1, 1]], device='cuda:0')


tensor([[0.4772, 0.4838, 0.4667, 0.4982, 0.4786]], device='cuda:0')

# training

In [10]:
from torch.utils.data import DataLoader
def load_data(seed=42): 
    dataset = load_dataset("mrm8488/ImageNet1K-val")
    dataset = dataset['train']
    splits = dataset.train_test_split(test_size=0.1, seed=seed)
    test_ds = splits['test']
    splits = splits['train'].train_test_split(test_size=0.1, seed=seed)
    train_ds = splits['train']
    val_ds = splits['test']
    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = load_data()

normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
if "height" in processor.size:
    size = (processor.size["height"], processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in processor.size:
    size = processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = processor.size.get("longest_edge")

transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_ds.set_transform(preprocess)


Repo card metadata block was not found. Setting CardData to empty.


NameError: name 'Normalize' is not defined

# training params

In [None]:
batch_size = 256
num_workers = 8
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers)
n_steps = 10

params_to_optimize = [name for name, param in mask_gen_model.named_parameters() if param.requires_grad]
print("params to be optimized: ")
print(params_to_optimize)

params to be optimized: 
['similarity_measure.logit_scale', 'similarity_measure.pred_map.input_layer.weight', 'similarity_measure.pred_map.input_layer.bias', 'similarity_measure.pred_map.layers.0.0.weight', 'similarity_measure.pred_map.layers.0.0.bias', 'similarity_measure.pred_map.layers.0.3.weight', 'similarity_measure.pred_map.layers.0.3.bias', 'similarity_measure.pred_map.layers.0.6.weight', 'similarity_measure.pred_map.layers.0.6.bias', 'similarity_measure.pred_map.layers.1.0.weight', 'similarity_measure.pred_map.layers.1.0.bias', 'similarity_measure.pred_map.layers.1.3.weight', 'similarity_measure.pred_map.layers.1.3.bias', 'similarity_measure.pred_map.layers.1.6.weight', 'similarity_measure.pred_map.layers.1.6.bias', 'similarity_measure.pred_map.output_layer.weight', 'similarity_measure.pred_map.output_layer.bias', 'similarity_measure.explain_map.input_layer.weight', 'similarity_measure.explain_map.input_layer.bias', 'similarity_measure.explain_map.layers.0.0.weight', 'similarit

In [None]:
from tqdm import tqdm

params_to_optimize = [param for param in mask_gen_model.parameters() if param.requires_grad]
# optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.9)


print()

for epoch in range(10):
    pbar = tqdm(train_dataloader)
    for idx, data in enumerate(pbar):
        pixel_values = data['pixel_values'].to(device)
        loss_dict = mask_gen_model.train_one_batch(pixel_values, optimizer=optimizer, n_steps=n_steps)
        # scheduler.step()
        pbar.set_description(f"Epoch {epoch+1}, Step {idx+1}: Loss = {loss_dict['loss'].item():.4f}, " 
                             f"Reward Loss = {loss_dict['reward_loss'].item():.4f}, "
                             f"Regret Loss = {loss_dict['regret_loss'].item():.4f}, "
                             f"Mask Loss = {loss_dict['mask_loss'].item():.4f} "
                             f"alt_mask_loss = {loss_dict['alt_mask_loss'].item():.4f}"
                             f"mask_mean = {loss_dict['mask_mean'].item():.4f} "
                             f"prob_mean = {loss_dict['prob_mean'].item():.4f} "
                             )
        if idx % 10 == 0:
            print()
        if (idx) % 30 == 0:
            
            torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_{epoch}_{idx}.pth') 



torch.save(mask_gen_model.state_dict(), f'mask_gen_model/mask_gen_model_final_{epoch}_{idx}.pth') 





Epoch 1, Step 1: Loss = 0.4140, Reward Loss = 0.4030, Regret Loss = 0.1692, Mask Loss = 0.5087 alt_mask_loss = 1.1055mask_mean = 0.5072 prob_mean = 0.5827 :   0%|          | 0/159 [00:16<?, ?it/s]




Epoch 1, Step 11: Loss = 0.1257, Reward Loss = 0.1238, Regret Loss = 0.8926, Mask Loss = 0.1967 alt_mask_loss = 0.1860mask_mean = 0.1910 prob_mean = 0.2486 :   7%|▋         | 11/159 [02:51<37:28, 15.19s/it]




Epoch 1, Step 21: Loss = 0.1666, Reward Loss = 0.1637, Regret Loss = 0.7182, Mask Loss = 0.2436 alt_mask_loss = 0.2834mask_mean = 0.2319 prob_mean = 0.2978 :  13%|█▎        | 21/159 [05:21<34:32, 15.01s/it]




Epoch 1, Step 31: Loss = 0.2207, Reward Loss = 0.2165, Regret Loss = 0.6554, Mask Loss = 0.2860 alt_mask_loss = 0.4217mask_mean = 0.2576 prob_mean = 0.3772 :  19%|█▉        | 30/159 [07:51<32:15, 15.00s/it]




Epoch 1, Step 41: Loss = 0.1214, Reward Loss = 0.1194, Regret Loss = 0.9819, Mask Loss = 0.2219 alt_mask_loss = 0.2056mask_mean = 0.1826 prob_mean = 0.2432 :  26%|██▌       | 41/159 [10:53<36:15, 18.44s/it]




Epoch 1, Step 51: Loss = 0.1972, Reward Loss = 0.1934, Regret Loss = 0.7126, Mask Loss = 0.3066 alt_mask_loss = 0.3843mask_mean = 0.2377 prob_mean = 0.3591 :  32%|███▏      | 51/159 [13:23<27:12, 15.12s/it]




Epoch 1, Step 61: Loss = 0.1737, Reward Loss = 0.1703, Regret Loss = 0.7753, Mask Loss = 0.2869 alt_mask_loss = 0.3375mask_mean = 0.2195 prob_mean = 0.3282 :  38%|███▊      | 60/159 [15:53<24:47, 15.03s/it]




Epoch 1, Step 71: Loss = 0.1025, Reward Loss = 0.1009, Regret Loss = 0.9814, Mask Loss = 0.2353 alt_mask_loss = 0.1653mask_mean = 0.1765 prob_mean = 0.2131 :  45%|████▍     | 71/159 [18:56<22:44, 15.50s/it]




Epoch 1, Step 81: Loss = 0.2228, Reward Loss = 0.2184, Regret Loss = 0.7646, Mask Loss = 0.3091 alt_mask_loss = 0.4400mask_mean = 0.2400 prob_mean = 0.4097 :  51%|█████     | 81/159 [21:26<19:31, 15.02s/it]




Epoch 1, Step 91: Loss = 0.1615, Reward Loss = 0.1586, Regret Loss = 0.8333, Mask Loss = 0.2764 alt_mask_loss = 0.2930mask_mean = 0.2117 prob_mean = 0.3126 :  57%|█████▋    | 90/159 [24:23<23:02, 20.04s/it]




Epoch 1, Step 101: Loss = 0.1987, Reward Loss = 0.1947, Regret Loss = 0.8592, Mask Loss = 0.3046 alt_mask_loss = 0.3957mask_mean = 0.2211 prob_mean = 0.3818 :  64%|██████▎   | 101/159 [26:58<14:42, 15.22s/it]




Epoch 1, Step 111: Loss = 0.1892, Reward Loss = 0.1855, Regret Loss = 0.7670, Mask Loss = 0.3096 alt_mask_loss = 0.3671mask_mean = 0.2307 prob_mean = 0.3551 :  70%|██████▉   | 111/159 [29:55<15:37, 19.54s/it]




Epoch 1, Step 121: Loss = 0.1331, Reward Loss = 0.1308, Regret Loss = 0.8926, Mask Loss = 0.2541 alt_mask_loss = 0.2290mask_mean = 0.1885 prob_mean = 0.2702 :  75%|███████▌  | 120/159 [32:52<11:56, 18.38s/it]




Epoch 1, Step 131: Loss = 0.1791, Reward Loss = 0.1757, Regret Loss = 0.8903, Mask Loss = 0.2871 alt_mask_loss = 0.3425mask_mean = 0.2129 prob_mean = 0.3482 :  82%|████████▏ | 131/159 [35:26<07:02, 15.10s/it]




Epoch 1, Step 141: Loss = 0.1965, Reward Loss = 0.1925, Regret Loss = 0.7628, Mask Loss = 0.3267 alt_mask_loss = 0.4003mask_mean = 0.2316 prob_mean = 0.3712 :  89%|████████▊ | 141/159 [37:56<04:30, 15.00s/it]




Epoch 1, Step 151: Loss = 0.1660, Reward Loss = 0.1628, Regret Loss = 0.9118, Mask Loss = 0.2874 alt_mask_loss = 0.3131mask_mean = 0.2006 prob_mean = 0.3332 :  94%|█████████▍| 150/159 [41:01<02:35, 17.27s/it]




Epoch 1, Step 159: Loss = 0.1528, Reward Loss = 0.1498, Regret Loss = 0.8862, Mask Loss = 0.2899 alt_mask_loss = 0.3010mask_mean = 0.2032 prob_mean = 0.3029 : 100%|██████████| 159/159 [43:14<00:00, 16.32s/it]
Epoch 2, Step 1: Loss = 0.1720, Reward Loss = 0.1686, Regret Loss = 0.9114, Mask Loss = 0.2959 alt_mask_loss = 0.3338mask_mean = 0.2090 prob_mean = 0.3375 :   0%|          | 0/159 [00:17<?, ?it/s]




Epoch 2, Step 11: Loss = 0.1928, Reward Loss = 0.1888, Regret Loss = 0.8265, Mask Loss = 0.3369 alt_mask_loss = 0.4017mask_mean = 0.2268 prob_mean = 0.3725 :   7%|▋         | 11/159 [02:51<37:03, 15.02s/it]




Epoch 2, Step 21: Loss = 0.2474, Reward Loss = 0.2417, Regret Loss = 0.7374, Mask Loss = 0.3941 alt_mask_loss = 0.5681mask_mean = 0.2574 prob_mean = 0.4780 :  13%|█▎        | 21/159 [05:59<47:45, 20.76s/it]




Epoch 2, Step 31: Loss = 0.1648, Reward Loss = 0.1614, Regret Loss = 0.9199, Mask Loss = 0.3176 alt_mask_loss = 0.3357mask_mean = 0.2001 prob_mean = 0.3411 :  19%|█▉        | 30/159 [08:28<32:41, 15.20s/it]




Epoch 2, Step 41: Loss = 0.1194, Reward Loss = 0.1173, Regret Loss = 1.0752, Mask Loss = 0.2424 alt_mask_loss = 0.2096mask_mean = 0.1667 prob_mean = 0.2626 :  26%|██▌       | 41/159 [11:02<29:31, 15.01s/it]




Epoch 2, Step 50: Loss = 0.2399, Reward Loss = 0.2347, Regret Loss = 0.7584, Mask Loss = 0.3739 alt_mask_loss = 0.5208mask_mean = 0.2553 prob_mean = 0.4573 :  31%|███▏      | 50/159 [13:17<27:11, 14.96s/it]

In [None]:
3262

3262