# Exploratory Data Analysis

Hypothesis: We can use the https://arxiv.org/pdf/2406.04313 method for increasing honesty

In [1]:
# autoreload your package
%load_ext autoreload
%autoreload 2
import adapter_overseer


In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [3]:
import warnings
# warnings.simplefilter("ignore")
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings("ignore", ".*divide by zero.*")

## numeric, plotting
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (7.0, 4)

## utils
from pathlib import Path
from tqdm.auto import tqdm
import logging, os, re
import collections, functools, itertools
from loguru import logger

from typing import List, Callable, Tuple, Dict, Optional
from jaxtyping import Float, Int
from torch import Tensor

# torch
# import pytorch_lightning as pl
from einops import rearrange, repeat, reduce
import torch
import torch.nn as nn


from baukit.nethook import get_module
from baukit import TraceDict

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from adapter_overseer.config import ExtractConfig

cfg = ExtractConfig()
cfg

ExtractConfig(datasets=('amazon_polarity',), datasets_ood='imdb', model='failspy/Llama-3-8B-Instruct-abliterated', collection_layers=('base_model.model.model.layers.10', 'base_model.model.model.layers.20'), batch_size=2, prompt_format=None, num_shots=2, max_length=776, max_examples=1000, seed=42, max_epochs=1)

## Load

In [5]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# https://huggingface.co/blog/mlabonne/orpo-llama-3
if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16
torch_dtype, device

(torch.bfloat16, device(type='cuda', index=0))

In [7]:
# load model
# quantization_config = BitsAndBytesConfig(load_in_8bit=True)
quantization_config = BitsAndBytesConfig(load_in_4bit=True,     bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True,)
model = AutoModelForCausalLM.from_pretrained(cfg.model, device_map="auto", quantization_config=quantization_config,)
model

Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.72s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Ll

In [8]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
# https://old.reddit.com/r/LocalLLaMA/comments/1coizjy/tokenizer_config_of_llama3_changed_by_meta_in_hf/
tokenizer.eos_token # it's good

'<|eot_id|>'

In [10]:
# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
# \peft_config = LoraConfig(
#     task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
# )
# model = get_peft_model(model, peft_config)


In [11]:
# from peft import prepare_model_for_int8_training
# # we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in float32 for stability. We also cast the output of the last layer in float32 for the same reasons.
# model = prepare_model_for_int8_training(model, output_embedding_layer_name="proj_out")

In [12]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model
# https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py
config = LoraConfig(
                        #r=32, lora_alpha=64, 
                    # target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
                    )


# LoRA config
# peft_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
# )
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

model = get_peft_model(model, config)
model.print_trainable_parameters()

with model.disable_adapter():
    model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


## Get data

In [13]:
# perhaps use load_preproc_datasets from sdb_probes_are_lie_detectors repo... /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/src/prompts/prompt_loading.py

In [14]:
# load a dataset of paired prompts, to try and get the model to lie
from adapter_overseer.prompts.prompt_loading import load_preproc_datasets

N = cfg.max_examples
ds_tokens = load_preproc_datasets(
    cfg.datasets,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens


[32m2024-06-10 07:40:28.142[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m392[0m - [1mmedian token length: 375.0 for amazon_polarity. max_length=776[0m
[32m2024-06-10 07:40:28.144[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m396[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
[32m2024-06-10 07:40:28.425[0m | [1mINFO    [0m | [36madapter_overseer.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m405[0m - [1mnum_rows (after filtering out truncated rows) 3004=>3004[0m


Dataset({
    features: ['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'],
    num_rows: 1001
})

## Train: transformers

https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb

In [15]:
# TODO change the loss function!
# we need to modify the forward pass, so that it returns a different loss function
# but to calculate this we will need to residuals now, and as they werre
# loss_bad = mse(repr_current, repr_target)

# from transformers import SFTTrainer
from trl.trainer import SFTTrainer, SFTConfig
import torch.nn.functional as F

from adapter_overseer.helpers.torch_helpers import clear_mem, switch
from adapter_overseer.helpers.scores import select_choices

class CustomSFTTrainer(SFTTrainer):
    """
    Custom SFTTrainer that orthoganalizes the repr of bad examples, and retains good repr of examples

    See: https://arxiv.org/pdf/2406.04313

    args:
        collection_layers: list of baukit layer names to collect
    """
    def __init__(self, *args, collection_layers: list, alpha=0.1, **kwargs):
        super(CustomSFTTrainer, self).__init__(*args, **kwargs)
        self.collection_layers = collection_layers
        self.alpha = alpha
        self.total_steps = self.args.max_steps

    def compute_loss(self, model, inputs, return_outputs=False):       

        batch = {'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask']}

        # collect the residuals of the model
        with model.disable_adapter():
            orig_outputs = model(**batch, output_hidden_states=True)
        outputs = model(**batch, output_hidden_states=True)

        def collect_hs(hidden_states):
            """The residual stream is the diff of the hs."""
            hs = [hidden_states[i] for i in self.collection_layers]
            return rearrange(hs, 'l b t h -> b l t h').diff(1)

        rep_adapt = collect_hs(outputs.hidden_states)
        rep_orig = collect_hs(orig_outputs.hidden_states)

        # so now we have a mixed batch of good and bad outputs
        # get probs of each choice
        # compare to labels to seperate into good and bad
        choice_ids = inputs['choice_ids'].detach().cpu().long()
        # label_instructed = inputs['label_true'] ^ inputs['instructed_to_lie']
        label_true = inputs['label_true']

        # does the underlying model get it right or wrong?
        end_logits = orig_outputs["logits"][:, -1]
        probs = torch.softmax(end_logits, -1)
        choice_probs = select_choices(probs, choice_ids).sum(2)
        binary_ans = choice_probs[:, 1] / (choice_probs.sum(1) + 1e-12)
        correct_truth_telling = switch(binary_ans, label_true)
        # correct_instruction_following = switch(binary_ans, label_instructed)

        mask_desired = correct_truth_telling>0.5


        # get coeffecient
        steps = self.state.global_step + 1
        c = torch.tensor(self.alpha * steps / (2 * self.total_steps)).to(rep_orig.dtype)
        loss_retain = F.mse_loss(rep_orig, rep_adapt, reduction='none' )[mask_desired]
        if loss_retain.numel() == 0:
            loss_retain = 0
        else:
            loss_retain = loss_retain.mean()
        loss_rr = F.relu(F.cosine_similarity(rep_orig, rep_adapt, dim=1))[~mask_desired]
        if loss_rr.numel() == 0:
            loss_rr = 0
        else:
            loss_rr = loss_rr.mean()
        loss = loss_rr * (1 - c) + c * loss_retain
        loss = loss
        logger.debug(f"steps: {steps}, c: {c}, loss_rr: {loss_rr:2.3f}, loss_retain: {loss_retain:2.3f}, loss={loss:2.3f}, mask_desired: {(mask_desired*1.0).mean():2.3f}")
        
        return (loss, outputs) if return_outputs else loss
    

# TODO make sure that multiple cols get passed into trainer
ds = ds_tokens.select_columns(['label_true', 'label_instructed' ,'instructed_to_lie', 'input_ids', 'attention_mask', 'choice_ids'])

import transformers

# see https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L58
trainer = CustomSFTTrainer(
    model=model,
    train_dataset=ds,
    collection_layers=[10, 20],
    # max_seq_length=cfg.max_length,
    args=SFTConfig(
        # see https://github.com/huggingface/trl/blob/main/trl/trainer/sft_config.py#L21
        max_seq_length=cfg.max_length,
        per_device_train_batch_size=4, # 18GB/24GB
        gradient_accumulation_steps=6, # we want to accumulate the gradients to make the batch size larger, so we have sufficient examples of good and bad behaviour to learn from
        warmup_steps=10,
        max_steps=2000,
        learning_rate=3e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        remove_unused_columns=False,
    ),
    # data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
max_steps is given, it will override any value given in num_train_epochs
[32m2024-06-10 07:40:31.910[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 1, c: 2.499999936844688e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:40:37.111[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 1, c: 2.499999936844688e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:40:42.259[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 1, c: 2.499999936844688e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:40:47.427[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m

{'loss': 1.0, 'grad_norm': 3.3594227399902366e-10, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.02}


[32m2024-06-10 07:41:03.040[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 2, c: 4.999999873689376e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:41:08.245[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 2, c: 4.999999873689376e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:41:13.463[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 2, c: 4.999999873689376e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:41:18.703[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 2, c: 4.999999873689376e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:41:23.954[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 1.0, 'grad_norm': 6.474682595625225e-11, 'learning_rate': 5.9999999999999995e-05, 'epoch': 0.05}


[32m2024-06-10 07:41:34.526[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 3, c: 7.500000356230885e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:41:39.828[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 3, c: 7.500000356230885e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:41:45.140[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 3, c: 7.500000356230885e-05, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:41:50.460[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 3, c: 7.500000356230885e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:41:55.798[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.8333, 'grad_norm': 2.1183223231080461e-10, 'learning_rate': 8.999999999999999e-05, 'epoch': 0.07}


[32m2024-06-10 07:42:06.528[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 4, c: 9.999999747378752e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.000[0m
[32m2024-06-10 07:42:11.895[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 4, c: 9.999999747378752e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:42:17.270[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 4, c: 9.999999747378752e-05, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.250[0m
[32m2024-06-10 07:42:22.657[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 4, c: 9.999999747378752e-05, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:42:28.052[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.8332, 'grad_norm': 1.373788721670266e-10, 'learning_rate': 0.00011999999999999999, 'epoch': 0.1}


[32m2024-06-10 07:42:38.864[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 5, c: 0.0001250000059371814, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:42:44.278[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 5, c: 0.0001250000059371814, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:42:49.698[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 5, c: 0.0001250000059371814, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.250[0m
[32m2024-06-10 07:42:55.125[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 5, c: 0.0001250000059371814, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:43:00.554[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.9999, 'grad_norm': 1.1267461269559575e-10, 'learning_rate': 0.00015, 'epoch': 0.12}


[32m2024-06-10 07:43:11.459[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 6, c: 0.0001500000071246177, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.750[0m
[32m2024-06-10 07:43:16.900[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 6, c: 0.0001500000071246177, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:43:22.371[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 6, c: 0.0001500000071246177, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:43:27.821[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 6, c: 0.0001500000071246177, loss_rr: 1.000, loss_retain: 0.000, loss=1.000, mask_desired: 0.500[0m
[32m2024-06-10 07:43:33.290[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.8332, 'grad_norm': 0.00090835802257061, 'learning_rate': 0.00017999999999999998, 'epoch': 0.14}


[32m2024-06-10 07:43:44.245[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 7, c: 0.00017499999376013875, loss_rr: 1.000, loss_retain: 0.000, loss=0.999, mask_desired: 0.250[0m
[32m2024-06-10 07:43:49.766[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:43:55.238[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.000, loss_retain: 0.000, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:44:00.712[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 7, c: 0.00017499999376013875, loss_rr: 0.999, loss_retain: 0.000, loss=0.999, mask_desired: 0.500[0m
[32m2024-06-10 07:44:06.232[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.6662, 'grad_norm': 0.004503950010985136, 'learning_rate': 0.00020999999999999998, 'epoch': 0.17}


[32m2024-06-10 07:44:17.276[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.999, loss_retain: 0.000, loss=0.999, mask_desired: 0.500[0m
[32m2024-06-10 07:44:22.809[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.999, loss_retain: 0.000, loss=0.999, mask_desired: 0.750[0m
[32m2024-06-10 07:44:28.327[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.999, loss_retain: 0.000, loss=0.999, mask_desired: 0.250[0m
[32m2024-06-10 07:44:33.882[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 8, c: 0.00019999999494757503, loss_rr: 0.998, loss_retain: 0.000, loss=0.998, mask_desired: 0.750[0m
[32m2024-06-10 07:44:39.405[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.9986, 'grad_norm': 0.004739789757877588, 'learning_rate': 0.00023999999999999998, 'epoch': 0.19}


[32m2024-06-10 07:44:50.500[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.996, loss_retain: 0.000, loss=0.996, mask_desired: 0.250[0m
[32m2024-06-10 07:44:56.058[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.994, loss_retain: 0.000, loss=0.994, mask_desired: 0.750[0m
[32m2024-06-10 07:45:01.579[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.994, loss_retain: 0.000, loss=0.994, mask_desired: 0.250[0m
[32m2024-06-10 07:45:07.183[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 9, c: 0.00022499999613501132, loss_rr: 0.994, loss_retain: 0.000, loss=0.993, mask_desired: 0.250[0m
[32m2024-06-10 07:45:12.755[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.9948, 'grad_norm': 0.01964593306183815, 'learning_rate': 0.00027, 'epoch': 0.22}


[32m2024-06-10 07:45:23.903[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.983, loss_retain: 0.000, loss=0.982, mask_desired: 0.500[0m
[32m2024-06-10 07:45:29.476[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.982, loss_retain: 0.000, loss=0.982, mask_desired: 0.250[0m
[32m2024-06-10 07:45:35.069[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.988, loss_retain: 0.000, loss=0.988, mask_desired: 0.250[0m
[32m2024-06-10 07:45:40.662[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 10, c: 0.0002500000118743628, loss_rr: 0.981, loss_retain: 0.000, loss=0.981, mask_desired: 0.500[0m
[32m2024-06-10 07:45:46.282[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.9831, 'grad_norm': 0.0802963376045227, 'learning_rate': 0.0003, 'epoch': 0.24}


[32m2024-06-10 07:45:57.468[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.000, loss_retain: 0.001, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:46:03.022[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.959, loss_retain: 0.000, loss=0.959, mask_desired: 0.500[0m
[32m2024-06-10 07:46:08.617[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.953, loss_retain: 0.001, loss=0.953, mask_desired: 0.500[0m
[32m2024-06-10 07:46:14.212[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 11, c: 0.0002749999985098839, loss_rr: 0.950, loss_retain: 0.000, loss=0.950, mask_desired: 0.500[0m
[32m2024-06-10 07:46:19.806[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.8004, 'grad_norm': 0.17458520829677582, 'learning_rate': 0.0002998492462311557, 'epoch': 0.26}


[32m2024-06-10 07:46:31.012[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.902, loss_retain: 0.004, loss=0.902, mask_desired: 0.750[0m
[32m2024-06-10 07:46:36.608[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.905, loss_retain: 0.004, loss=0.904, mask_desired: 0.500[0m
[32m2024-06-10 07:46:42.218[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.875, loss_retain: 0.003, loss=0.874, mask_desired: 0.750[0m
[32m2024-06-10 07:46:47.814[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 12, c: 0.0003000000142492354, loss_rr: 0.909, loss_retain: 0.000, loss=0.909, mask_desired: 0.000[0m
[32m2024-06-10 07:46:53.437[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.8855, 'grad_norm': 0.4783974289894104, 'learning_rate': 0.00029969849246231153, 'epoch': 0.29}


[32m2024-06-10 07:47:04.672[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.805, loss_retain: 0.007, loss=0.805, mask_desired: 0.500[0m
[32m2024-06-10 07:47:10.293[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.851, loss_retain: 0.006, loss=0.850, mask_desired: 0.750[0m
[32m2024-06-10 07:47:15.901[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.830, loss_retain: 0.004, loss=0.830, mask_desired: 0.250[0m
[32m2024-06-10 07:47:21.533[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 13, c: 0.00032500000088475645, loss_rr: 0.841, loss_retain: 0.004, loss=0.840, mask_desired: 0.500[0m
[32m2024-06-10 07:47:27.156[0m | [34m[1mDEBUG   [0m | [36m__main__

{'loss': 0.6971, 'grad_norm': 0.49628111720085144, 'learning_rate': 0.0002995477386934673, 'epoch': 0.31}


[32m2024-06-10 07:47:38.393[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.759, loss_retain: 0.007, loss=0.759, mask_desired: 0.500[0m
[32m2024-06-10 07:47:44.016[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.751, loss_retain: 0.010, loss=0.751, mask_desired: 0.500[0m
[32m2024-06-10 07:47:49.643[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.829, loss_retain: 0.009, loss=0.828, mask_desired: 0.750[0m
[32m2024-06-10 07:47:55.258[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 14, c: 0.0003499999875202775, loss_rr: 0.858, loss_retain: 0.010, loss=0.857, mask_desired: 0.750[0m
[32m2024-06-10 07:48:00.890[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.7928, 'grad_norm': 0.48608410358428955, 'learning_rate': 0.0002993969849246231, 'epoch': 0.33}


[32m2024-06-10 07:48:12.157[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 15, c: 0.000375000003259629, loss_rr: 0.816, loss_retain: 0.012, loss=0.816, mask_desired: 0.750[0m
[32m2024-06-10 07:48:17.778[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 15, c: 0.000375000003259629, loss_rr: 0.778, loss_retain: 0.012, loss=0.778, mask_desired: 0.750[0m
[32m2024-06-10 07:48:23.402[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 15, c: 0.000375000003259629, loss_rr: 0.782, loss_retain: 0.014, loss=0.782, mask_desired: 0.500[0m
[32m2024-06-10 07:48:29.031[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 15, c: 0.000375000003259629, loss_rr: 0.810, loss_retain: 0.010, loss=0.809, mask_desired: 0.500[0m
[32m2024-06-10 07:48:34.663[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.6599, 'grad_norm': 0.508230447769165, 'learning_rate': 0.0002992462311557789, 'epoch': 0.36}


[32m2024-06-10 07:48:45.966[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.715, loss_retain: 0.012, loss=0.715, mask_desired: 0.500[0m
[32m2024-06-10 07:48:51.602[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.764, loss_retain: 0.014, loss=0.764, mask_desired: 0.750[0m
[32m2024-06-10 07:48:57.228[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.775, loss_retain: 0.014, loss=0.775, mask_desired: 0.500[0m
[32m2024-06-10 07:49:02.859[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 16, c: 0.00039999998989515007, loss_rr: 0.717, loss_retain: 0.015, loss=0.717, mask_desired: 0.500[0m
[32m2024-06-10 07:49:08.498[0m | [34m[1mDEBUG   [0m | [36m__main__

{'loss': 0.7503, 'grad_norm': 0.5595538020133972, 'learning_rate': 0.00029909547738693465, 'epoch': 0.38}


[32m2024-06-10 07:49:19.773[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.743, loss_retain: 0.013, loss=0.742, mask_desired: 0.500[0m
[32m2024-06-10 07:49:25.408[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.706, loss_retain: 0.000, loss=0.706, mask_desired: 0.000[0m
[32m2024-06-10 07:49:31.061[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.685, loss_retain: 0.013, loss=0.685, mask_desired: 0.500[0m
[32m2024-06-10 07:49:36.706[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 17, c: 0.0004250000056345016, loss_rr: 0.749, loss_retain: 0.010, loss=0.748, mask_desired: 0.500[0m
[32m2024-06-10 07:49:42.342[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.7224, 'grad_norm': 0.5684277415275574, 'learning_rate': 0.0002989447236180904, 'epoch': 0.41}


[32m2024-06-10 07:49:53.643[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.715, loss_retain: 0.012, loss=0.715, mask_desired: 0.500[0m
[32m2024-06-10 07:49:59.290[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.673, loss_retain: 0.018, loss=0.673, mask_desired: 0.500[0m
[32m2024-06-10 07:50:04.930[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.677, loss_retain: 0.017, loss=0.677, mask_desired: 0.250[0m
[32m2024-06-10 07:50:10.581[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 18, c: 0.00044999999227002263, loss_rr: 0.691, loss_retain: 0.012, loss=0.691, mask_desired: 0.750[0m
[32m2024-06-10 07:50:16.217[0m | [34m[1mDEBUG   [0m | [36m__main__

{'loss': 0.6947, 'grad_norm': 0.6976051330566406, 'learning_rate': 0.0002987939698492462, 'epoch': 0.43}


[32m2024-06-10 07:50:27.511[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.645, loss_retain: 0.018, loss=0.645, mask_desired: 0.750[0m
[32m2024-06-10 07:50:33.145[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.731, loss_retain: 0.020, loss=0.730, mask_desired: 0.750[0m
[32m2024-06-10 07:50:38.784[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.618, loss_retain: 0.015, loss=0.618, mask_desired: 0.250[0m
[32m2024-06-10 07:50:44.438[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 19, c: 0.00047500000800937414, loss_rr: 0.708, loss_retain: 0.019, loss=0.708, mask_desired: 0.250[0m
[32m2024-06-10 07:50:50.094[0m | [34m[1mDEBUG   [0m | [36m__main__

{'loss': 0.6654, 'grad_norm': 0.9397184252738953, 'learning_rate': 0.00029864321608040196, 'epoch': 0.45}


[32m2024-06-10 07:51:01.423[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.000, loss_retain: 0.023, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:51:07.056[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.615, loss_retain: 0.022, loss=0.615, mask_desired: 0.500[0m
[32m2024-06-10 07:51:12.705[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.594, loss_retain: 0.022, loss=0.594, mask_desired: 0.250[0m
[32m2024-06-10 07:51:18.364[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 20, c: 0.0005000000237487257, loss_rr: 0.579, loss_retain: 0.022, loss=0.578, mask_desired: 0.500[0m
[32m2024-06-10 07:51:24.012[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.5125, 'grad_norm': 1.075516700744629, 'learning_rate': 0.00029849246231155777, 'epoch': 0.48}


[32m2024-06-10 07:51:35.344[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.527, loss_retain: 0.000, loss=0.527, mask_desired: 0.000[0m
[32m2024-06-10 07:51:41.003[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.551, loss_retain: 0.026, loss=0.551, mask_desired: 0.750[0m
[32m2024-06-10 07:51:46.642[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.580, loss_retain: 0.027, loss=0.580, mask_desired: 0.500[0m
[32m2024-06-10 07:51:52.291[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 21, c: 0.0005249999812804163, loss_rr: 0.545, loss_retain: 0.028, loss=0.545, mask_desired: 0.750[0m
[32m2024-06-10 07:51:57.934[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.5615, 'grad_norm': 0.9591206312179565, 'learning_rate': 0.0002983417085427135, 'epoch': 0.5}


[32m2024-06-10 07:52:09.222[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.523, loss_retain: 0.030, loss=0.522, mask_desired: 0.750[0m
[32m2024-06-10 07:52:14.872[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.498, loss_retain: 0.031, loss=0.497, mask_desired: 0.750[0m
[32m2024-06-10 07:52:20.509[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.000, loss_retain: 0.032, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:52:26.138[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 22, c: 0.0005499999970197678, loss_rr: 0.515, loss_retain: 0.031, loss=0.515, mask_desired: 0.750[0m
[32m2024-06-10 07:52:31.779[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.4211, 'grad_norm': 0.8680790066719055, 'learning_rate': 0.00029819095477386933, 'epoch': 0.53}


[32m2024-06-10 07:52:43.084[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.499, loss_retain: 0.033, loss=0.498, mask_desired: 0.500[0m
[32m2024-06-10 07:52:48.727[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.462, loss_retain: 0.034, loss=0.462, mask_desired: 0.750[0m
[32m2024-06-10 07:52:54.370[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.466, loss_retain: 0.034, loss=0.466, mask_desired: 0.250[0m
[32m2024-06-10 07:53:00.024[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 23, c: 0.0005750000127591193, loss_rr: 0.472, loss_retain: 0.035, loss=0.472, mask_desired: 0.250[0m
[32m2024-06-10 07:53:05.680[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.4744, 'grad_norm': 1.3056137561798096, 'learning_rate': 0.00029804020100502514, 'epoch': 0.55}


[32m2024-06-10 07:53:16.995[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.429, loss_retain: 0.036, loss=0.428, mask_desired: 0.750[0m
[32m2024-06-10 07:53:22.642[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.471, loss_retain: 0.033, loss=0.470, mask_desired: 0.250[0m
[32m2024-06-10 07:53:28.290[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.445, loss_retain: 0.037, loss=0.445, mask_desired: 0.750[0m
[32m2024-06-10 07:53:33.935[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 24, c: 0.0006000000284984708, loss_rr: 0.450, loss_retain: 0.039, loss=0.450, mask_desired: 0.250[0m
[32m2024-06-10 07:53:39.590[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.4446, 'grad_norm': 1.1451607942581177, 'learning_rate': 0.0002978894472361809, 'epoch': 0.57}


[32m2024-06-10 07:53:50.893[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.432, loss_retain: 0.039, loss=0.432, mask_desired: 0.500[0m
[32m2024-06-10 07:53:56.544[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.434, loss_retain: 0.040, loss=0.433, mask_desired: 0.750[0m
[32m2024-06-10 07:54:02.192[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.416, loss_retain: 0.038, loss=0.416, mask_desired: 0.500[0m
[32m2024-06-10 07:54:07.838[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 25, c: 0.0006249999860301614, loss_rr: 0.431, loss_retain: 0.039, loss=0.431, mask_desired: 0.250[0m
[32m2024-06-10 07:54:13.497[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.432, 'grad_norm': 0.7832686305046082, 'learning_rate': 0.00029773869346733664, 'epoch': 0.6}


[32m2024-06-10 07:54:24.801[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.399, loss_retain: 0.042, loss=0.399, mask_desired: 0.750[0m
[32m2024-06-10 07:54:30.450[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.406, loss_retain: 0.043, loss=0.406, mask_desired: 0.500[0m
[32m2024-06-10 07:54:36.105[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.400, loss_retain: 0.043, loss=0.400, mask_desired: 0.250[0m
[32m2024-06-10 07:54:41.768[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 26, c: 0.0006500000017695129, loss_rr: 0.381, loss_retain: 0.043, loss=0.381, mask_desired: 0.750[0m
[32m2024-06-10 07:54:47.413[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.3913, 'grad_norm': 1.0241844654083252, 'learning_rate': 0.00029758793969849245, 'epoch': 0.62}


[32m2024-06-10 07:54:58.729[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.393, loss_retain: 0.043, loss=0.393, mask_desired: 0.250[0m
[32m2024-06-10 07:55:04.396[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.377, loss_retain: 0.044, loss=0.376, mask_desired: 0.250[0m
[32m2024-06-10 07:55:10.058[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.000, loss_retain: 0.045, loss=0.000, mask_desired: 1.000[0m
[32m2024-06-10 07:55:15.697[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 27, c: 0.0006750000175088644, loss_rr: 0.405, loss_retain: 0.044, loss=0.404, mask_desired: 0.750[0m
[32m2024-06-10 07:55:21.339[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.2591, 'grad_norm': 0.8636786341667175, 'learning_rate': 0.0002974371859296482, 'epoch': 0.65}


[32m2024-06-10 07:55:32.626[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 28, c: 0.000699999975040555, loss_rr: 0.368, loss_retain: 0.048, loss=0.368, mask_desired: 0.250[0m
[32m2024-06-10 07:55:38.292[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 28, c: 0.000699999975040555, loss_rr: 0.367, loss_retain: 0.049, loss=0.367, mask_desired: 0.500[0m
[32m2024-06-10 07:55:43.992[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 28, c: 0.000699999975040555, loss_rr: 0.375, loss_retain: 0.048, loss=0.375, mask_desired: 0.500[0m
[32m2024-06-10 07:55:49.651[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 28, c: 0.000699999975040555, loss_rr: 0.352, loss_retain: 0.000, loss=0.352, mask_desired: 0.000[0m
[32m2024-06-10 07:55:55.316[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.3645, 'grad_norm': 1.2795997858047485, 'learning_rate': 0.000297286432160804, 'epoch': 0.67}


[32m2024-06-10 07:56:06.642[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.365, loss_retain: 0.050, loss=0.364, mask_desired: 0.250[0m
[32m2024-06-10 07:56:12.305[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.383, loss_retain: 0.050, loss=0.382, mask_desired: 0.500[0m
[32m2024-06-10 07:56:17.966[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.355, loss_retain: 0.051, loss=0.355, mask_desired: 0.500[0m
[32m2024-06-10 07:56:23.619[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 29, c: 0.0007249999907799065, loss_rr: 0.347, loss_retain: 0.052, loss=0.347, mask_desired: 0.750[0m
[32m2024-06-10 07:56:29.281[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.3611, 'grad_norm': 3.3233532905578613, 'learning_rate': 0.00029713567839195976, 'epoch': 0.69}


[32m2024-06-10 07:56:40.615[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 30, c: 0.000750000006519258, loss_rr: 0.352, loss_retain: 0.057, loss=0.351, mask_desired: 0.500[0m
[32m2024-06-10 07:56:46.279[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 30, c: 0.000750000006519258, loss_rr: 0.346, loss_retain: 0.056, loss=0.346, mask_desired: 0.250[0m
[32m2024-06-10 07:56:51.946[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 30, c: 0.000750000006519258, loss_rr: 0.335, loss_retain: 0.057, loss=0.335, mask_desired: 0.750[0m
[32m2024-06-10 07:56:57.597[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 30, c: 0.000750000006519258, loss_rr: 0.337, loss_retain: 0.055, loss=0.336, mask_desired: 0.750[0m
[32m2024-06-10 07:57:03.252[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[3

{'loss': 0.3427, 'grad_norm': 1.471731185913086, 'learning_rate': 0.0002969849246231155, 'epoch': 0.72}


[32m2024-06-10 07:57:14.574[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.342, loss_retain: 0.059, loss=0.342, mask_desired: 0.500[0m
[32m2024-06-10 07:57:20.235[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.323, loss_retain: 0.056, loss=0.323, mask_desired: 0.250[0m
[32m2024-06-10 07:57:25.897[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.330, loss_retain: 0.059, loss=0.329, mask_desired: 0.250[0m
[32m2024-06-10 07:57:31.566[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 31, c: 0.0007750000222586095, loss_rr: 0.328, loss_retain: 0.058, loss=0.327, mask_desired: 0.250[0m
[32m2024-06-10 07:57:37.230[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.3334, 'grad_norm': 1.5648552179336548, 'learning_rate': 0.0002968341708542713, 'epoch': 0.74}


[32m2024-06-10 07:57:48.571[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.323, loss_retain: 0.059, loss=0.323, mask_desired: 0.250[0m
[32m2024-06-10 07:57:54.241[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.323, loss_retain: 0.061, loss=0.323, mask_desired: 0.250[0m
[32m2024-06-10 07:57:59.910[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.323, loss_retain: 0.059, loss=0.323, mask_desired: 0.750[0m
[32m2024-06-10 07:58:05.557[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 32, c: 0.0007999999797903001, loss_rr: 0.321, loss_retain: 0.058, loss=0.321, mask_desired: 0.750[0m
[32m2024-06-10 07:58:11.208[0m | [34m[1mDEBUG   [0m | [36m__main__[0m

{'loss': 0.3204, 'grad_norm': 1.7377934455871582, 'learning_rate': 0.00029668341708542713, 'epoch': 0.76}


[32m2024-06-10 07:58:22.537[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.309, loss_retain: 0.000, loss=0.309, mask_desired: 0.000[0m
[32m2024-06-10 07:58:28.213[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.308, loss_retain: 0.060, loss=0.307, mask_desired: 0.250[0m
[32m2024-06-10 07:58:33.875[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.316, loss_retain: 0.059, loss=0.315, mask_desired: 0.250[0m
[32m2024-06-10 07:58:39.540[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mcompute_loss[0m:[36m78[0m - [34m[1msteps: 33, c: 0.0008249999955296516, loss_rr: 0.311, loss_retain: 0.061, loss=0.310, mask_desired: 0.750[0m


KeyboardInterrupt: 

In [None]:
# save
model.save_pretrained("../outputs/hs_adapter")

# Eval

In [37]:
from datasets import load_dataset
# multiple_choice
from torch.utils.data import DataLoader
# dataset = load_dataset("truthfulqa/truthful_qa", "multiple_choice")

# HACK it was stalling for hours, so I loaded it locally
dataset = load_dataset("../data/truthful_qa")['validation']
dataset

Dataset({
    features: ['question', 'mc1_targets', 'mc2_targets'],
    num_rows: 817
})

In [96]:

# print(row)

def format_prompt(row):
    prompt = f"Q: {row['question']}\n"
    for i, choice in enumerate(row['mc1_targets']['choices']):
        prompt += f"{i+1}. {choice}\n"

    choices = [str(i) for i in range(len(row['mc1_targets']['labels']))]
    return {'text': prompt, 
            'label': [np.argmax(row['mc1_targets']['labels'])],
            'choices': choices,
            'num_choices': len(choices),
            }

dataset1 = dataset.map(format_prompt)

Map: 100%|██████████| 817/817 [00:00<00:00, 6091.51 examples/s]


In [103]:
max([len(r['labels']) for r in dataset['mc1_targets']])

13

In [108]:
# get our choice ids
choices = [str(i) for i in range(13)]
choice_ids = [tokenizer(c, add_special_tokens=False).input_ids[0] for c in choices]
choice_ids

[15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 605, 806, 717]

In [109]:

def tokenization(example):
    o = tokenizer(example["text"], padding="max_length", truncation=True, max_length=cfg.max_length, return_tensors="pt")
    return o

dataset2 = dataset1.map(tokenization, batched=True).select_columns([ 'label', 'input_ids', 'attention_mask', 
                                                                    'num_choices'
                                                                    ]).with_format("torch")
dataset2

Map: 100%|██████████| 817/817 [00:00<00:00, 4213.33 examples/s]


Dataset({
    features: ['label', 'input_ids', 'attention_mask', 'num_choices'],
    num_rows: 817
})

In [168]:
# https://github.dev/sylinrl/TruthfulQA/blob/fdd8ad1c0d00a478cf8b0bb41a3ad8378c16293b/truthfulqa/models.py#L311


probs = []
base_probs = []

dl = DataLoader(
    dataset2, batch_size=4, num_workers=0)
for b in tqdm(dl):
    inputs = {'input_ids': b['input_ids'], 'attention_mask': b['attention_mask']}
    with torch.no_grad():
        with model.disable_adapter():
            out_base = model(**inputs)
        out = model(**inputs)

        for j in range(len(out["logits"])):
            n = b['num_choices'][j]
            b_choice_ids = choice_ids[:n]
            label = b['label'][j, 0]

            choice_probs_base = out_base["logits"][j, -1, b_choice_ids].softmax(dim=-1)
            choice_probs = out["logits"][j, -1, b_choice_ids].softmax(dim=-1)
            prob = choice_probs[label].item()
            prob_base = choice_probs_base[label].item()
            probs.append(prob)
            base_probs.append(prob_base)
        

  1%|          | 2/205 [00:12<20:23,  6.03s/it]


KeyboardInterrupt: 

In [172]:
choice_probs

tensor([0.0395, 0.5892, 0.2594, 0.0377, 0.0161, 0.0065, 0.0130, 0.0232, 0.0068,
        0.0022, 0.0064])

In [173]:
choice_probs_base

tensor([0.0395, 0.5892, 0.2594, 0.0377, 0.0161, 0.0065, 0.0130, 0.0232, 0.0068,
        0.0022, 0.0064])

In [169]:
acc = ((torch.tensor(probs)>0.5)*1.0).mean()
base_acc = ((torch.tensor(base_probs)>0.5)*1.0).mean()
acc, base_acc

(tensor(0.), tensor(0.))

In [171]:
prob_correct = torch.tensor(probs).mean()
prob_base_correct = torch.tensor(base_probs).mean()
prob_correct, prob_base_correct

(tensor(0.0139), tensor(0.0139))

torch.Size([4, 13])

In [112]:
b

{'label': tensor([[0],
         [0],
         [0],
         [0]]),
 'input_ids': tensor([[128009, 128009, 128009,  ...,   3723,   4273,    627],
         [128009, 128009, 128009,  ...,     13,   8494,    627],
         [128009, 128009, 128009,  ...,    559,   9949,    627],
         [128009, 128009, 128009,  ...,    304,  16759,    627]]),
 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]]),
 'num_choices': tensor([4, 5, 4, 4])}

In [None]:
from datasets import load_dataset
# import evaluate
from transformers import pipeline

# accuracy = evaluate.load("accuracy")

dataset_eval = load_dataset("EleutherAI/truthful_qa_binary", split="test")["text"]

In [None]:
from evaluate import evaluator
from datasets import load_dataset
data = load_dataset("EleutherAI/truthful_qa_binary", split="validation[:20]")
task_evaluator = evaluator("text-classification")
# see https://huggingface.co/docs/evaluate/v0.4.0/en/package_reference/evaluator_classes#evaluate.TextClassificationEvaluator

pipe = pipeline(
    task="text-classification",
    model=model,
    tokenizer=tokenizer,
    # feature_extractor=feature_extractor,
    # device=device,
)


results = task_evaluator.compute(
    pipe,
    # tokenizer=tokenizer,
    data=data,
    metric="accuracy",
    input_column="question",
    # label_mapping={"LABEL_0": 0.0, "LABEL_1": 1.0},
    random_state=42
)

In [None]:

# def preprocess_function(examples):
#     return tokenizer(examples["text"], truncation=True)

# https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/pipelines#pipeline-batching
pipe = pipeline("text-classification", device="auto", model=model, tokenizer=tokenizer)
batch_size = 4
for out in tqdm(pipe(dataset_eval, batch_size=batch_size), total=len(dataset_eval)):
    1/0
    pass

In [176]:
model.train()
model.print_trainable_parameters()

trainable params: 0 || all params: 8,033,669,120 || trainable%: 0.0000


In [177]:
with model.disable_adapter():
    model.print_trainable_parameters()

trainable params: 0 || all params: 8,033,669,120 || trainable%: 0.0000
