# Exploratory Data Analysis

Hypothesis: We can use the https://arxiv.org/pdf/2406.04313 method for increasing honesty

In [1]:
# autoreload your package
%load_ext autoreload
%autoreload 2
import adapter_overseer


In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [3]:
import warnings
# warnings.simplefilter("ignore")
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings("ignore", ".*divide by zero.*")

## numeric, plotting
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (7.0, 4)

## utils
from pathlib import Path
from tqdm.auto import tqdm
import logging, os, re
import collections, functools, itertools
from loguru import logger

from typing import List, Callable, Tuple, Dict, Optional
from jaxtyping import Float, Int
from torch import Tensor

# torch
# import pytorch_lightning as pl
from einops import rearrange, repeat, reduce
import torch
import torch.nn as nn


from baukit.nethook import get_module
from baukit import TraceDict

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from adapter_overseer.config import ExtractConfig

cfg = ExtractConfig()
cfg

ExtractConfig(datasets=('amazon_polarity',), datasets_ood='imdb', model='failspy/Llama-3-8B-Instruct-abliterated', collection_layers=('base_model.model.model.layers.10', 'base_model.model.model.layers.20'), batch_size=2, prompt_format=None, num_shots=2, max_length=776, max_examples=1000, seed=42, max_epochs=1)

## Load

In [5]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# https://huggingface.co/blog/mlabonne/orpo-llama-3
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16
torch_dtype, device

(torch.bfloat16, device(type='cuda', index=0))

In [7]:
# load model
# quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantization_config = BitsAndBytesConfig(load_in_4bit=True,     bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True,)
model = AutoModelForCausalLM.from_pretrained(cfg.model, device_map="auto", quantization_config=quantization_config,)
model

Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Ll

In [8]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
# https://old.reddit.com/r/LocalLLaMA/comments/1coizjy/tokenizer_config_of_llama3_changed_by_meta_in_hf/
tokenizer.eos_token # it's good

'<|eot_id|>'

In [10]:
# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
# \peft_config = LoraConfig(
#     task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
# )
# model = get_peft_model(model, peft_config)


In [11]:
# from peft import prepare_model_for_int8_training
# # we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in float32 for stability. We also cast the output of the last layer in float32 for the same reasons.
# model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")

In [12]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model
# https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py
config = LoraConfig(
                        #r=32, lora_alpha=64, 
                    # target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
                    )


# LoRA config
# peft_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
# )
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


## Get data

In [13]:
# perhaps use load_preproc_datasets from sdb_probes_are_lie_detectors repo... /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/src/prompts/prompt_loading.py

In [14]:
# load a dataset of paired prompts, to try and get the model to lie
from adapter_overseer.prompts.prompt_loading import load_preproc_datasets

N = cfg.max_examples
ds_tokens = load_preproc_datasets(
    cfg.datasets,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens


[32m2024-06-10 06:44:52.433[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m392[0m - [1mmedian token length: 375.0 for amazon_polarity. max_length=776[0m
[32m2024-06-10 06:44:52.435[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m396[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
[32m2024-06-10 06:44:52.704[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m405[0m - [1mnum_rows (after filtering out truncated rows) 3004=>3004[0m


Dataset({
    features: ['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'],
    num_rows: 1001
})

## Train: transformers

https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb

In [21]:
# TODO change the loss function!
# we need to modify the forward pass, so that it returns a different loss function
# but to calculate this we will need to residuals now, and as they werre
# loss_bad = mse(repr_current, repr_target)

# from transformers import SFTTrainer
from trl.trainer import SFTTrainer, SFTConfig
import torch.nn.functional as F

from adapter_overseer.helpers.torch_helpers import clear_mem, switch
from adapter_overseer.helpers.scores import select_choices

class CustomSFTTrainer(SFTTrainer):
    """
    Custom SFTTrainer that orthoganalizes the repr of bad examples, and retains good repr of examples

    See: https://arxiv.org/pdf/2406.04313

    args:
        collection_layers: list of baukit layer names to collect
    """
    def __init__(self, *args, collection_layers: list, alpha=0.1, **kwargs):
        super(CustomSFTTrainer, self).__init__(*args, **kwargs)
        self.collection_layers = collection_layers
        self.alpha = alpha
        self.total_steps = self.args.max_steps

    def compute_loss(self, model, inputs, return_outputs=False):       

        batch = {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}

        # collect the residuals of the model
        with model.disable_adapter():
            orig_outputs = model(**batch, output_hidden_states=True)
        outputs = model(**batch, output_hidden_states=True)

        def collect_hs(hidden_states):
            """The residual stream is the diff of the hs."""
            hs = [hidden_states[i] for i in self.collection_layers]
            return rearrange(hs, 'l b t h -> b l t h').diff(1)

        rep_adapt = collect_hs(outputs.hidden_states)
        rep_orig = collect_hs(orig_outputs.hidden_states)

        # so now we have a mixed batch of good and bad outputs
        # get probs of each choice
        # compare to labels to seperate into good and bad
        choice_ids = inputs['choice_ids'].detach().cpu().long()
        # label_instructed = inputs['label_true'] ^ inputs['instructed_to_lie']
        label_true = inputs['label_true']

        # does the underlying model get it right or wrong?
        end_logits = orig_outputs["logits"][:, -1]
        probs = torch.softmax(end_logits, -1)
        choice_probs = select_choices(probs, choice_ids).sum(2)
        binary_ans = choice_probs[:, 1] / (choice_probs.sum(1) + 1e-12)
        correct_truth_telling = switch(binary_ans, label_true)
        # correct_instruction_following = switch(binary_ans, label_instructed)

        mask_desired = correct_truth_telling>0.5


        # get coeffecient
        steps = self.state.global_step + 1
        c = torch.tensor(self.alpha * steps / (2 * self.total_steps)).to(rep_orig.dtype)
        loss_retain = F.mse_loss(rep_orig, rep_adapt, reduction='none' )[mask_desired]
        if loss_retain.numel() == 0:
            loss_retain = 0
        else:
            loss_retain = loss_retain.mean()
        loss_rr = F.relu(F.cosine_similarity(rep_orig, rep_adapt, dim=1))[~mask_desired]
        if loss_rr.numel() == 0:
            loss_rr = 0
        else:
            loss_rr = loss_rr.mean()
        loss = loss_rr * (1 - c) + c * loss_retain
        loss = loss
        logger.debug(f"steps: {steps}, c: {c}, loss_rr: {loss_rr:2.3f}, loss_retain: {loss_retain:2.3f}, loss={loss:2.3f}, mask_desired: {(mask_desired*1.0).mean():2.3f}")
        
        return (loss, outputs) if return_outputs else loss
    

# TODO make sure that multiple cols get passed into trainer
ds = ds_tokens.select_columns(['label_true', 'label_instructed' ,'instructed_to_lie', 'input_ids', 'attention_mask', 'choice_ids'])

import transformers

# see https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L58
trainer = CustomSFTTrainer(
    model=model,
    train_dataset=ds,
    collection_layers=[10, 20],
    # max_seq_length=cfg.max_length,
    args=SFTConfig(
        # see https://github.com/huggingface/trl/blob/main/trl/trainer/sft_config.py#L21
        max_seq_length=cfg.max_length,
        per_device_train_batch_size=4, # 18GB/24GB
        gradient_accumulation_steps=6, # we want to accumulate the gradients to make the batch size larger, so we have sufficient examples of good and bad behaviour to learn from
        warmup_steps=10,
        max_steps=2000,
        learning_rate=3e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        remove_unused_columns=False,
    ),
    # data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
max_steps is given, it will override any value given in num_train_epochs
  1%|          | 11/2000 [04:02<12:09:38, 22.01s/it]
[32m2024-06-10 06:54:27.903[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 1, c: 2.499999936844688e-05, loss_rr: 0.988, loss_retain: 0.000, loss=0.988, mask_desired: 0.500[0m
[32m2024-06-10 06:54:33.321[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 1, c: 2.499999936844688e-05, loss_rr: 0.986, loss_retain: 0.000, loss=0.986, mask_desired: 0.750[0m
[32m2024-06-10 06:54:38.767[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 1, c: 2.499999936844688e-05, loss_rr: 0.993, loss_retain: 0.000, loss=0.993, mask_desired: 0.750[0m
[32m2024-06-10 06:54:44.236[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[

{'loss': 0.9884, 'grad_norm': 0.06136675924062729, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.02}


[32m2024-06-10 06:55:00.826[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 2, c: 4.999999873689376e-05, loss_rr: 0.988, loss_retain: 0.000, loss=0.988, mask_desired: 0.500[0m
[32m2024-06-10 06:55:06.371[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 2, c: 4.999999873689376e-05, loss_rr: 0.986, loss_retain: 0.000, loss=0.986, mask_desired: 0.500[0m
[32m2024-06-10 06:55:11.943[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 2, c: 4.999999873689376e-05, loss_rr: 0.986, loss_retain: 0.000, loss=0.986, mask_desired: 0.500[0m
[32m2024-06-10 06:55:17.520[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 2, c: 4.999999873689376e-05, loss_rr: 0.989, loss_retain: 0.000, loss=0.989, mask_desired: 0.500[0m
[32m2024-06-10 06:55:23.125[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 2

{'loss': 0.9872, 'grad_norm': 0.06063719466328621, 'learning_rate': 5.9999999999999995e-05, 'epoch': 0.05}


[32m2024-06-10 06:55:34.393[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 3, c: 7.500000356230885e-05, loss_rr: 0.988, loss_retain: 0.000, loss=0.987, mask_desired: 0.500[0m
[32m2024-06-10 06:55:40.011[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 3, c: 7.500000356230885e-05, loss_rr: 0.988, loss_retain: 0.000, loss=0.988, mask_desired: 0.750[0m
[32m2024-06-10 06:55:45.631[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 3, c: 7.500000356230885e-05, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 06:55:51.217[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 3, c: 7.500000356230885e-05, loss_rr: 0.978, loss_retain: 0.000, loss=0.978, mask_desired: 0.750[0m
[32m2024-06-10 06:55:56.851[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 3

{'loss': 0.8202, 'grad_norm': 0.075690358877182, 'learning_rate': 8.999999999999999e-05, 'epoch': 0.07}


[32m2024-06-10 06:56:08.166[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 4, c: 9.999999747378752e-05, loss_rr: 0.984, loss_retain: 0.000, loss=0.984, mask_desired: 0.000[0m
[32m2024-06-10 06:56:13.849[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 4, c: 9.999999747378752e-05, loss_rr: 0.982, loss_retain: 0.000, loss=0.982, mask_desired: 0.500[0m
[32m2024-06-10 06:56:19.483[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 4, c: 9.999999747378752e-05, loss_rr: 0.980, loss_retain: 0.000, loss=0.979, mask_desired: 0.250[0m
[32m2024-06-10 06:56:25.162[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 4, c: 9.999999747378752e-05, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 06:56:30.767[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 4

{'loss': 0.8176, 'grad_norm': 0.08097293227910995, 'learning_rate': 0.00011999999999999999, 'epoch': 0.1}


[32m2024-06-10 06:56:42.183[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 5, c: 0.0001250000059371814, loss_rr: 0.961, loss_retain: 0.000, loss=0.961, mask_desired: 0.500[0m
[32m2024-06-10 06:56:47.841[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 5, c: 0.0001250000059371814, loss_rr: 0.956, loss_retain: 0.001, loss=0.955, mask_desired: 0.750[0m
[32m2024-06-10 06:56:53.475[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 5, c: 0.0001250000059371814, loss_rr: 0.962, loss_retain: 0.000, loss=0.962, mask_desired: 0.250[0m
[32m2024-06-10 06:56:59.178[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 5, c: 0.0001250000059371814, loss_rr: 0.955, loss_retain: 0.001, loss=0.955, mask_desired: 0.750[0m
[32m2024-06-10 06:57:04.828[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 5

{'loss': 0.9584, 'grad_norm': 0.2663840353488922, 'learning_rate': 0.00015, 'epoch': 0.12}


[32m2024-06-10 06:57:16.183[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 6, c: 0.0001500000071246177, loss_rr: 0.942, loss_retain: 0.001, loss=0.942, mask_desired: 0.750[0m
[32m2024-06-10 06:57:21.847[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 6, c: 0.0001500000071246177, loss_rr: 0.930, loss_retain: 0.001, loss=0.930, mask_desired: 0.500[0m
[32m2024-06-10 06:57:27.519[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 6, c: 0.0001500000071246177, loss_rr: 0.000, loss_retain: 0.001, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 06:57:33.190[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 6, c: 0.0001500000071246177, loss_rr: 0.934, loss_retain: 0.001, loss=0.934, mask_desired: 0.500[0m
[32m2024-06-10 06:57:38.892[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 6

{'loss': 0.7814, 'grad_norm': 0.33154547214508057, 'learning_rate': 0.00017999999999999998, 'epoch': 0.14}


[32m2024-06-10 06:57:50.359[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.905, loss_retain: 0.001, loss=0.905, mask_desired: 0.250[0m
[32m2024-06-10 06:57:56.065[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.000, loss_retain: 0.002, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 06:58:01.742[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.000, loss_retain: 0.003, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 06:58:07.419[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.890, loss_retain: 0.002, loss=0.890, mask_desired: 0.500[0m
[32m2024-06-10 06:58:13.135[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.5987, 'grad_norm': 0.4056958854198456, 'learning_rate': 0.00020999999999999998, 'epoch': 0.17}


[32m2024-06-10 06:58:24.579[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.855, loss_retain: 0.006, loss=0.855, mask_desired: 0.500[0m
[32m2024-06-10 06:58:30.297[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.883, loss_retain: 0.005, loss=0.883, mask_desired: 0.750[0m
[32m2024-06-10 06:58:35.982[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.856, loss_retain: 0.005, loss=0.856, mask_desired: 0.250[0m
[32m2024-06-10 06:58:41.694[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.812, loss_retain: 0.005, loss=0.812, mask_desired: 0.750[0m
[32m2024-06-10 06:58:47.373[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.8569, 'grad_norm': 0.6575590372085571, 'learning_rate': 0.00023999999999999998, 'epoch': 0.19}


[32m2024-06-10 06:58:58.830[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.827, loss_retain: 0.009, loss=0.827, mask_desired: 0.250[0m
[32m2024-06-10 06:59:04.575[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.728, loss_retain: 0.009, loss=0.728, mask_desired: 0.750[0m
[32m2024-06-10 06:59:10.290[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.786, loss_retain: 0.007, loss=0.786, mask_desired: 0.250[0m
[32m2024-06-10 06:59:16.030[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.738, loss_retain: 0.008, loss=0.738, mask_desired: 0.250[0m
[32m2024-06-10 06:59:21.775[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.7825, 'grad_norm': 0.8301381468772888, 'learning_rate': 0.00027, 'epoch': 0.22}


[32m2024-06-10 06:59:33.271[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.713, loss_retain: 0.010, loss=0.713, mask_desired: 0.500[0m
[32m2024-06-10 06:59:38.990[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.740, loss_retain: 0.013, loss=0.740, mask_desired: 0.250[0m
[32m2024-06-10 06:59:44.713[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.752, loss_retain: 0.012, loss=0.752, mask_desired: 0.250[0m
[32m2024-06-10 06:59:50.451[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.720, loss_retain: 0.011, loss=0.720, mask_desired: 0.500[0m
[32m2024-06-10 06:59:56.192[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.7335, 'grad_norm': 0.6981897354125977, 'learning_rate': 0.0003, 'epoch': 0.24}


[32m2024-06-10 07:00:07.726[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.000, loss_retain: 0.014, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:00:13.441[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.746, loss_retain: 0.011, loss=0.746, mask_desired: 0.500[0m
[32m2024-06-10 07:00:19.177[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.715, loss_retain: 0.013, loss=0.715, mask_desired: 0.500[0m
[32m2024-06-10 07:00:24.933[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.713, loss_retain: 0.010, loss=0.713, mask_desired: 0.500[0m
[32m2024-06-10 07:00:30.686[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.6166, 'grad_norm': 0.4491898715496063, 'learning_rate': 0.0002998492462311557, 'epoch': 0.26}


[32m2024-06-10 07:00:42.221[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.723, loss_retain: 0.015, loss=0.723, mask_desired: 0.750[0m
[32m2024-06-10 07:00:47.980[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.701, loss_retain: 0.015, loss=0.701, mask_desired: 0.500[0m
[32m2024-06-10 07:00:53.726[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.689, loss_retain: 0.014, loss=0.689, mask_desired: 0.750[0m
[32m2024-06-10 07:00:59.466[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.746, loss_retain: 0.000, loss=0.746, mask_desired: 0.000[0m
[32m2024-06-10 07:01:05.204[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.7031, 'grad_norm': 0.8454946875572205, 'learning_rate': 0.00029969849246231153, 'epoch': 0.29}


[32m2024-06-10 07:01:16.784[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.643, loss_retain: 0.018, loss=0.643, mask_desired: 0.500[0m
[32m2024-06-10 07:01:22.549[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.712, loss_retain: 0.015, loss=0.712, mask_desired: 0.750[0m
[32m2024-06-10 07:01:28.311[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.665, loss_retain: 0.011, loss=0.664, mask_desired: 0.250[0m
[32m2024-06-10 07:01:34.087[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.705, loss_retain: 0.012, loss=0.705, mask_desired: 0.500[0m
[32m2024-06-10 07:01:39.827[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1m

{'loss': 0.5727, 'grad_norm': 0.6691225171089172, 'learning_rate': 0.0002995477386934673, 'epoch': 0.31}


[32m2024-06-10 07:01:51.374[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.617, loss_retain: 0.013, loss=0.617, mask_desired: 0.500[0m
[32m2024-06-10 07:01:57.158[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.634, loss_retain: 0.020, loss=0.634, mask_desired: 0.500[0m
[32m2024-06-10 07:02:02.956[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.710, loss_retain: 0.018, loss=0.710, mask_desired: 0.750[0m
[32m2024-06-10 07:02:08.692[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.740, loss_retain: 0.017, loss=0.739, mask_desired: 0.750[0m
[32m2024-06-10 07:02:14.433[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.6686, 'grad_norm': 0.7955325245857239, 'learning_rate': 0.0002993969849246231, 'epoch': 0.33}


[32m2024-06-10 07:02:26.006[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 15, c: 0.000375000003259629, loss_rr: 0.681, loss_retain: 0.021, loss=0.681, mask_desired: 0.750[0m
[32m2024-06-10 07:02:31.781[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 15, c: 0.000375000003259629, loss_rr: 0.654, loss_retain: 0.020, loss=0.653, mask_desired: 0.750[0m
[32m2024-06-10 07:02:37.556[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 15, c: 0.000375000003259629, loss_rr: 0.660, loss_retain: 0.022, loss=0.660, mask_desired: 0.500[0m
[32m2024-06-10 07:02:43.312[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 15, c: 0.000375000003259629, loss_rr: 0.696, loss_retain: 0.017, loss=0.696, mask_desired: 0.500[0m
[32m2024-06-10 07:02:49.052[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 1

{'loss': 0.5551, 'grad_norm': 0.8235176801681519, 'learning_rate': 0.0002992462311557789, 'epoch': 0.36}


[32m2024-06-10 07:03:00.574[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.566, loss_retain: 0.021, loss=0.566, mask_desired: 0.500[0m
[32m2024-06-10 07:03:06.310[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.591, loss_retain: 0.024, loss=0.591, mask_desired: 0.750[0m
[32m2024-06-10 07:03:12.035[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.633, loss_retain: 0.024, loss=0.633, mask_desired: 0.500[0m
[32m2024-06-10 07:03:17.766[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.567, loss_retain: 0.025, loss=0.567, mask_desired: 0.500[0m
[32m2024-06-10 07:03:23.501[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1m

{'loss': 0.6007, 'grad_norm': 1.1528459787368774, 'learning_rate': 0.00029909547738693465, 'epoch': 0.38}


[32m2024-06-10 07:03:35.013[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.562, loss_retain: 0.025, loss=0.562, mask_desired: 0.500[0m
[32m2024-06-10 07:03:40.759[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.524, loss_retain: 0.000, loss=0.524, mask_desired: 0.000[0m
[32m2024-06-10 07:03:46.519[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.499, loss_retain: 0.026, loss=0.498, mask_desired: 0.500[0m
[32m2024-06-10 07:03:52.291[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.553, loss_retain: 0.023, loss=0.553, mask_desired: 0.500[0m
[32m2024-06-10 07:03:58.045[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.537, 'grad_norm': 1.1799741983413696, 'learning_rate': 0.0002989447236180904, 'epoch': 0.41}


[32m2024-06-10 07:04:09.553[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.490, loss_retain: 0.031, loss=0.490, mask_desired: 0.500[0m
[32m2024-06-10 07:04:15.258[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.475, loss_retain: 0.033, loss=0.475, mask_desired: 0.500[0m
[32m2024-06-10 07:04:20.974[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.463, loss_retain: 0.034, loss=0.463, mask_desired: 0.250[0m
[32m2024-06-10 07:04:26.708[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.476, loss_retain: 0.030, loss=0.476, mask_desired: 0.750[0m
[32m2024-06-10 07:04:32.439[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1m

{'loss': 0.4831, 'grad_norm': 1.3564367294311523, 'learning_rate': 0.0002987939698492462, 'epoch': 0.43}


[32m2024-06-10 07:04:43.955[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.448, loss_retain: 0.035, loss=0.448, mask_desired: 0.750[0m
[32m2024-06-10 07:04:49.661[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.473, loss_retain: 0.035, loss=0.473, mask_desired: 0.750[0m
[32m2024-06-10 07:04:55.380[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.423, loss_retain: 0.034, loss=0.423, mask_desired: 0.250[0m
[32m2024-06-10 07:05:01.125[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.477, loss_retain: 0.037, loss=0.477, mask_desired: 0.250[0m
[32m2024-06-10 07:05:06.869[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1m

{'loss': 0.4521, 'grad_norm': 1.1842265129089355, 'learning_rate': 0.00029864321608040196, 'epoch': 0.45}


[32m2024-06-10 07:05:18.392[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.000, loss_retain: 0.038, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:05:24.100[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.428, loss_retain: 0.038, loss=0.428, mask_desired: 0.500[0m
[32m2024-06-10 07:05:29.810[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.427, loss_retain: 0.038, loss=0.426, mask_desired: 0.250[0m
[32m2024-06-10 07:05:35.533[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.420, loss_retain: 0.038, loss=0.420, mask_desired: 0.500[0m
[32m2024-06-10 07:05:41.277[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.3625, 'grad_norm': 1.104626178741455, 'learning_rate': 0.00029849246231155777, 'epoch': 0.48}


[32m2024-06-10 07:05:52.802[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.401, loss_retain: 0.000, loss=0.400, mask_desired: 0.000[0m
[32m2024-06-10 07:05:58.539[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.411, loss_retain: 0.040, loss=0.410, mask_desired: 0.750[0m
[32m2024-06-10 07:06:04.245[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.427, loss_retain: 0.040, loss=0.427, mask_desired: 0.500[0m
[32m2024-06-10 07:06:09.971[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.399, loss_retain: 0.040, loss=0.399, mask_desired: 0.750[0m
[32m2024-06-10 07:06:15.675[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.4172, 'grad_norm': 1.3811888694763184, 'learning_rate': 0.0002983417085427135, 'epoch': 0.5}


[32m2024-06-10 07:06:27.183[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.410, loss_retain: 0.041, loss=0.410, mask_desired: 0.750[0m
[32m2024-06-10 07:06:32.890[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.388, loss_retain: 0.042, loss=0.388, mask_desired: 0.750[0m
[32m2024-06-10 07:06:38.605[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.000, loss_retain: 0.041, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:06:44.337[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.402, loss_retain: 0.041, loss=0.402, mask_desired: 0.750[0m
[32m2024-06-10 07:06:50.071[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.3279, 'grad_norm': 1.4486515522003174, 'learning_rate': 0.00029819095477386933, 'epoch': 0.53}


[32m2024-06-10 07:07:01.573[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.394, loss_retain: 0.044, loss=0.393, mask_desired: 0.500[0m
[32m2024-06-10 07:07:07.281[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.368, loss_retain: 0.044, loss=0.368, mask_desired: 0.750[0m
[32m2024-06-10 07:07:13.018[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.371, loss_retain: 0.043, loss=0.371, mask_desired: 0.250[0m
[32m2024-06-10 07:07:18.761[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.380, loss_retain: 0.046, loss=0.380, mask_desired: 0.250[0m
[32m2024-06-10 07:07:24.501[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.3774, 'grad_norm': 1.7152425050735474, 'learning_rate': 0.00029804020100502514, 'epoch': 0.55}


[32m2024-06-10 07:07:36.037[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.354, loss_retain: 0.047, loss=0.354, mask_desired: 0.750[0m
[32m2024-06-10 07:07:41.782[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.373, loss_retain: 0.049, loss=0.373, mask_desired: 0.250[0m
[32m2024-06-10 07:07:47.571[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.367, loss_retain: 0.046, loss=0.367, mask_desired: 0.750[0m
[32m2024-06-10 07:07:53.341[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.365, loss_retain: 0.049, loss=0.364, mask_desired: 0.250[0m
[32m2024-06-10 07:07:59.125[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.363, 'grad_norm': 1.3496503829956055, 'learning_rate': 0.0002978894472361809, 'epoch': 0.57}


[32m2024-06-10 07:08:10.705[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.359, loss_retain: 0.052, loss=0.359, mask_desired: 0.500[0m
[32m2024-06-10 07:08:16.495[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.359, loss_retain: 0.051, loss=0.358, mask_desired: 0.750[0m
[32m2024-06-10 07:08:22.247[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.348, loss_retain: 0.050, loss=0.348, mask_desired: 0.500[0m
[32m2024-06-10 07:08:27.990[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.355, loss_retain: 0.051, loss=0.355, mask_desired: 0.250[0m
[32m2024-06-10 07:08:33.759[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.3578, 'grad_norm': 1.3774548768997192, 'learning_rate': 0.00029773869346733664, 'epoch': 0.6}


[32m2024-06-10 07:08:45.328[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.346, loss_retain: 0.052, loss=0.346, mask_desired: 0.750[0m
[32m2024-06-10 07:08:51.111[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.344, loss_retain: 0.057, loss=0.344, mask_desired: 0.500[0m
[32m2024-06-10 07:08:56.867[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.345, loss_retain: 0.054, loss=0.345, mask_desired: 0.250[0m
[32m2024-06-10 07:09:02.649[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.331, loss_retain: 0.055, loss=0.330, mask_desired: 0.750[0m
[32m2024-06-10 07:09:08.401[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.3394, 'grad_norm': 1.6528685092926025, 'learning_rate': 0.00029758793969849245, 'epoch': 0.62}


[32m2024-06-10 07:09:19.947[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.337, loss_retain: 0.056, loss=0.337, mask_desired: 0.250[0m
[32m2024-06-10 07:09:25.716[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.326, loss_retain: 0.055, loss=0.326, mask_desired: 0.250[0m
[32m2024-06-10 07:09:31.474[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.000, loss_retain: 0.056, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:09:37.216[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.348, loss_retain: 0.056, loss=0.348, mask_desired: 0.750[0m
[32m2024-06-10 07:09:42.976[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.2232, 'grad_norm': 0.9091976881027222, 'learning_rate': 0.0002974371859296482, 'epoch': 0.65}


[32m2024-06-10 07:09:54.539[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 28, c: 0.000699999975040555, loss_rr: 0.318, loss_retain: 0.058, loss=0.318, mask_desired: 0.250[0m
[32m2024-06-10 07:10:00.317[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 28, c: 0.000699999975040555, loss_rr: 0.328, loss_retain: 0.059, loss=0.328, mask_desired: 0.500[0m
[32m2024-06-10 07:10:06.055[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 28, c: 0.000699999975040555, loss_rr: 0.329, loss_retain: 0.061, loss=0.328, mask_desired: 0.500[0m
[32m2024-06-10 07:10:11.803[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 28, c: 0.000699999975040555, loss_rr: 0.310, loss_retain: 0.000, loss=0.309, mask_desired: 0.000[0m
[32m2024-06-10 07:10:17.538[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 2

{'loss': 0.3203, 'grad_norm': 1.7665152549743652, 'learning_rate': 0.000297286432160804, 'epoch': 0.67}


[32m2024-06-10 07:10:29.086[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.311, loss_retain: 0.059, loss=0.311, mask_desired: 0.250[0m
[32m2024-06-10 07:10:34.821[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.324, loss_retain: 0.059, loss=0.323, mask_desired: 0.500[0m
[32m2024-06-10 07:10:40.564[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.318, loss_retain: 0.060, loss=0.317, mask_desired: 0.500[0m
[32m2024-06-10 07:10:46.307[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.301, loss_retain: 0.060, loss=0.301, mask_desired: 0.750[0m
[32m2024-06-10 07:10:52.053[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.315, 'grad_norm': 2.5819449424743652, 'learning_rate': 0.00029713567839195976, 'epoch': 0.69}


[32m2024-06-10 07:11:03.622[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 30, c: 0.000750000006519258, loss_rr: 0.320, loss_retain: 0.061, loss=0.320, mask_desired: 0.500[0m
[32m2024-06-10 07:11:09.373[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 30, c: 0.000750000006519258, loss_rr: 0.315, loss_retain: 0.062, loss=0.314, mask_desired: 0.250[0m
[32m2024-06-10 07:11:15.100[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 30, c: 0.000750000006519258, loss_rr: 0.293, loss_retain: 0.062, loss=0.293, mask_desired: 0.750[0m
[32m2024-06-10 07:11:20.838[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 30, c: 0.000750000006519258, loss_rr: 0.305, loss_retain: 0.059, loss=0.305, mask_desired: 0.750[0m
[32m2024-06-10 07:11:26.571[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 3

{'loss': 0.3073, 'grad_norm': 3.651921033859253, 'learning_rate': 0.0002969849246231155, 'epoch': 0.72}


[32m2024-06-10 07:11:38.141[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.314, loss_retain: 0.062, loss=0.314, mask_desired: 0.500[0m
[32m2024-06-10 07:11:43.912[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.291, loss_retain: 0.063, loss=0.291, mask_desired: 0.250[0m
[32m2024-06-10 07:11:49.688[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.302, loss_retain: 0.061, loss=0.302, mask_desired: 0.250[0m
[32m2024-06-10 07:11:55.428[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.292, loss_retain: 0.059, loss=0.292, mask_desired: 0.250[0m
[32m2024-06-10 07:12:01.171[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.3034, 'grad_norm': 4.210335731506348, 'learning_rate': 0.0002968341708542713, 'epoch': 0.74}


[32m2024-06-10 07:12:12.714[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.297, loss_retain: 0.067, loss=0.297, mask_desired: 0.250[0m
[32m2024-06-10 07:12:18.473[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.297, loss_retain: 0.062, loss=0.297, mask_desired: 0.250[0m
[32m2024-06-10 07:12:24.241[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.301, loss_retain: 0.062, loss=0.301, mask_desired: 0.750[0m
[32m2024-06-10 07:12:29.995[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.291, loss_retain: 0.062, loss=0.291, mask_desired: 0.750[0m
[32m2024-06-10 07:12:35.766[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.2942, 'grad_norm': 5.933142185211182, 'learning_rate': 0.00029668341708542713, 'epoch': 0.76}


[32m2024-06-10 07:12:47.326[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.283, loss_retain: 0.000, loss=0.283, mask_desired: 0.000[0m
[32m2024-06-10 07:12:53.065[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.287, loss_retain: 0.061, loss=0.287, mask_desired: 0.250[0m
[32m2024-06-10 07:12:58.820[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.296, loss_retain: 0.062, loss=0.295, mask_desired: 0.250[0m
[32m2024-06-10 07:13:04.601[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.291, loss_retain: 0.062, loss=0.291, mask_desired: 0.750[0m
[32m2024-06-10 07:13:10.368[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

{'loss': 0.2917, 'grad_norm': 8.013928413391113, 'learning_rate': 0.0002965326633165829, 'epoch': 0.79}


[32m2024-06-10 07:13:21.942[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 34, c: 0.0008500000112690032, loss_rr: 0.283, loss_retain: 0.061, loss=0.283, mask_desired: 0.500[0m
[32m2024-06-10 07:13:27.703[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 34, c: 0.0008500000112690032, loss_rr: 0.281, loss_retain: 0.060, loss=0.281, mask_desired: 0.500[0m
[32m2024-06-10 07:13:33.446[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 34, c: 0.0008500000112690032, loss_rr: 0.280, loss_retain: 0.059, loss=0.280, mask_desired: 0.500[0m
[32m2024-06-10 07:13:39.215[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1msteps: 34, c: 0.0008500000112690032, loss_rr: 0.274, loss_retain: 0.061, loss=0.274, mask_desired: 0.750[0m
[32m2024-06-10 07:13:44.976[0m | [1mINFO    [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m77[0m - [1mstep

KeyboardInterrupt: 

In [22]:
# save
model.save_pretrained("../outputs/hs_adapter")



# Eval

In [36]:
from datasets import load_dataset
from torch.utils.data import DataLoader
dataset = load_dataset("EleutherAI/truthful_qa_binary", split="validation[:20]")
dataset

KeyboardInterrupt: 

In [None]:


def tokenization(example):
    return tokenizer(example["text"])
dataset = dataset.map(tokenization, batched=True)

dl = DataLoader(dataset, batch_size=4, num_workers=0)
for b in tqdm(dl):
    with torch.no_grad():
        with model.disable_adapter():
            out = model(**b)
        out = model(**b)

        logits = outputs["logits"][0, -1, answer_ids]
        logprobs = logits.log_softmax(dim=-1)
        lm_log_odds = logprobs[1] - logprobs[0]
        y_pred = F.logsigmoid(lm_log_odds.mean())
        #now just see if it matches the answer

        out.logits[:, choices]


In [24]:
from datasets import load_dataset
# import evaluate
from transformers import pipeline

# accuracy = evaluate.load("accuracy")

dataset_eval = load_dataset("EleutherAI/truthful_qa_binary", split="test")["text"]


[A
Downloading data: 100%|██████████| 84.5k/84.5k [00:00<00:00, 160kB/s]

[A
Generating validation split: 100%|██████████| 817/817 [00:00<00:00, 4721.48 examples/s]


ValueError: Unknown split "test". Should be one of ['validation'].

In [32]:
from evaluate import evaluator
from datasets import load_dataset
data = load_dataset("EleutherAI/truthful_qa_binary", split="validation[:20]")
task_evaluator = evaluator("text-classification")
# see https://huggingface.co/docs/evaluate/v0.4.0/en/package_reference/evaluator_classes#evaluate.TextClassificationEvaluator

pipe = pipeline(
    task="text-classification",
    model=model,
    tokenizer=tokenizer,
    # feature_extractor=feature_extractor,
    # device=device,
)


results = task_evaluator.compute(
    pipe,
    # tokenizer=tokenizer,
    data=data,
    metric="accuracy",
    input_column="question",
    # label_mapping={"LABEL_0": 0.0, "LABEL_1": 1.0},
    random_state=42
)

The model 'PeftModel' is not supported for text-classification. Supported models are ['AlbertForSequenceClassification', 'BartForSequenceClassification', 'BertForSequenceClassification', 'BigBirdForSequenceClassification', 'BigBirdPegasusForSequenceClassification', 'BioGptForSequenceClassification', 'BloomForSequenceClassification', 'CamembertForSequenceClassification', 'CanineForSequenceClassification', 'LlamaForSequenceClassification', 'ConvBertForSequenceClassification', 'CTRLForSequenceClassification', 'Data2VecTextForSequenceClassification', 'DebertaForSequenceClassification', 'DebertaV2ForSequenceClassification', 'DistilBertForSequenceClassification', 'ElectraForSequenceClassification', 'ErnieForSequenceClassification', 'ErnieMForSequenceClassification', 'EsmForSequenceClassification', 'FalconForSequenceClassification', 'FlaubertForSequenceClassification', 'FNetForSequenceClassification', 'FunnelForSequenceClassification', 'GemmaForSequenceClassification', 'GPT2ForSequenceClassif

KeyboardInterrupt: 

In [None]:

# def preprocess_function(examples):
#     return tokenizer(examples["text"], truncation=True)

# https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/pipelines#pipeline-batching
pipe = pipeline("text-classification", device="auto", model=model, tokenizer=tokenizer)
batch_size = 4
for out in tqdm(pipe(dataset_eval, batch_size=batch_size), total=len(dataset_eval)):
    1/0
    pass