In [1]:
import os
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from scripts.dataset_creators.read_internal_states import HiddenStatesDataset
from scripts.eval.run_token_scoring import score_predictions_labels
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
#### load data 
foldtrain_dataer_path = '/scratch/ramprasad.sa/probing_summarization_factuality/internal_states/GPT_annotated/XSUM/mistral7b/document_context_gpt/'
hs_dataset = HiddenStatesDataset()
, test_data, class_weights = hs_dataset.make_data(folder_path, 
                                                            hidden_state_idx = 32)



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [02:56<00:00,  1.76s/it]


16 63


In [21]:
write_dir = '/scratch/ramprasad.sa/probing_summarization_factuality/probes/linear_probe/GPT_annotated/XSUM/mistral7b'




In [24]:
#### train model and run validation 

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple linear probe model
class LogisticRegressionProbe(nn.Module):
    def __init__(self, input_size):
        super(LogisticRegressionProbe, self).__init__()
        self.linear = nn.Linear(input_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.linear(x)
        return self.sigmoid(x)

    
def run_model(hstate,
             model):
    hstate, tokens, labels = dat
    nonzero_rows_mask = torch.any(hstate != 0, dim=1)
    hstate_filtered = hstate[nonzero_rows_mask] 
    outputs = model(hstate_filtered.float())
    
    labels = labels[nonzero_rows_mask]
    
    return outputs, labels

def compute_loss(criterion,
                 labels,
                 outputs,
                 class_weights, 
                 ):
    
    label_weights = torch.tensor([class_weights[lab.item()] for lab in labels]
    loss = criterion(outputs.squeeze(), labels.float()) 
    loss = loss * label_weights
    loss = torch.mean(loss)
    
    return loss 



        
hstate, tok, lab = train_data[0]
input_size = hstate.size(1)
output_size = 1  
model = LogisticRegressionProbe(input_size)

# criterion = nn.BCELoss(reduction='none')
criterion = nn.BCELoss(reduction='none')
optimizer = optim.SGD(model.parameters(), lr=0.01)

num_epochs = 100
for epoch in range(num_epochs):
    for dat_idx, dat in enumerate(train_data):
        outputs, labels = run_model(dat, 
                            model)
        
        loss = compute_loss(criterion,
                 labels,
                 outputs,
                 class_weights)
        
        # Backward pass and optimization
        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  #
    
    filename = f'loss_{loss.item():.4f}_epoch{epoch}'
    torch.save(model.state_dict(), f'{write_dir}/{filename}')
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
        all_labels = []
        all_predictions = []

        for dat_idx, dat in enumerate(test_data):
                hstate, tokens, labels = dat
                nonzero_rows_mask = torch.any(hstate != 0, dim=1)
                labels = labels[nonzero_rows_mask]

                outputs = run_model(dat, 
                                    model)
                outputs = outputs.detach().numpy().squeeze()

                all_labels += labels.tolist()
                all_predictions += outputs.tolist()
                                 
        score_dict = compute_scores(all_labels, all_predictions)

Epoch [10/100], Loss: 1.1233
Epoch [20/100], Loss: 0.2319
Epoch [30/100], Loss: 0.0001
Epoch [40/100], Loss: 0.0000
Epoch [50/100], Loss: 0.0000
Epoch [60/100], Loss: 0.0000
Epoch [70/100], Loss: 0.0000
Epoch [80/100], Loss: 0.0000
Epoch [90/100], Loss: 0.0000
Epoch [100/100], Loss: 0.0000


In [67]:
#### inference code 
from sklearn.metrics import balanced_accuracy_score, roc_auc_score

def compute_scores(labels, predictions):
    auc_score = roc_auc_score(labels, predictions)
    predictions_binary = [0 if each > 0.5 else 1 for each in predictions]
    bacc_score = balanced_accuracy_score(labels, predictions_binary)
    return {'auc': auc_score, 'bacc': bacc_score}

all_labels = []
all_predictions = []

for dat_idx, dat in enumerate(test_data):
        hstate, tokens, labels = dat
        nonzero_rows_mask = torch.any(hstate != 0, dim=1)
        labels = labels[nonzero_rows_mask]
        
        outputs = run_model(dat, 
                            model)
        outputs = outputs.detach().numpy().squeeze()
        
        all_labels += labels.tolist()
        all_predictions += outputs.tolist()

In [68]:
compute_scores(all_labels, all_predictions)

{'auc': 0.7797624685945231, 'bacc': 0.462448747614988}

In [54]:
outputs.shape

(112,)

In [59]:
all_predictions

[2.0238524079621466e-28,
 8.231002418467236e-27,
 2.6920903909843815e-23,
 5.059092461065665e-19,
 1.0612483550401068e-27,
 2.0961315361832363e-22,
 1.605602455489211e-38,
 1.350132461133355e-35,
 1.4881230427307251e-35,
 1.1942419059758571e-18,
 0.0,
 0.0,
 0.0,
 0.0,
 1.2432102345206659e-23,
 0.0,
 9.29623908703731e-31,
 2.185813378254812e-34,
 2.232541689319177e-26,
 1.4457848887430185e-30,
 4.2308520869507695e-27,
 3.276869765805627e-32,
 0.0,
 1.087926119562972e-32,
 6.183651989624514e-26,
 8.138432040550109e-19,
 3.5426471759878896e-36,
 1.531635173541312e-27,
 1.2330288597571826e-28,
 9.828840339131011e-36,
 1.0737419986116756e-34,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.9088805339987512e-26,
 7.746586984518466e-38,
 1.6280110912967555e-28,
 5.267700046345547e-31,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 3.229566721144074e-30,
 0.0,
 5.795731772609828e-38,
 0.0,
 0.0,
 1.3295454123530243e-20,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 4.142701560958038e-16,
 0.0,
 0.0,
 0.0,
 0.0,
 0