## Let's implement CCS from scratch.
This will deliberately be a simple (but less efficient) implementation to make everything as clear as possible.

In [1]:
from tqdm.auto import tqdm
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
os.environ["HF_DATASETS_OFFLINE"] = "1"
from datasets import load_dataset
import datasets
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression

import lightning.pytorch as pl
from dataclasses import dataclass
from torch.utils.data import random_split, DataLoader, TensorDataset
from transformers.models.auto.modeling_auto import AutoModel
# from scipy.stats import zscore
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from sklearn.preprocessing import RobustScaler
import gc

import os

  from .autonotebook import tqdm as notebook_tqdm


## Model

In [2]:
# from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import LlamaForCausalLM, LlamaTokenizer

In [3]:
# Here are a few different model options you can play around with:
model_name = "deberta"
model_name = "gpt-j"
# model_name = "t5"
model_name = "llama"
model_name = "alpaca"
finetuned = None

model_options = dict(
    device_map="auto", 
    load_in_8bit=True,
    torch_dtype=torch.float16,
)


if model_name == "deberta":
    model_type = "encoder"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xxlarge")
    model = AutoModelForMaskedLM.from_pretrained("microsoft/deberta-v2-xxlarge", **model_options)
elif model_name == "gpt-j":
    model_type = "decoder"
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", **model_options)
elif model_name == "t5":
    model_type = "encoder_decoder"
    tokenizer = AutoTokenizer.from_pretrained("t5-11b")
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-11b", **model_options)
    model.parallelize()  # T5 is big enough that we may need to run it on multiple GPUs
elif ("llama" in model_name) or ("alpaca" in model_name):
    # https://github.com/deep-diver/LLM-As-Chatbot/blob/216abb559d00a0555f41a1426ac9db6c1abc24f3/models/alpaca.py
    model_repo = "Neko-Institute-of-Science/LLaMA-7B-HF"
    lora_repo = "tloen/alpaca-lora-7b"
    model_type = "decoder"
    tokenizer = LlamaTokenizer.from_pretrained(model_repo)
    model = LlamaForCausalLM.from_pretrained(model_repo, **model_options)
    
    if "alpaca" in model_name:
        from peft import PeftModel
        model = PeftModel.from_pretrained(
            model, 
            lora_repo, 
            device_map='auto'#{'': 0}
        )
        
    tokenizer.pad_token = 0
    tokenizer.padding_side = "left"
else:
    raise NotADirectoryError(model_name)
model


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/wassname/miniforge3/envs/dlk2/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/wassname/miniforge3/envs/dlk2/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)
Loading checkpoint shards: 100%|██████████████| 2/2 [00:07<00:00,  3.78s/it]


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=0)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear8bitLt(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear8bitLt(
                in_features=4096, out_features=4096, bias=

In [4]:
# tokenizer = AutoTokenizer.from_pretrained(model_repo)
# tokenizer.truncation_side='Left'
# tokenizer

In [5]:
# get the tokens for 0 and 1, we will use these later...
id_0, id_1 = tokenizer('0')['input_ids'][-1], tokenizer('1')['input_ids'][-1]
id_0, id_1

(29900, 29896)

## Dataset

In [6]:
# debug
datasets.logging.set_verbosity_info()


# Let's just try IMDB for simplicity
data = load_dataset("amazon_polarity")
# data = load_dataset("/home/wassname/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc/amazon_polarity-train-00003-of-00004.arrow")
data

Using the latest cached version of the module from /home/wassname/.cache/huggingface/modules/datasets_modules/datasets/amazon_polarity/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc (last modified on Sat May  6 07:52:43 2023) since it couldn't be found locally at amazon_polarity.
Found cached dataset amazon_polarity (/home/wassname/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc)
100%|█████████████████████████████████████████| 2/2 [00:00<00:00,  3.96it/s]


DatasetDict({
    train: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 3600000
    })
    test: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 400000
    })
})

In [22]:
def format_imdb(text, label):
    return f"""Review: "I think this is a lovely family movie. There are plenty of hilarious scenes and heart-warming moments to be had throughout the movie. The actors are great and the effects well executed throughout. Danny Glover plays George Knox who manages the terrible baseball team 'The Angels' and is great throughout the film. Also fantastic are the young actors Joseph Gordon-Levitt and Milton Davis Jr. Christopher Lloyd is good as Al 'The Angel' and the effects are great in this top notch Disney movie. A touching and heart-warming movie which everyone should enjoy."
Question: Is this review positive? 
Answer: 1
---
Review: "{text}"
Question: Is this review {'positive' if label else 'negative'}?
Answer: 
"""

def format_imdbs(texts, labels):
    return [format_imdb(t, labels) for t in texts]

print(format_imdb("The movie was the worst.... not!", 0))

Review: "I think this is a lovely family movie. There are plenty of hilarious scenes and heart-warming moments to be had throughout the movie. The actors are great and the effects well executed throughout. Danny Glover plays George Knox who manages the terrible baseball team 'The Angels' and is great throughout the film. Also fantastic are the young actors Joseph Gordon-Levitt and Milton Davis Jr. Christopher Lloyd is good as Al 'The Angel' and the effects are great in this top notch Disney movie. A touching and heart-warming movie which everyone should enjoy."
Question: Is this review positive? 
Answer: 1
---
Review: "The movie was the worst.... not!"
Question: Is this review negative?
Answer: 



In [23]:
# tokens
len(tokenizer(format_imdb("The movie was the worst.... not!", 0))['input_ids'])

174

## First check models text output

In [28]:
# gen_config_raw = {
#     "temperature": temperature,
#     "top_p": top_p,
#     "top_k": top_k,
#     "repetition_penalty": repetition_penalty,
#     "max_new_tokens": max_new_tokens,
#     "num_beams": num_beams,
#     "use_cache": use_cache,
#     "do_sample": do_sample,
#     "eos_token_id": eos_token_id, 
#     "pad_token_id": pad_token_id
# }

def get_output(model, tokenizer, input_text, add_bos_token=False, truncation_length=400):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    if not isinstance(input_text, list):
        input_text = [input_text]
    # tokenize (adding the EOS token this time)
#     input_text = [i + tokenizer.eos_token for i in input_text]
#     input_text = [i[-1000:] for i in input_text]
    input_ids = tokenizer(input_text, 
                          return_tensors="pt",
#                           truncation=True, 
#                           padding=True,
#                           max_length=600,
#                           add_special_tokens=False,
                         ).input_ids.to(model.device)
#     print('input_ids', input_ids.shape)

    # remove bos token? https://github.com/oobabooga/text-generation-webui/blob/1b52bddfcc70d2db88257d36f1c6d182573588c4/modules/text_generation.py#L36
    if not add_bos_token and input_ids[0][0] == tokenizer.bos_token_id:
        input_ids = input_ids[:, 1:]


    # Llama adds this extra token when the first character is '\n', and this
    # compromises the stopping criteria, so we just remove it
    if type(tokenizer) is LlamaTokenizer and input_ids[0][0] == 29871:
        print('removed extra \n token')
        input_ids = input_ids[:, 1:]
        
    # Handling truncation
    if truncation_length is not None:
        input_ids = input_ids[:, -truncation_length:]

    # forward pass
    with torch.no_grad():
        output = model.generate(input_ids=input_ids, max_length=400)
#     print(output)
    
    text_q = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
    text_ans = tokenizer.batch_decode(output, skip_special_tokens=False)#, skip_prompt=True, skip_special_tokens=True)
    print(text_q[0])
    print('-'*40+'answ'+'-'*40)
    print(text_ans[0])


In [29]:
tokenizer.pad_token_id=0
tokenizer.padding_side = "left"

In [30]:
idx = 1
text, true_label = data['test'][idx]["content"], data['test'][idx]["label"]

In [31]:
input_text = [format_imdb(text, 1)]
# input_text = [i + tokenizer.eos_token for i in input_text]
get_output(model, tokenizer, input_text)

Review: "I think this is a lovely family movie. There are plenty of hilarious scenes and heart-warming moments to be had throughout the movie. The actors are great and the effects well executed throughout. Danny Glover plays George Knox who manages the terrible baseball team 'The Angels' and is great throughout the film. Also fantastic are the young actors Joseph Gordon-Levitt and Milton Davis Jr. Christopher Lloyd is good as Al 'The Angel' and the effects are great in this top notch Disney movie. A touching and heart-warming movie which everyone should enjoy."
Question: Is this review positive? 
Answer: 1
---
Review: "Despite the fact that I have only played a small portion of the game, the music I heard (plus the connection to Chrono Trigger which was great as well) led me to purchase the soundtrack, and it remains one of my favorite albums. There is an incredible mix of fun, epic, and emotional songs. Those sad and beautiful tracks I especially like, as there's not too many of those

## Write code for extracting hidden states given a model and text. 
How we do this exactly will depend on the type of model.

In [41]:


def get_decoder_hidden_states(model, tokenizer, input_text, layers=[2, -2], add_bos_token=False, truncation_length=400):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    if not isinstance(input_text, list):
        input_text = [input_text]
    # tokenize (adding the EOS token this time)
    input_text = [i + tokenizer.eos_token for i in input_text]
    input_text = [i[-1000:] for i in input_text]
    input_ids = tokenizer(input_text, 
                          return_tensors="pt",
#                           truncation=True, 
                          padding=True,
#                           max_length=600
                         ).input_ids.to(model.device)
#     print('input_ids', input_ids.shape)


    # remove bos token? https://github.com/oobabooga/text-generation-webui/blob/1b52bddfcc70d2db88257d36f1c6d182573588c4/modules/text_generation.py#L36
    if not add_bos_token and input_ids[0][0] == tokenizer.bos_token_id:
        input_ids = input_ids[:, 1:]


    # Llama adds this extra token when the first character is '\n', and this
    # compromises the stopping criteria, so we just remove it
    if type(tokenizer) is LlamaTokenizer and input_ids[0][0] == 29871:
        print('removed extra \n token')
        input_ids = input_ids[:, 1:]
        
    # Handling truncation
    if truncation_length is not None:
        input_ids = input_ids[:, -truncation_length:]

    # forward pass
    with torch.no_grad():
        output = model(input_ids, 
                       output_hidden_states=True,
#                        , output_attentions=True
                       use_cache=True,
                       
                      )
    
    # the output is large, so we will just select what we want 1) the first token with[:, 0]
    # 2) selected layers with [layers]
#     output['attentions'] = [output['attentions'][i] for i in layers]
#     output['attentions'] = [v.detach().cpu()[:, -1] for v in output['attentions']]
#     output['attentions'] = torch.concat(output['attentions'])
    
    
    # dims [Batch, Token, Probs?]
    output['hidden_states'] = torch.stack([output['hidden_states'][i] for i in layers], 1).detach().cpu()
    # dims [Batch, Layers, Seq_Token, Probs?] e.g. torch.Size([3, 2, 284, 4096])
    
    # dims [Batch, ?, Output_Tokens] e.g. torch.Size([3, 284, 32000])
    o = output['logits'].detach().cpu().float().softmax(-1)
    
    text_q = [tokenizer.decode(oo) for oo in input_ids]
    text_ans = [tokenizer.decode(oo) for oo in o.argmax(-1)]

    nth_place = 0
    prob_0, prob1 = o[:, nth_place][:, [id_0, id_1]].T # get the prob of 0 vs 1 in nth place in answer
    output['ans'] = (prob1/(prob_0+prob1))
    # FIXME output batch
    return dict(hidden_states=output['hidden_states'], ans=output['ans'], text_ans=text_ans, text_q=text_q
#                 , attentions=output['attentions']
               )

def get_hidden_states(model, tokenizer, input_text, layers=[2, -2], model_type="encoder"):
    fn = {
          "decoder": get_decoder_hidden_states}[model_type]

    return fn(model, tokenizer, input_text, layers=layers)

In [42]:
# input_text = [format_imdb(text, 0)]
# input_text = [i + tokenizer.eos_token for i in input_text]
# input_ids = tokenizer(input_text, 
#                       return_tensors="pt",
#                       truncation=True, 
#                       padding=True,
#                       max_length=300).input_ids.to(model.device)
# print(tokenizer.decode(input_ids[0]))

In [43]:
# unit test
idx = 0
text, true_label = data['test'][idx]["content"], data['test'][idx]["label"]
neg_hs = get_hidden_states(model, tokenizer, format_imdb(text, 0), model_type=model_type)
pos_hs = get_hidden_states(model, tokenizer, format_imdb(text, 1), model_type=model_type)
# neg_hs

In [44]:
print('-'*40+'input'+'-'*40)
print(neg_hs['text_q'][0])
print('-'*40+'answ'+'-'*40)
print(neg_hs['text_ans'][0])
print('='*80)
print('-'*40+'input'+'-'*40)
print(pos_hs['text_q'][0])
print('-'*40+'answ'+'-'*40)
print(pos_hs['text_ans'][0])
print('-'*80)

----------------------------------------input----------------------------------------
 hout. Danny Glover plays George Knox who manages the terrible baseball team 'The Angels' and is great throughout the film. Also fantastic are the young actors Joseph Gordon-Levitt and Milton Davis Jr. Christopher Lloyd is good as Al 'The Angel' and the effects are great in this top notch Disney movie. A touching and heart-warming movie which everyone should enjoy."
Question: Is this review positive? 
Answer: 1
---
Review: "My lovely Pat has one of the GREAT voices of her generation. I have listened to this CD for YEARS and I still LOVE IT. When I'm in a good mood it makes me feel better. A bad mood just evaporates like sugar in the rain. This CD just oozes LIFE. Vocals are jusat STUUNNING and lyrics just kill. One of life's hidden gems. This is a desert isle CD in my book. Why she never made it big is just beyond me. Everytime I play this, no matter black, white, young, old, male, female EVERYBODY sa

In [45]:
# # unit tests
# idx = 0
# n=10
# batch_size=3
# ds_subset = data['test'].shuffle(42).select(range(n))
# dl = DataLoader(ds_subset, batch_size=batch_size, shuffle=True)
# batch = next(iter(dl))

# texts, true_labels = batch["content"], batch["label"]
# neg_hs = get_hidden_states(model, tokenizer, format_imdbs(texts, 0), model_type=model_type)
# neg_hs
# for k,v in neg_hs.items():
#     print(k, v.shape)

## Now let's write code for formatting data and for getting all the hidden states.

In [18]:


def get_hidden_states_many_examples(model, tokenizer, data, model_type, n=100, layers=[2, -2], batch_size=3):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples.
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    
    res = []
    
    ds_subset = data['test'].shuffle(42).select(range(n))
    dl = DataLoader(ds_subset, batch_size=batch_size, shuffle=True)
    for batch in tqdm(dl):
        text, true_label = batch["content"], batch["label"]
        neg = get_hidden_states(model, tokenizer, format_imdbs(text, 0), model_type=model_type, layers=layers)
        pos = get_hidden_states(model, tokenizer, format_imdbs(text, 1), model_type=model_type, layers=layers)

        # collect
        b = len(text)
#         print(neg['hidden_states'].shape)
        res.append([
            neg['hidden_states'].reshape((b,-1)),
            pos['hidden_states'].reshape((b,-1)),
            true_label,
            neg['ans'],  
            pos['ans'],            
        ])
    
    # FIXME not all the hidden state are the same size, wat
    res = [np.concatenate(r) for r in zip(*res)]
    return res
    all_neg_hs, all_pos_hs, all_gt_labels, all_neg_ans, all_pos_ans = res
    return all_neg_hs, all_pos_hs, all_gt_labels, all_neg_ans, all_pos_ans
#     return all_neg_hs, all_pos_hs, all_gt_labels, np.array(all_neg_ans), np.array(all_pos_ans)

# Lets verify that the models answers are good

In [19]:
# neg_hs, pos_hs, y, all_neg_ans, all_pos_ans = get_hidden_states_many_examples(model, tokenizer, data, model_type, n=10)

Speed

- 60second for 100 no batching. 1.7 ex/s

In [20]:
gc.collect()
torch.cuda.empty_cache()
gc.collect()

0

In [21]:
neg_hs, pos_hs, y, all_neg_ans, all_pos_ans = get_hidden_states_many_examples(model, tokenizer, data, model_type)


gc.collect()
torch.cuda.empty_cache()
gc.collect()

Loading cached shuffled indices for dataset at /home/wassname/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc/cache-0a5d0b47b5e8dfc6.arrow
100%|███████████████████████████████████████| 34/34 [00:27<00:00,  1.22it/s]


In [None]:
# all_pos_ans

In [None]:
# roc_auc_score
pos_score = roc_auc_score(y, all_pos_ans)
neg_score = roc_auc_score(y, all_neg_ans)
pos_score, neg_score

In [None]:
# accuracy_score
pos_score = accuracy_score(y, (all_pos_ans>0.)*1.0)
neg_score = accuracy_score(y, (all_neg_ans<0.5)*1.0)
pos_score, neg_score

## Let's verify that the model's representations are good

Before trying CCS, let's make sure there exists a direction that classifies examples as true vs false with high accuracy; if supervised logistic regression accuracy is bad, there's no hope of unsupervised CCS doing well.

Note that because logistic regression is supervised we expect it to do better but to have worse generalisation that equivilent unsupervised methods. However in this case CSS is using a deeper model so it is more complicated.

In [None]:
# let's create a simple 50/50 train split (the data is already randomized)
n = len(y)

neg_hs2 = torch.from_numpy(np.stack([h.flatten() for h in neg_hs], 0))
pos_hs2 = torch.from_numpy(np.stack([h.flatten() for h in pos_hs], 0))

neg_hs_train, neg_hs_test = neg_hs2[:n//2], neg_hs2[n//2:]
pos_hs_train, pos_hs_test = pos_hs2[:n//2], pos_hs2[n//2:]
y_train, y_test = y[:n//2], y[n//2:]

# for simplicity we can just take the difference between positive and negative hidden states
# (concatenating also works fine)
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test

lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {} [TRAIN]".format(lr.score(x_train, y_train)))
print("Logistic regression accuracy: {} [TEST]".format(lr.score(x_test, y_test)))

## Now let's try CCS

In [None]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
#             nn.Linear(100, 100),
#             nn.ReLU(),
            nn.Linear(100, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)


## Train

In [None]:
# # Train CCS without any labels
# ccs = CCS(neg_hs_train, pos_hs_train, linear=False)
# ccs.repeated_train()

# # Evaluate
# ccs_acc = ccs.get_acc(neg_hs_train, pos_hs_train, y_train)
# print("CCS nonlinear train accuracy: {}".format(ccs_acc))

# ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
# print("CCS nonlinear test accuracy: {}".format(ccs_acc))

In [None]:
# # Train CCS without any labels
# ccs = CCS(neg_hs_train, pos_hs_train, linear=True)
# ccs.repeated_train()

# # Evaluate
# ccs_acc = ccs.get_acc(neg_hs_train, pos_hs_train, y_train)
# print("CCS train accuracy: {}".format(ccs_acc))

# ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
# print("CCS test accuracy: {}".format(ccs_acc))

# lightning

## DataModule

In [None]:

# def normalize(x):
#     """
#     Mean-normalizes the data x (of shape (n, d))
#     If self.var_normalize, also divides by the standard deviation
#     """
#     normalized_x = x - x.mean(axis=0, keepdims=True)
#     if self.var_normalize:
#         normalized_x /= normalized_x.std(axis=0, keepdims=True)

#     return normalized_x


class IMBDHSDataModule(pl.LightningDataModule):

    def __init__(self,
                 model: AutoModel,
                 tokenizer: AutoTokenizer,
                 model_type="decoder",
                 dataset_name="amazon_polarity",
                 batch_size=32,
                 n=200,
                ):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.save_hyperparameters(ignore=["model", "tokenizer"])

    def setup(self, stage: str):

        self.dataset = load_dataset(self.hparams.dataset_name, split="test")

        neg_hs, pos_hs, y, all_neg_ans, all_pos_ans = get_hidden_states_many_examples(
            self.model, self.tokenizer, self.dataset, self.hparams.model_type, n=self.hparams.n, layers=[2, -2])

        # let's create a simple 50/50 train split (the data is already randomized)
        n = len(y)
        val_split = int(n * 0.5)
        test_split = int(n * 0.75)
        neg_hs_train, pos_hs_train, y_train = neg_hs[:
                                                     val_split], pos_hs[:
                                                                        val_split], y[:
                                                                                      val_split]
        neg_hs_val, pos_hs_val, y_val = neg_hs[val_split:test_split], pos_hs[
            val_split:test_split], y[val_split:test_split]
        neg_hs_test, pos_hs_test, y_test = neg_hs[test_split:], pos_hs[
            test_split:], y[test_split:]

        # for simplicity we can just take the difference between positive and negative hidden states
        # (concatenating also works fine)
        self.x_train = neg_hs_train - pos_hs_train
        self.x_val = neg_hs_val - pos_hs_val
        self.x_test = neg_hs_test - pos_hs_test

        # normalize
        self.scaler = RobustScaler()
        self.scaler.fit(self.x_train)
        self.x_train = self.scaler.transform(self.x_train)
        self.x_val = self.scaler.transform(self.x_val)
        self.x_test = self.scaler.transform(self.x_test)

        self.ds_train = TensorDataset(torch.from_numpy(neg_hs_train).float(),
                                      torch.from_numpy(pos_hs_train).float(),
                                      torch.from_numpy(y_train).float())

        self.ds_val = TensorDataset(torch.from_numpy(neg_hs_val).float(),
                                    torch.from_numpy(pos_hs_val).float(),
                                    torch.from_numpy(y_val).float())

        self.ds_test = TensorDataset(torch.from_numpy(neg_hs_test).float(),
                                     torch.from_numpy(pos_hs_test).float(),
                                     torch.from_numpy(y_test).float())

    def train_dataloader(self):
        return DataLoader(self.ds_train,
                          batch_size=self.hparams.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.hparams.batch_size)

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=self.hparams.batch_size)


# test
dm = IMBDHSDataModule(model, tokenizer)
dm.setup('train')
dl = dm.val_dataloader()
b = next(iter(dl))
b

In [None]:
dm.x_test.shape

## LightningModel

In [None]:
from torch import optim

In [None]:


def get_loss(p0, p1):
    """
    Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
    """
    informative_loss = (torch.min(p0, p1)**2).mean(0)
    consistent_loss = ((p0 - (1-p1))**2).mean(0)
    return informative_loss + consistent_loss


def get_acc(p0, p1, y):
    avg_confidence = 0.5*(p0 + (1-p1))
    predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
    
    # TODO f1
    conf = (avg_confidence.detach().cpu().numpy() )[:, 0]
    
    acc = (predictions == y.cpu().numpy()).mean()
    acc = max(acc, 1 - acc)
    return predictions, acc

def get_f1(p0, p1, y):
    avg_confidence = 0.5*(p0 + (1-p1))
    predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
    
    # TODO f1
    conf = (avg_confidence.detach().cpu().numpy() )[:, 0]
    auc = roc_auc_score(y.cpu().numpy(), predictions)
    
    auc = max(auc, 1 - auc)
    return predictions, auc

class CSS(pl.LightningModule):
    def __init__(self, d, max_epochs, lr=4e-3, weight_decay=1e-6):
        super().__init__()
        self.probe = MLPProbe(d)
        self.save_hyperparameters()
        
    def forward(self, x):
        return self.probe(x)
        
    def _step(self, batch, batch_idx, stage='train'):
        x0, x1, y = batch
        p0, p1 = self(x0), self(x1)
        
        loss = get_loss(p0, p1)
        
        self.log(f"{stage}/loss", loss)
        
        predictions, acc = get_acc(p0, p1, y)
        self.log(f"{stage}/acc", acc)
        predictions, f1 = get_f1(p0, p1, y)
        self.log(f"{stage}/f1", f1)
        return loss
    
    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx)
    
    def validation_step(self, batch, batch_idx=0):
        return self._step(batch, batch_idx, stage='val')
    
    def prediction_step(self, batch, batch_idx):
        x0, x1, y = batch
        p0, p1 = self(x0), self(x1)
        predictions, acc = get_acc(p0, p1, y)
        return predictions 

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.hparams.max_epochs, eta_min=self.hparams.lr / 50
        )
        return [optimizer], [lr_scheduler]
    

In [None]:
# init the autoencoder
max_epochs = 1000
d = b[0].shape[-1]
net = CSS(d=d, max_epochs=max_epochs)

In [None]:
# train_loader = utils.data.DataLoader(dataset)

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=max_epochs)
trainer.fit(model=net, datamodule=dm)

In [None]:
%debug

# Read hist

In [None]:
# import pytorch_lightning as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger
# from pytorch_lightning.loggers.csv_logs import CSVLogger as CSVLogger2
from pathlib import Path
import pandas as pd

def read_metrics_csv(metrics_file_path):
    df_hist = pd.read_csv(metrics_file_path)
    df_hist["epoch"] = df_hist["epoch"].ffill()
    df_histe = df_hist.set_index("epoch").groupby("epoch").mean()
    return df_histe


def read_hist(trainer: pl.Trainer):

    ts = [t for t in trainer.loggers if isinstance(t, CSVLogger)]
    print(ts)
    try:
        metrics_file_path = Path(ts[0].experiment.metrics_file_path)
        df_histe = read_metrics_csv(metrics_file_path)
        return df_histe
    except Exception as e:
        raise e
        print(e)

In [None]:
df_hist = read_hist(trainer).ffill().bfill()
df_hist

In [None]:
df_hist[['val/acc', 'train/acc']].plot()

df_hist[['val/f1', 'train/f1']].plot()

df_hist[['val/loss', 'train/loss']].plot()