## Let's implement CCS from scratch.
This will deliberately be a simple (but less efficient) implementation to make everything as clear as possible.

In [1]:
from tqdm.auto import tqdm
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression

  from .autonotebook import tqdm as notebook_tqdm


## Model

In [2]:
# Here are a few different model options you can play around with:
model_name = "deberta"
model_name = "gpt-j"
# model_name = "t5"
model_name = "llama"
# model_name = "alpaca"
finetuned = None

model_options = dict(
    device_map="auto", 
    load_in_8bit=True,
    torch_dtype=torch.float16,
)


if model_name == "deberta":
    model_type = "encoder"
    tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xxlarge")
    model = AutoModelForMaskedLM.from_pretrained("microsoft/deberta-v2-xxlarge", **model_options)
elif model_name == "gpt-j":
    model_type = "decoder"
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
    model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", **model_options)
elif model_name == "t5":
    model_type = "encoder_decoder"
    tokenizer = AutoTokenizer.from_pretrained("t5-11b")
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-11b", **model_options)
    model.parallelize()  # T5 is big enough that we may need to run it on multiple GPUs
elif ("llama" in model_name) or ("alpaca" in model_name):
    model_repo = "Neko-Institute-of-Science/LLaMA-7B-HF"
    lora_repo = "tloen/alpaca-lora-7b"
    model_type = "decoder"
    tokenizer = AutoTokenizer.from_pretrained(model_repo)
    model = AutoModelForCausalLM.from_pretrained(model_repo, **model_options)
    
    if "alpaca" in model_name:
        from peft import PeftModel
        model = PeftModel.from_pretrained(
            model, 
            lora_repo, 
            device_map='auto'#{'': 0}
        )
else:
    raise NotADirectoryError(model_name)
model


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/wassname/miniforge3/envs/dlk2/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/wassname/miniforge3/envs/dlk2/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)
Loading checkpoint shards: 100%|███████████████████████████████████████████| 2/2 [00:06<00:00,  3.49s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSN

In [3]:
# get the tokens for 0 and 1, we will use these later...
id_0, id_1 = tokenizer('0')['input_ids'][-1], tokenizer('1')['input_ids'][-1]

## Dataset

In [4]:
# Let's just try IMDB for simplicity
data = load_dataset("amazon_polarity")["test"]

Found cached dataset amazon_polarity (/home/wassname/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc)
100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.87it/s]


In [5]:
def format_imdb(text, label):
    return f"""Review: "Whoever wrote the screenplay for this movie obviously never consulted any books about Lucille Ball, especially her autobiography. I've never seen so many mistakes in a biopic, ranging from her early years in Celoron and Jamestown to her later years with Desi. I could write a whole list of factual errors, but it would go on for pages. In all, I believe that Lucille Ball is one of those inimitable people who simply cannot be portrayed by anyone other than themselves. If I were Lucie Arnaz and Desi, Jr., I would be irate at how many mistakes were made in this film. The filmmakers tried hard, but the movie seems awfully sloppy to me."
###
Is this review negative? 1
###
Review: "This version of Anna Christie is in German. Greta Garbo again plays Anna Christie, but all of the other characters have different actors from the English version. Both were filmed back to back because Garbo had such a following in Germany. Garbo herself supposedly favored her Anna Christie in this version over the English version. It's a good tale and a must-see for Garbo fans."
###
Is this review negative? 0
###
Review: "I think this is a lovely family movie. There are plenty of hilarious scenes and heart-warming moments to be had throughout the movie. The actors are great and the effects well executed throughout. Danny Glover plays George Knox who manages the terrible baseball team 'The Angels' and is great throughout the film. Also fantastic are the young actors Joseph Gordon-Levitt and Milton Davis Jr. Christopher Lloyd is good as Al 'The Angel' and the effects are great in this top notch Disney movie. A touching and heart-warming movie which everyone should enjoy."
###
Is this review positive? 1
###
Review: "{text}"
###
Is this review {'positive' if label else 'negative'}? """



print(format_imdb("The movie was the worst.... not!", 0))

Review: "Whoever wrote the screenplay for this movie obviously never consulted any books about Lucille Ball, especially her autobiography. I've never seen so many mistakes in a biopic, ranging from her early years in Celoron and Jamestown to her later years with Desi. I could write a whole list of factual errors, but it would go on for pages. In all, I believe that Lucille Ball is one of those inimitable people who simply cannot be portrayed by anyone other than themselves. If I were Lucie Arnaz and Desi, Jr., I would be irate at how many mistakes were made in this film. The filmmakers tried hard, but the movie seems awfully sloppy to me."
###
Is this review negative? 1
###
Review: "This version of Anna Christie is in German. Greta Garbo again plays Anna Christie, but all of the other characters have different actors from the English version. Both were filmed back to back because Garbo had such a following in Germany. Garbo herself supposedly favored her Anna Christie in this version o

## First let's write code for extracting hidden states given a model and text. 
How we do this exactly will depend on the type of model.

In [6]:
# def get_encoder_hidden_states(model, tokenizer, input_text, layer=-1):
#     """
#     Given an encoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
#     on that input text (where the full text is given to the encoder).

#     Returns a numpy array of shape (hidden_dim,)
#     """
#     # tokenize
#     encoder_text_ids = tokenizer(input_text, truncation=True, return_tensors="pt").input_ids.to(model.device)

#     # forward pass
#     with torch.no_grad():
#         output = model(encoder_text_ids, output_hidden_states=True)

#     # get the appropriate hidden states
#     hs_tuple = output["hidden_states"]
    
#     hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

#     return hs

# def get_encoder_decoder_hidden_states(model, tokenizer, input_text, layer=-1):
#     """
#     Given an encoder-decoder model and some text, gets the encoder hidden states (in a given layer, by default the last) 
#     on that input text (where the full text is given to the encoder).

#     Returns a numpy array of shape (hidden_dim,)
#     """
#     # tokenize
#     encoder_text_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
#     decoder_text_ids = tokenizer("", return_tensors="pt").input_ids.to(model.device)

#     # forward pass
#     with torch.no_grad():
#         output = model(encoder_text_ids, decoder_input_ids=decoder_text_ids, output_hidden_states=True)

#     # get the appropriate hidden states
#     hs_tuple = output["encoder_hidden_states"]
#     hs = hs_tuple[layer][0, -1].detach().cpu().numpy()

#     return hs

def get_decoder_hidden_states(model, tokenizer, input_text, layers=[2, -2]):
    """
    Given a decoder model and some text, gets the hidden states (in a given layer, by default the last) on that input text

    Returns a numpy array of shape (hidden_dim,)
    """
    # tokenize (adding the EOS token this time)
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)

    # forward pass
    
    with torch.no_grad():
        # FIXME: should be a batch, to speed it up
        output = model(input_ids, 
                       output_hidden_states=True
#                        , output_attentions=True
                      )

    # the output is large, so we will just select what we want 1) the first token with[:, 0]
    # 2) selected layers with [layers]
#     output['attentions'] = [output['attentions'][i] for i in layers]
#     output['attentions'] = [v.detach().cpu()[:, -1] for v in output['attentions']]
#     output['attentions'] = torch.concat(output['attentions'])
    
    output['hidden_states'] = [output['hidden_states'][i] for i in layers]
    # dims [Batch, Token, Probs?]
    output['hidden_states'] = [v.detach().cpu()[:, -1] for v in output['hidden_states']]
    output['hidden_states'] = torch.concat(output['hidden_states'])
    
    o = output['logits'].detach().cpu().float().softmax(-1)
#     print(input_text)
#     print(tokenizer.decode(o.argmax(-1)[0]))
        
    prob_0, prob1 = o[0, 1][[id_0, id_1]]
    output['ans'] = (prob1/(prob_0+prob1)).item()

    return dict(hidden_states=output['hidden_states'], ans=output['ans']
#                 , attentions=output['attentions']
               )
#     hs_tuple = output["hidden_states"]
#     hs_tuple = (a.detach().cpu().numpy() for a in hs_tuple)
#     return output
#     hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
    
    # FIXME pass full output, as I want to check model zero shot accuracy!

#     return hs

def get_hidden_states(model, tokenizer, input_text, layers=[2, -2], model_type="encoder"):
    fn = {
# "encoder": get_encoder_hidden_states, "encoder_decoder": get_encoder_decoder_hidden_states,
          "decoder": get_decoder_hidden_states}[model_type]

    return fn(model, tokenizer, input_text, layers=layers)

In [7]:
# print(format_imdb(text, 0))

In [8]:
# unit test
idx = 0
text, true_label = data[idx]["content"], data[idx]["label"]
neg_hs = get_hidden_states(model, tokenizer, format_imdb(text, 0), model_type=model_type)
neg_hs

{'hidden_states': tensor([[-0.0331,  0.0093, -0.0510,  ..., -0.0213,  0.0366,  0.0545],
         [-0.9800,  3.0742,  2.6406,  ...,  4.4609,  1.9453, -0.5400]],
        dtype=torch.float16),
 'ans': 0.6738358736038208}

In [9]:
# neg_hs[0]['hidden_states'][-1]

In [10]:
print(format_imdb(text, 1))

Review: "Whoever wrote the screenplay for this movie obviously never consulted any books about Lucille Ball, especially her autobiography. I've never seen so many mistakes in a biopic, ranging from her early years in Celoron and Jamestown to her later years with Desi. I could write a whole list of factual errors, but it would go on for pages. In all, I believe that Lucille Ball is one of those inimitable people who simply cannot be portrayed by anyone other than themselves. If I were Lucie Arnaz and Desi, Jr., I would be irate at how many mistakes were made in this film. The filmmakers tried hard, but the movie seems awfully sloppy to me."
###
Is this review negative? 1
###
Review: "This version of Anna Christie is in German. Greta Garbo again plays Anna Christie, but all of the other characters have different actors from the English version. Both were filmed back to back because Garbo had such a following in Germany. Garbo herself supposedly favored her Anna Christie in this version o

In [11]:
# # sceatch
# layer = -10
# input_ids = tokenizer(text + tokenizer.eos_token, return_tensors="pt").input_ids.to(model.device)
# # forward pass
# with torch.no_grad():
#     output = model(input_ids, output_hidden_states=True)

# # get the last layer, last token hidden states
# hs_tuple = output["hidden_states"]
# hs = hs_tuple[layer][0, -1].detach().cpu().numpy()
# hs, output['logits'], output['hidden_states']

## Now let's write code for formatting data and for getting all the hidden states.

In [12]:


def get_hidden_states_many_examples(model, tokenizer, data, model_type, n=100, layers=[2, -2]):
    """
    Given an encoder-decoder model, a list of data, computes the contrast hidden states on n random examples.
    Returns numpy arrays of shape (n, hidden_dim) for each candidate label, along with a boolean numpy array of shape (n,)
    with the ground truth labels
    
    This is deliberately simple so that it's easy to understand, rather than being optimized for efficiency
    """
    # setup
    model.eval()
    all_neg_hs, all_pos_hs, all_gt_labels = [], [], []
    all_neg_ans, all_pos_ans = [], []

    # loop
    for _ in tqdm(range(n), unit='examples', desc='get_hidden_states'):
        # for simplicity, sample a random example until we find one that's a reasonable length
        # (most examples should be a reasonable length, so this is just to make sure)
        while True:
            idx = np.random.randint(len(data))
            text, true_label = data[idx]["content"], data[idx]["label"]
            # the actual formatted input will be longer, so include a bit of a margin
            if len(tokenizer(text)) < 400:  
                break
                
        # get hidden states
#         print(format_imdb(text, 0))
        neg = get_hidden_states(model, tokenizer, format_imdb(text, 0), model_type=model_type, layers=layers)
        pos = get_hidden_states(model, tokenizer, format_imdb(text, 1), model_type=model_type, layers=layers)

        # collect
        all_neg_hs.append(neg['hidden_states'].flatten())
        all_pos_hs.append(pos['hidden_states'].flatten())
        all_pos_ans.append(pos['ans'])
        all_neg_ans.append(neg['ans'])
        all_gt_labels.append(true_label)

    all_neg_hs = np.stack(all_neg_hs)
    all_pos_hs = np.stack(all_pos_hs)
    all_gt_labels = np.stack(all_gt_labels)

    return all_neg_hs, all_pos_hs, all_gt_labels, np.array(all_neg_ans), np.array(all_pos_ans)

# Lets verify that the models answers are good

In [13]:
neg_hs, pos_hs, y, all_neg_ans, all_pos_ans = get_hidden_states_many_examples(model, tokenizer, data, model_type)

import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()

get_hidden_states: 100%|█████████████████████████████████████████| 100/100 [00:56<00:00,  1.77examples/s]


0

In [14]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

pos_score = roc_auc_score(y, all_pos_ans)
neg_score = roc_auc_score(y, all_neg_ans)
pos_score, neg_score

(0.5366586538461539, 0.5759214743589743)

In [15]:

pos_score = accuracy_score(y, (all_pos_ans>0.)*1.0)
neg_score = accuracy_score(y, (all_neg_ans<0.5)*1.0)
pos_score, neg_score

(0.48, 0.52)

## Let's verify that the model's representations are good

Before trying CCS, let's make sure there exists a direction that classifies examples as true vs false with high accuracy; if supervised logistic regression accuracy is bad, there's no hope of unsupervised CCS doing well.

Note that because logistic regression is supervised we expect it to do better but to have worse generalisation that equivilent unsupervised methods. However in this case CSS is using a deeper model so it is more complicated.

In [16]:
# let's create a simple 50/50 train split (the data is already randomized)
n = len(y)

neg_hs2 = torch.from_numpy(np.stack([h.flatten() for h in neg_hs], 0))
pos_hs2 = torch.from_numpy(np.stack([h.flatten() for h in pos_hs], 0))

neg_hs_train, neg_hs_test = neg_hs2[:n//2], neg_hs2[n//2:]
pos_hs_train, pos_hs_test = pos_hs2[:n//2], pos_hs2[n//2:]
y_train, y_test = y[:n//2], y[n//2:]

# for simplicity we can just take the difference between positive and negative hidden states
# (concatenating also works fine)
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test

lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train, y_train)
print("Logistic regression accuracy: {} [TRAIN]".format(lr.score(x_train, y_train)))
print("Logistic regression accuracy: {} [TEST]".format(lr.score(x_test, y_test)))

Logistic regression accuracy: 1.0 [TRAIN]
Logistic regression accuracy: 0.96 [TEST]


## Now let's try CCS

In [17]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
#             nn.Linear(100, 100),
#             nn.ReLU(),
            nn.Linear(100, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)


## Train

In [18]:
# # Train CCS without any labels
# ccs = CCS(neg_hs_train, pos_hs_train, linear=False)
# ccs.repeated_train()

# # Evaluate
# ccs_acc = ccs.get_acc(neg_hs_train, pos_hs_train, y_train)
# print("CCS nonlinear train accuracy: {}".format(ccs_acc))

# ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
# print("CCS nonlinear test accuracy: {}".format(ccs_acc))

In [19]:
# # Train CCS without any labels
# ccs = CCS(neg_hs_train, pos_hs_train, linear=True)
# ccs.repeated_train()

# # Evaluate
# ccs_acc = ccs.get_acc(neg_hs_train, pos_hs_train, y_train)
# print("CCS train accuracy: {}".format(ccs_acc))

# ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
# print("CCS test accuracy: {}".format(ccs_acc))

# lightning

In [20]:
import lightning.pytorch as pl

## DataModule

In [21]:
from dataclasses import dataclass
from torch.utils.data import random_split, DataLoader, TensorDataset
from transformers.models.auto.modeling_auto import AutoModel
# from scipy.stats import zscore

from sklearn.preprocessing import RobustScaler

# def normalize(x):
#     """
#     Mean-normalizes the data x (of shape (n, d))
#     If self.var_normalize, also divides by the standard deviation
#     """
#     normalized_x = x - x.mean(axis=0, keepdims=True)
#     if self.var_normalize:
#         normalized_x /= normalized_x.std(axis=0, keepdims=True)

#     return normalized_x


class IMBDHSDataModule(pl.LightningDataModule):

    def __init__(self,
                 model: AutoModel,
                 tokenizer: AutoTokenizer,
                 model_type="decoder",
                 dataset_name="amazon_polarity",
                 batch_size=32,
                 n=200,
                ):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.save_hyperparameters(ignore=["model", "tokenizer"])

    def setup(self, stage: str):

        self.dataset = load_dataset(self.hparams.dataset_name, split="test")

        neg_hs, pos_hs, y, all_neg_ans, all_pos_ans = get_hidden_states_many_examples(
            self.model, self.tokenizer, self.dataset, self.hparams.model_type, n=self.hparams.n, layers=[2, -2])

        # let's create a simple 50/50 train split (the data is already randomized)
        n = len(y)
        val_split = int(n * 0.5)
        test_split = int(n * 0.75)
        neg_hs_train, pos_hs_train, y_train = neg_hs[:
                                                     val_split], pos_hs[:
                                                                        val_split], y[:
                                                                                      val_split]
        neg_hs_val, pos_hs_val, y_val = neg_hs[val_split:test_split], pos_hs[
            val_split:test_split], y[val_split:test_split]
        neg_hs_test, pos_hs_test, y_test = neg_hs[test_split:], pos_hs[
            test_split:], y[test_split:]

        # for simplicity we can just take the difference between positive and negative hidden states
        # (concatenating also works fine)
        self.x_train = neg_hs_train - pos_hs_train
        self.x_val = neg_hs_val - pos_hs_val
        self.x_test = neg_hs_test - pos_hs_test

        # normalize
        self.scaler = RobustScaler()
        self.scaler.fit(self.x_train)
        self.x_train = self.scaler.transform(self.x_train)
        self.x_val = self.scaler.transform(self.x_val)
        self.x_test = self.scaler.transform(self.x_test)

        self.ds_train = TensorDataset(torch.from_numpy(neg_hs_train).float(),
                                      torch.from_numpy(pos_hs_train).float(),
                                      torch.from_numpy(y_train).float())

        self.ds_val = TensorDataset(torch.from_numpy(neg_hs_val).float(),
                                    torch.from_numpy(pos_hs_val).float(),
                                    torch.from_numpy(y_val).float())

        self.ds_test = TensorDataset(torch.from_numpy(neg_hs_test).float(),
                                     torch.from_numpy(pos_hs_test).float(),
                                     torch.from_numpy(y_test).float())

    def train_dataloader(self):
        return DataLoader(self.ds_train,
                          batch_size=self.hparams.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.hparams.batch_size)

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=self.hparams.batch_size)


# test
dm = IMBDHSDataModule(model, tokenizer)
dm.setup('train')
dl = dm.val_dataloader()
b = next(iter(dl))
b

Found cached dataset amazon_polarity (/home/wassname/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc)
get_hidden_states:  90%|███████████████████████████████████▊    | 895/1000 [08:21<00:58,  1.78examples/s]


In [None]:
dm.x_test.shape

## LightningModel

In [None]:
from torch import optim

In [None]:


def get_loss(p0, p1):
    """
    Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
    """
    informative_loss = (torch.min(p0, p1)**2).mean(0)
    consistent_loss = ((p0 - (1-p1))**2).mean(0)
    return informative_loss + consistent_loss


def get_acc(p0, p1, y):
    avg_confidence = 0.5*(p0 + (1-p1))
    predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
    
    # TODO f1
    conf = (avg_confidence.detach().cpu().numpy() )[:, 0]
    
    acc = (predictions == y.cpu().numpy()).mean()
    acc = max(acc, 1 - acc)
    return predictions, acc

def get_f1(p0, p1, y):
    avg_confidence = 0.5*(p0 + (1-p1))
    predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
    
    # TODO f1
    conf = (avg_confidence.detach().cpu().numpy() )[:, 0]
    auc = roc_auc_score(y.cpu().numpy(), predictions)
    
    auc = max(auc, 1 - auc)
    return predictions, auc

class CSS(pl.LightningModule):
    def __init__(self, d, max_epochs, lr=4e-3, weight_decay=1e-6):
        super().__init__()
        self.probe = MLPProbe(d)
        self.save_hyperparameters()
        
    def forward(self, x):
        return self.probe(x)
        
    def _step(self, batch, batch_idx, stage='train'):
        x0, x1, y = batch
        p0, p1 = self(x0), self(x1)
        
        loss = get_loss(p0, p1)
        
        self.log(f"{stage}/loss", loss)
        
        predictions, acc = get_acc(p0, p1, y)
        self.log(f"{stage}/acc", acc)
        predictions, f1 = get_f1(p0, p1, y)
        self.log(f"{stage}/f1", f1)
        return loss
    
    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx)
    
    def validation_step(self, batch, batch_idx=0):
        return self._step(batch, batch_idx, stage='val')
    
    def prediction_step(self, batch, batch_idx):
        x0, x1, y = batch
        p0, p1 = self(x0), self(x1)
        predictions, acc = get_acc(p0, p1, y)
        return predictions 

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.hparams.max_epochs, eta_min=self.hparams.lr / 50
        )
        return [optimizer], [lr_scheduler]
    

In [None]:
# init the autoencoder
max_epochs = 1000
d = b[0].shape[-1]
net = CSS(d=d, max_epochs=max_epochs)

In [None]:
# train_loader = utils.data.DataLoader(dataset)

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=max_epochs)
trainer.fit(model=net, datamodule=dm)

In [None]:
%debug

# Read hist

In [None]:
# import pytorch_lightning as pl
from lightning.pytorch.loggers.csv_logs import CSVLogger
# from pytorch_lightning.loggers.csv_logs import CSVLogger as CSVLogger2
from pathlib import Path
import pandas as pd

def read_metrics_csv(metrics_file_path):
    df_hist = pd.read_csv(metrics_file_path)
    df_hist["epoch"] = df_hist["epoch"].ffill()
    df_histe = df_hist.set_index("epoch").groupby("epoch").mean()
    return df_histe


def read_hist(trainer: pl.Trainer):

    ts = [t for t in trainer.loggers if isinstance(t, CSVLogger)]
    print(ts)
    try:
        metrics_file_path = Path(ts[0].experiment.metrics_file_path)
        df_histe = read_metrics_csv(metrics_file_path)
        return df_histe
    except Exception as e:
        raise e
        print(e)

In [None]:
df_hist = read_hist(trainer).ffill().bfill()
df_hist

In [None]:
df_hist[['val/acc', 'train/acc']].plot()

df_hist[['val/f1', 'train/f1']].plot()

df_hist[['val/loss', 'train/loss']].plot()