In [1]:
import os
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from scripts.dataset_creators.read_internal_states import HiddenStatesDataset
from scripts.eval.run_token_scoring import score_predictions_labels
from transformers import AutoTokenizer, AutoModelForCausalLM

In [73]:
#### load data 
train_folder_path = '/scratch/ramprasad.sa/probing_summarization_factuality/internal_states/GPT_annotated/XSUM/mistral7b/document_context_gpt/'

train_data, test_data, class_weights = HiddenStatesDataset().make_data(folder_path, 
                                                            hidden_state_idx = 32)


eval_dir = '/scratch/ramprasad.sa/probing_summarization_factuality/internal_states/Genaudit/XSUM/mistral7b/document_context/'
eval_data_1, eval_data_2, _ = HiddenStatesDataset().make_data(eval_dir, 
                                                            hidden_state_idx = 32)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [03:03<00:00,  1.83s/it]


16 63


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [01:21<00:00,  2.71s/it]


6 17


In [76]:
eval_data = eval_data_2 + eval_data_1

In [21]:
write_dir = '/scratch/ramprasad.sa/probing_summarization_factuality/probes/linear_probe/GPT_annotated/XSUM/mistral7b'




In [77]:
#### train model and run validation 

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple linear probe model
class LogisticRegressionProbe(nn.Module):
    def __init__(self, input_size):
        super(LogisticRegressionProbe, self).__init__()
        self.linear = nn.Linear(input_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.linear(x)
        return self.sigmoid(x)

    
def run_model(hstate,
             model):
    hstate, tokens, labels = dat
    nonzero_rows_mask = torch.any(hstate != 0, dim=1)
    hstate_filtered = hstate[nonzero_rows_mask] 
    outputs = model(hstate_filtered.float())
    
    labels = labels[nonzero_rows_mask]
    
    return outputs, labels

def compute_loss(criterion,
                 labels,
                 outputs,
                 class_weights, 
                 ):
    
    label_weights = torch.tensor([class_weights[lab.item()] for lab in labels])
    loss = criterion(outputs.squeeze(), labels.float()) 
    loss = loss * label_weights
    loss = torch.mean(loss)
    
    return loss 



        
hstate, tok, lab = train_data[0]
input_size = hstate.size(1)
output_size = 1  
model = LogisticRegressionProbe(input_size)

# criterion = nn.BCELoss(reduction='none')
criterion = nn.BCELoss(reduction='none')
optimizer = optim.SGD(model.parameters(), lr=0.01)

num_epochs = 100
for epoch in range(num_epochs):
    for dat_idx, dat in enumerate(train_data):
        outputs, labels = run_model(dat, 
                            model)
        
        loss = compute_loss(criterion,
                 labels,
                 outputs,
                 class_weights)
        
        # Backward pass and optimization
        optimizer.zero_grad()  # Zero gradients
        loss.backward()  # Backward pass
        optimizer.step()  #
    
    filename = f'loss_{loss.item():.4f}_epoch{epoch}'
    torch.save(model.state_dict(), f'{write_dir}/{filename}')
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
        all_labels = []
        all_predictions = []

        for dat_idx, dat in enumerate(test_data):
                hstate, tokens, labels = dat
                nonzero_rows_mask = torch.any(hstate != 0, dim=1)
                labels = labels[nonzero_rows_mask]

                outputs, labels = run_model(dat, 
                                    model)
                outputs = outputs.detach().numpy().squeeze()

                all_labels += labels.tolist()
                all_predictions += outputs.tolist()
                                 
        scores_dict = compute_scores(all_labels, all_predictions)
        print('Validation scores', scores_dict)
        
        
        all_labels = []
        all_predictions = []

        for dat_idx, dat in enumerate(eval_data):
                hstate, tokens, labels = dat
                nonzero_rows_mask = torch.any(hstate != 0, dim=1)
                labels = labels[nonzero_rows_mask]

                outputs, labels = run_model(dat, 
                                    model)
                outputs = outputs.detach().numpy().squeeze()

                all_labels += labels.tolist()
                all_predictions += outputs.tolist()
                                 
        scores_dict = compute_scores(all_labels, all_predictions)
        print('Test scores', scores_dict)
        
        
        

Epoch [10/100], Loss: 0.6077
Validation scores {'auc': 0.8207237806555795, 'bacc': 0.3040839524215483}
Test scores {'auc': 0.7168457330021956, 'bacc': 0.4215797430083144}
Epoch [20/100], Loss: 0.9490
Validation scores {'auc': 0.7715711083145014, 'bacc': 0.40677140421385943}
Test scores {'auc': 0.6590990893711982, 'bacc': 0.5075585789871504}
Epoch [30/100], Loss: 0.0002
Validation scores {'auc': 0.7860322873110597, 'bacc': 0.4705273413713311}
Test scores {'auc': 0.6605748119353562, 'bacc': 0.5034013605442177}
Epoch [40/100], Loss: 0.0001
Validation scores {'auc': 0.7860232659891655, 'bacc': 0.4705273413713311}
Test scores {'auc': 0.658559190872116, 'bacc': 0.5034013605442177}
Epoch [50/100], Loss: 0.0001
Validation scores {'auc': 0.7860322873110597, 'bacc': 0.4705273413713311}
Test scores {'auc': 0.6576143684987223, 'bacc': 0.5034013605442177}
Epoch [60/100], Loss: 0.0000
Validation scores {'auc': 0.7860999472252669, 'bacc': 0.4705273413713311}
Test scores {'auc': 0.6569664902998237, 'b

In [78]:
#### inference code 
from sklearn.metrics import balanced_accuracy_score, roc_auc_score

def compute_scores(labels, predictions):
    auc_score = roc_auc_score(labels, predictions)
    predictions_binary = [0 if each > 0.5 else 1 for each in predictions]
    bacc_score = balanced_accuracy_score(labels, predictions_binary)
    return {'auc': auc_score, 'bacc': bacc_score}

all_labels = []
all_predictions = []

for dat_idx, dat in enumerate(test_data):
        hstate, tokens, labels = dat
        nonzero_rows_mask = torch.any(hstate != 0, dim=1)
        labels = labels[nonzero_rows_mask]
        
        outputs = run_model(dat, 
                            model)
        outputs = outputs.detach().numpy().squeeze()
        
        all_labels += labels.tolist()
        all_predictions += outputs.tolist()

AttributeError: 'tuple' object has no attribute 'detach'

In [68]:
compute_scores(all_labels, all_predictions)

{'auc': 0.7797624685945231, 'bacc': 0.462448747614988}

In [54]:
outputs.shape

(112,)

In [81]:
len(train_data)

79

In [82]:
num_epochs = 100

for epoch in tqdm(range(0, num_epochs)):
    dat_idx = epoch%len(train_data)
    print(dat_idx)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 222864.19it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20



