In [4]:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
plt.style.use('ggplot')

from typing import Optional, List, Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path

import transformers

import lightning.pytorch as pl
# from dataclasses import dataclass

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from sklearn.preprocessing import RobustScaler

from tqdm.auto import tqdm
import os

from loguru import logger
logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")

transformers.__version__

'4.30.1'

## Load model

In [19]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM, AutoModelForCausalLM, AutoConfig
from transformers import LogitsProcessorList

In [6]:
# leaderboard https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard
model_options = dict(
    device_map="auto",
    load_in_4bit=True,
    # load_in_8bit=True,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    use_safetensors=False,
    # use_cache=False,
)

model_repo = "HuggingFaceH4/starchat-beta"

config = AutoConfig.from_pretrained(model_repo, trust_remote_code=True,)
print(config)
config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForCausalLM.from_pretrained(model_repo, config=config, **model_options)
tokenizer.pad_token_id = 204

GPTBigCodeConfig {
  "_name_or_path": "HuggingFaceH4/starchat-beta",
  "activation_function": "gelu",
  "architectures": [
    "GPTBigCodeForCausalLM"
  ],
  "attention_softmax_in_fp32": true,
  "attn_pdrop": 0.1,
  "bos_token_id": 0,
  "embd_pdrop": 0.1,
  "eos_token_id": 0,
  "inference_runner": 0,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "max_batch_size": null,
  "max_sequence_length": null,
  "model_type": "gpt_bigcode",
  "multi_query": true,
  "n_embd": 6144,
  "n_head": 48,
  "n_inner": 24576,
  "n_layer": 40,
  "n_positions": 8192,
  "pad_key_length": true,
  "pre_allocate_kv_cache": false,
  "resid_pdrop": 0.1,
  "scale_attention_softmax_in_fp32": true,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.30.1",
  "use_cache": true,
  "validate_runner_input": true,
  "v

Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
# Params
BATCH_SIZE = 10 # None # None means auto # 6 gives 16Gb/25GB. where 10GB is the base model. so 6 is 6/15
USE_MCDROPOUT = True

try:
    # num_layers = len(model.model.layers)
    num_layers = model.config.n_layer
    print(num_layers)
except AttributeError:
    try:
        num_layers = len(model.base_model.model.model.layers)
        print(num_layers)
    except:
        num_layers = 10
        
stride = 2
# don't take the first or last layers as they can make it to easy to leak info
extract_layers = tuple(range(2, num_layers-2, stride)) + (num_layers-2,)
extract_layers, num_layers

40


((2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38), 40)

In [122]:
def get_choices_as_tokens(choice_n = "Negative", choice_p = "Positive"):
    # Note some tokenizer differentiate between "no", "\nno", so we sometime need to add whitespace beforehand...
    id_n, id_y = tokenizer(f'\n{choice_n}', add_special_tokens=True)['input_ids'][-1], tokenizer(f'\n{choice_p}', add_special_tokens=True)['input_ids'][-1]
    assert tokenizer.decode([id_n])==choice_n
    assert tokenizer.decode([id_y])==choice_p
    # print(tokenizer.decode([id_y]))
    return id_n, id_y

(17152, 17991)

In [9]:
def enable_dropout(model, USE_MCDROPOUT:Union[float,bool]=True):
    """ Function to enable the dropout layers during test-time """
    
    for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()
            if USE_MCDROPOUT!=True:
                m.p=USE_MCDROPOUT
                # print(m)

## Load probe

In [10]:
# from https://github.com/timeseriesAI/tsai/blob/f20027e236ff06ed8fa3f5d30da5ebdcc67fe5aa/tsai/models/layers.py#L261

class AddCoords1d(nn.Module):
    """Add coordinates to ease position identification without modifying mean and std"""
    def forward(self, x):
        bs, _, seq_len = x.shape
        cc = torch.linspace(-1,1,x.shape[-1], device=x.device).repeat(bs, 1, 1)
        cc = (cc - cc.mean()) / cc.std()
        x = torch.cat([x, cc], dim=1)
        return x

In [11]:
class MLPProbe(nn.Module):
    def __init__(self, c_in, depth=0, hs=16, dropout=0):
        super().__init__()

        layers = [
            nn.BatchNorm1d(c_in, affine=False),  # this will normalise the inputs
            AddCoords1d(),
            nn.Dropout1d(dropout),
            
            nn.Conv1d(c_in+1, hs*(depth+1), kernel_size=2, padding=0),
            nn.BatchNorm1d(hs*(depth+1)),
            nn.ReLU(),
        ]
        for i in range(depth):
            layers += [
                AddCoords1d(),
                nn.Conv1d(hs*(depth-i+1)+1, hs*(depth-i), 2, padding=0),
                nn.BatchNorm1d(hs*(depth-i)),
                nn.ReLU(),
                
            ]
        layers += [nn.AdaptiveAvgPool1d(1)]
        self.net = nn.Sequential(*layers)
        self.head = nn.Sequential(
            nn.Linear(hs, hs), nn.BatchNorm1d(hs), nn.Dropout(dropout), nn.ReLU(), 
            nn.Linear(hs, 1)
        )

    def forward(self, x):
        h = self.net(x)
        # print(1, h.shape)
        h = h.squeeze(-1)
        # print(1, h.shape)
        return self.head(h)


In [12]:
from pytorch_optimizer import Ranger21
import torchmetrics

from torchmetrics import Metric, MetricCollection, Accuracy, AUROC
    
class CSS(pl.LightningModule):
    def __init__(self, c_in, total_steps, depth=1, hs=16, lr=4e-3, weight_decay=1e-9, dropout=0):
        super().__init__()
        self.probe = MLPProbe(c_in, depth=depth, dropout=dropout, hs=hs)
        self.save_hyperparameters()
        
        self.loss_fn = nn.SmoothL1Loss()
        
        # metrics for each stage
        metrics_template = MetricCollection({
            'acc': Accuracy(task="binary"), 
            'auroc': AUROC(task="binary")
        })
        self.metrics = torch.nn.ModuleDict({
            f'metrics_{stage}': metrics_template.clone(prefix=stage+'/') for stage in ['train', 'val', 'test']
        })
        
    def forward(self, x):
        return self.probe(x).squeeze(1)
        
    def _step(self, batch, batch_idx, stage='train'):
        x0, x1, y = batch
        ypred0 = self(x0)
        ypred1 = self(x1)
        
        if stage=='pred':
            return (ypred1-ypred0).float()
        
        loss = self.loss_fn(ypred1-ypred0, y)
        self.log(f"{stage}/loss", loss)
        
        m = self.metrics[f'metrics_{stage}']
        
        y_cls = switch2bool(ypred1-ypred0)
        m(y_cls, y>0.)
        self.log_dict(m, on_epoch=True, on_step=False)
        return loss
    
    def training_step(self, batch, batch_idx=0, dataloader_idx=0):
        return self._step(batch, batch_idx)
    
    def validation_step(self, batch, batch_idx=0):
        return self._step(batch, batch_idx, stage='val')
    
    def predict_step(self, batch, batch_idx=0, dataloader_idx=0):
        return self._step(batch, batch_idx, stage='pred').cpu().detach()
    
    def test_step(self, batch, batch_idx=0, dataloader_idx=0):
        return self._step(batch, batch_idx, stage='test')
    
    def configure_optimizers(self):
        """use ranger21 from  https://github.com/kozistr/pytorch_optimizer"""
        optimizer = Ranger21(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,       
            num_iterations=self.hparams.total_steps,
        )
        return optimizer
    
    

In [13]:
f = '/home/ubuntu/Documents/mjc/elk/discovering_latent_knowledge/notebooks/lightning_logs/version_338/checkpoints/epoch=37-step=2090.ckpt'

In [14]:
net = CSS.load_from_checkpoint(f)
# , c_in=c_in, depth=6, hs=42*6, lr=3e-3, 
#         #   weight_decay=1e-4, 
#           dropout=0.1,
#           )
net

CSS(
  (probe): MLPProbe(
    (net): Sequential(
      (0): BatchNorm1d(6144, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (1): AddCoords1d()
      (2): Dropout1d(p=0.1, inplace=False)
      (3): Conv1d(6145, 1764, kernel_size=(2,), stride=(1,))
      (4): BatchNorm1d(1764, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): AddCoords1d()
      (7): Conv1d(1765, 1512, kernel_size=(2,), stride=(1,))
      (8): BatchNorm1d(1512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU()
      (10): AddCoords1d()
      (11): Conv1d(1513, 1260, kernel_size=(2,), stride=(1,))
      (12): BatchNorm1d(1260, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (13): ReLU()
      (14): AddCoords1d()
      (15): Conv1d(1261, 1008, kernel_size=(2,), stride=(1,))
      (16): BatchNorm1d(1008, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (17): ReLU()
      (18): AddCoo

## Run model

In [15]:
def to_numpy(x):
    if isinstance(x, torch.Tensor):
        # note apache parquet doesn't support half https://github.com/huggingface/datasets/issues/4981
        x = x.detach().cpu().float()
        if x.squeeze().dim()==0:
            return x.item()
        return x.numpy()
    else:
        return x

In [230]:

       
def get_hidden_states(model, tokenizer, input_text, layers=extract_layers, truncation_length=999, output_attentions=False, use_mcdropout=USE_MCDROPOUT, choice_n="No", choice_p="Yes"):
    """
    Given a decoder model and some texts, gets the hidden states (in a given layer) on that input texts
    """
    id_n, id_y = get_choices_as_tokens(choice_n, choice_p)
    if not isinstance(input_text, list):
        input_text = [input_text]
    input_ids = tokenizer(input_text, 
                          return_tensors="pt",
                          padding=True,
                            add_special_tokens=True,
                         ).input_ids.to(model.device)
        
    # Handling truncation: truncate start, not end
    if truncation_length is not None:
        if input_ids.size(1)>truncation_length:
            print('truncating', input_ids.size(1))
        input_ids = input_ids[:, -truncation_length:]

    # forward pass
    last_token = -1
    first_token = 0
    with torch.no_grad():
        model.eval()        
        if use_mcdropout: enable_dropout(model, use_mcdropout)
        
        # taken from greedy_decode https://github.com/huggingface/transformers/blob/ba695c1efd55091e394eb59c90fb33ac3f9f0d41/src/transformers/generation/utils.py
        logits_processor = LogitsProcessorList()
        model_kwargs = dict(use_cache=False)
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
        outputs = model.forward(**model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=True)
        
        next_token_logits = outputs.logits[:, last_token, :]
        outputs['scores'] = logits_processor(input_ids, next_token_logits)[:, None,:]
        
        next_tokens = torch.argmax(outputs['scores'], dim=-1)
        outputs['sequences'] = torch.cat([input_ids, next_tokens], dim=-1)

        # the output is large, so we will just select what we want 1) the first token with[:, 0]
        # 2) selected layers with [layers]
        attentions = None
        if output_attentions:
            # shape is [(batch_size, num_heads, sequence_length, sequence_length)]*num_layers
            # lets take max?
            attentions = [outputs['attentions'][i] for i in layers]
            attentions = [v[:, last_token] for v in attentions]
            attentions = torch.concat(attentions)
        
        hidden_states = torch.stack([outputs['hidden_states'][i] for i in layers], 1)
        
        hidden_states = hidden_states[:, :, last_token] # (batch, layers, past_seq, logits) take just the last token so they are same size
        
        input_truncated = tokenizer.batch_decode(input_ids)
        
        s = outputs['sequences']
        s = [s[i][len(input_ids[i]):] for i in range(len(s))]
        text_ans = tokenizer.batch_decode(s)

        scores = outputs['scores'][:, first_token].softmax(-1) # for first (and only) token
        prob_n, prob_y = scores[:, [id_n, id_y]].T
        eps = 1e-3
        ans = (prob_y/(prob_n+prob_y+eps))
    
    out = dict(hidden_states=hidden_states, ans=ans, text_ans=text_ans, input_truncated=input_truncated, input_id_shape=input_ids.shape,
                attentions=attentions, prob_n=prob_n, prob_y=prob_y, scores=outputs['scores'][:, 0], input_text=input_text,
               )
    out = {k:to_numpy(v) for k,v in out.items()}    
    return out

In [262]:
def pred(text):
    input_text = f"""
    <|system|><|end|>
    <|user|>{text}<|end|>
    <|assistant|>
    """
    hs0 = get_hidden_states(model, tokenizer, input_text)
    hs1 = get_hidden_states(model, tokenizer, input_text)

    hs2x = lambda x: torch.from_numpy(x['hidden_states'].transpose(0, 2, 1)).to(device)
    x0 = hs2x(hs0)
    x1 = hs2x(hs1)
    assert not (x1==x0).all()
    x0.shape

    net.eval()
    with torch.no_grad():
        ypred0 = to_numpy(net.probe(x0))
        ypred1 = to_numpy(net.probe(x1))
        y_pred = (ypred1-ypred0)
        
    y_pred
    hs2_more_positive = (hs1['ans']-hs0['ans'])>0
    y_pred * np.sign(hs1['ans']-hs0['ans'])
    print(f"""========================
{text}
----------------------------
Model Ans0: "{hs0['text_ans'][0]}" prob_y={hs0['prob_y']:2.2f} prob_n={hs0['prob_n']} probe={ypred0}
Model Ans1: "{hs1['text_ans'][0]}" prob_y={hs1['prob_y']:2.2f} prob_n={hs1['prob_n']} probe={ypred0}
Probe says: {y_pred:2.4f} (+ve means ans0 is more true)
where
    hs1_more_positive={hs2_more_positive}
    hs1_more_true={y_pred>0}
""")

In [263]:
texts = [
    "A broken mirror will give 7 years bad luck?"
    ]
for text in texts:
   pred(text)   



A broken mirror will give 7 years bad luck?
----------------------------
Model Ans0: " Yes" prob_y=0.00 prob_n=2.384185791015625e-07 probe=-0.22275735437870026
Model Ans1: " No" prob_y=0.00 prob_n=6.556510925292969e-07 probe=-0.22275735437870026
Probe says: -0.1197 (+ve means ans0 is more true)
where
    hs1_more_positive=True
    hs1_more_true=False



In [245]:
device = next(net.parameters()).device


Model says: Positive Positive prob_y=0.44 prob_n=0.01
Probe says: -0.0602 (hs2 is more true)
where
    hs2_more_positive=True




Model says: Positive Positive prob_y=0.56 prob_n=0.03
Probe says: -0.1134 (hs2 is more true)
where
    hs2_more_positive=True



What is the probe predicting? Whether hs1 is more true than hs0