In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

import torch.nn as nn
import torch.nn.functional as F
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class ZeroShotTester:
    def __init__(self, model_name, device=None, dtype=torch.float16):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=dtype, device_map="auto"
        )
        self.model.eval()

    @torch.no_grad()
    def get_hidden_state(self, text, layer_index):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
        outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
        return outputs.hidden_states[layer_index][0]  # [seq_len, hidden_dim]

    def get_token_representation(self, text, layer, position):
        hidden = self.get_hidden_state(text, layer)
        return hidden[position]

    def extract_representations(self, sentences, layer, position):
        reps = [self.get_token_representation(s, layer, position).cpu() for s in tqdm(sentences)]
        return torch.stack(reps)

    def evaluate(
        self,
        x0_sentences,
        x1_sentences,
        layers=[5, 10, 20, 30],
        positions=[-1],
        verbose=True,
    ):
        results = {}
        for layer in layers:
            for pos in positions:
                if verbose:
                    print(f"Extracting layer {layer}, pos {pos}...")
                x0_repr = self.extract_representations(x0_sentences, layer, pos)
                x1_repr = self.extract_representations(x1_sentences, layer, pos)

                X = torch.cat([x0_repr, x1_repr]).numpy()
                y = np.array([0] * len(x0_repr) + [1] * len(x1_repr))

                clf = LogisticRegression().fit(X, y)
                acc = clf.score(X, y)
                results[(layer, pos)] = acc

                if verbose:
                    print(f"✅ Layer {layer:2d}, Pos {pos:2d} → Accuracy: {acc:.3f}")
        return results