In [2]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq
)
from nltk.tokenize import word_tokenize
import tools
from const import (
    LEN,
    QUESTION,
    DATASET_NAME,
    FORMAT,
    ZERO_TOKEN,
    ONE_TOKEN,
    MAX_INPUT_LENGTH,
    SEED
)

In [2]:
LANGUAGES = ["ru"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PER_DEVICE_EVAL_BATCH_SIZE = 80

In [3]:
tools.set_seed(SEED)

In [4]:
seahorse = load_dataset(DATASET_NAME)

def filter_data(example): 
    if(example[QUESTION] == 0.5):
        return False

    example[QUESTION] = int(example[QUESTION])
    return example['lang'] in LANGUAGES and len(example['summary']) > LEN and len([token for token in word_tokenize(example['summary'], language='russian') if token.isalpha()]) >= 20

seahorse_filtered = seahorse.filter(filter_data)
seahorse_filtered

DatasetDict({
    train: Dataset({
        features: ['gem_id', 'lang', 'text', 'summary', 'model', 'comprehensible', 'repetition', 'grammar', 'attribution', 'main_ideas', 'conciseness'],
        num_rows: 4111
    })
    validation: Dataset({
        features: ['gem_id', 'lang', 'text', 'summary', 'model', 'comprehensible', 'repetition', 'grammar', 'attribution', 'main_ideas', 'conciseness'],
        num_rows: 597
    })
    test: Dataset({
        features: ['gem_id', 'lang', 'text', 'summary', 'model', 'comprehensible', 'repetition', 'grammar', 'attribution', 'main_ideas', 'conciseness'],
        num_rows: 1233
    })
})

In [5]:
model_name = "seahorse_metric"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(DEVICE)
model.eval()
zero_token_id = tokenizer(ZERO_TOKEN).input_ids[0]
one_token_id = tokenizer(ONE_TOKEN).input_ids[0]

In [7]:
def filter_long_examples(example):
        inputs = FORMAT.format(example['text'], example['summary'])
        tokenized = tokenizer(inputs, truncation=False)
        return len(tokenized['input_ids']) <= MAX_INPUT_LENGTH

validation_data_filtered_by_len = seahorse_filtered['validation'].filter(filter_long_examples, num_proc=4)
test_data_filtered_by_len = seahorse_filtered['test'].filter(filter_long_examples, num_proc=4)

In [8]:
def preprocess_function(examples):
    inputs = [FORMAT.format(article, summary)
            for article, summary in zip(examples['text'], examples['summary'])]
    model_inputs = tokenizer(
        inputs,
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding=False, # Defer padding to DataCollator
    )
    model_inputs["labels"] = [[i] for i in examples[QUESTION]]
    return model_inputs

validation_tokenized = validation_data_filtered_by_len.map(preprocess_function, batched=True, num_proc=4, remove_columns=validation_data_filtered_by_len.column_names)
test_tokenized = test_data_filtered_by_len.map(preprocess_function, batched=True, num_proc=4, remove_columns=test_data_filtered_by_len.column_names)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True)
val_data_loader = torch.utils.data.DataLoader(validation_tokenized, batch_size=PER_DEVICE_EVAL_BATCH_SIZE, collate_fn=data_collator, shuffle=False)
test_data_loader = torch.utils.data.DataLoader(test_tokenized, batch_size=PER_DEVICE_EVAL_BATCH_SIZE, collate_fn=data_collator, shuffle=False)

# Vlidation Optimal Threshold

In [10]:
val_logits_target, val_labels_binary = tools.get_logits_and_labels(
    model, val_data_loader, zero_token_id, one_token_id, DEVICE
)
optimal_threshold = tools.find_best_threshold(val_logits_target, val_labels_binary)

Evaluating and collecting logits:   0%|          | 0/8 [00:00<?, ?it/s]

Finding best threshold:   0%|          | 0/101 [00:00<?, ?it/s]


Optimal threshold found: 0.4400 (max F1 = 0.7265)


# Test Results

In [11]:
test_logits_target, test_labels_binary = tools.get_logits_and_labels(
    model, test_data_loader, zero_token_id, one_token_id, DEVICE
)
final_results = tools.calculate_final_metrics(
    test_logits_target, test_labels_binary, optimal_threshold
)
final_results

Evaluating and collecting logits:   0%|          | 0/16 [00:00<?, ?it/s]

{'pearson_corr': 0.4587684140690601,
 'roc_auc': 0.7805154307213431,
 'accuracy': 0.710804224207961,
 'f1': 0.7156035358651394,
 'mean_confidence_overall': 0.7909254,
 'mean_confidence_correct': 0.82086253,
 'mean_confidence_incorrect': 0.71734405,
 'ece': 0.08814920775726995,
 'mce': 0.1410748458677723}

# Results After Calibration

In [12]:
from sklearn.isotonic import IsotonicRegression

val_probs = torch.sigmoid(torch.tensor(val_logits_target[:, 1] - val_logits_target[:, 0])).numpy()
test_probs = torch.sigmoid(torch.tensor(test_logits_target[:, 1] - test_logits_target[:, 0])).numpy()

isotonic_calibrator = IsotonicRegression(out_of_bounds="clip")
isotonic_calibrator.fit(val_probs, val_labels_binary)
calibrated_probabilities = isotonic_calibrator.transform(test_probs)

In [13]:
val_calibrated_probabilities = isotonic_calibrator.transform(val_probs)
optimal_calibrated_threshold = tools.find_best_threshold_from_probabilities(val_calibrated_probabilities, val_labels_binary)

Finding best threshold:   0%|          | 0/101 [00:00<?, ?it/s]


Optimal threshold found: 0.3700 (max F1 = 0.7298)


In [14]:
calibrated_final_results = tools.calculate_calibrated_metrics(
             test_logits_target,
             test_labels_binary,
             optimal_calibrated_threshold,
             calibrated_probabilities
    )
calibrated_final_results

Bin 3 is empty.
Bin 6 is empty.
Bin 8 is empty.


{'pearson_corr': 0.4587684140690601,
 'roc_auc': 0.7805154307213431,
 'accuracy': 0.7124289195775793,
 'f1': 0.7172072325970649,
 'mean_confidence_overall': 0.7259493,
 'mean_confidence_correct': 0.751194,
 'mean_confidence_incorrect': 0.6634081,
 'ece': 0.04807778492049412,
 'mce': 0.1524728826574377}