In [1]:
import os
os.chdir("..")

In [2]:
import json
import pandas as pd
from tqdm import tqdm 

import torch
from torch import nn, optim
from torch.nn import functional as F

# Temperature scaling baseline

In [3]:
config_list = [
    {
        'PATH': "./data/climate_change/",
        'NAME': "CCC",
        'COLUMN_NAME': 'FSL_BART',
    }, {
        'PATH': "./data/topic_stance/",
        'NAME': "TS_topic",
        'COLUMN_NAME': 'FSL_BART_topic',
    }, {
        'PATH': "./data/topic_stance/",
        'NAME': "TS_stance",
        'COLUMN_NAME': 'FSL_BART_stance',
    }, {
        'PATH': "./data/depression/",
        'NAME': "D",
        'COLUMN_NAME': 'FSL_BART',
    }
]

In [4]:
for i, config in enumerate(config_list):
    print('\t*', i, ':\t', config['NAME'])

	* 0 :	 CCC
	* 1 :	 TS_topic
	* 2 :	 TS_stance
	* 3 :	 D


In [5]:
config_index = 3
config = config_list[config_index]
print(config['NAME'])

D


## Load data

In [6]:
train_df = pd.read_pickle(os.path.join(config['PATH'], 'training.pkl'))
test_df = pd.read_pickle(os.path.join(config['PATH'], 'testing.pkl'))

In [7]:
if config['NAME'] == "TS_topic": 
    with open(os.path.join(config['PATH'], 'claims_topic.json')) as file:
        claims = json.load(file)
        
elif config['NAME'] == "TS_stance": 
    with open(os.path.join(config['PATH'], 'claims_stance.json')) as file:
        claims = json.load(file)
        
else:
    with open(os.path.join(config['PATH'], 'claims.json')) as file:
        claims = json.load(file)

class_descr = claims["class_descr"]
del claims["class_descr"]

In [8]:
def temperature_scale(temperature, logits):
    """
    Perform temperature scaling on logits
    """
    return logits / temperature


def set_temperature(logits, labels):
    """
    Tune the tempearature of the model (using the validation set).
    We're going to set it to optimize NLL.
    valid_loader (DataLoader): validation set loader
    """
    temperature = nn.Parameter(torch.ones(1) * 1.5)
    logits = torch.FloatTensor(logits)
    labels = torch.FloatTensor(labels)
    
    loss_function = nn.BCEWithLogitsLoss()

    # Calculate NLL and ECE before temperature scaling
    before_temperature_nll = loss_function(logits, labels).item()

    # Next: optimize the temperature w.r.t. NLL
    optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50)

    def eval():
        optimizer.zero_grad()
        scaled_logits = temperature_scale(temperature, logits)
        try:
            loss = loss_function(scaled_logits, labels)
        except RuntimeError as e:
            print(e)
            print(logits)
            print(temperature)
            print(scaled_logits)
            print(labels)
            input()
        loss.backward()
        return loss
    optimizer.step(eval)

    # Calculate NLL and ECE after temperature scaling
    after_temperature_nll = loss_function(temperature_scale(temperature, logits), labels).item()
    
    return temperature

def get_calibrated_predictions_TS(zsl_dict, models, claims):
    new_dict = dict()
    for t in claims:
        new_dict[claims[t]] = temperature_scale(models[t], torch.FloatTensor([zsl_dict[claims[t]]]))[0].item()
    return new_dict

In [9]:
sample_sizes = [5, 10, 20, 40, 80, 160]

In [10]:
for samp_size in sample_sizes:
    
    print(samp_size)
    
    models = dict()
    volumes = dict()
    
    total_dfs = list()
    
    for t in tqdm(claims):
        if t == "class_descr": 
            continue
            
        if config["NAME"] == "CCC":
            class_idx = t[:3]
        elif config["NAME"] == "TS_topic":
            class_idx = t[:1]
        elif config["NAME"] == "TS_stance":
            class_idx = t[:2]
        elif config["NAME"] == "D":
            class_idx = t.split("_")[0]
        
        # Sample data
        sub_pos = train_df[train_df[class_idx + "_annot"] == 1]
        samp_pos = sub_pos.sample(min(samp_size, len(sub_pos)))
        sub_neg = train_df[train_df[class_idx + "_annot"] == 0]
        samp_neg = sub_neg.sample(min(samp_size, len(sub_neg)))
        volumes[t] = (len(samp_pos), len(samp_neg))
        total_dfs += [samp_pos, samp_neg]
        
        # Fit calibrator
        X = [d[claims[t]] for d in samp_pos[config["COLUMN_NAME"]].to_list()] + [d[claims[t]] for d in samp_neg[config["COLUMN_NAME"]].to_list()]
        y = samp_pos[class_idx + "_annot"].to_list() + samp_neg[class_idx + "_annot"].to_list()
        calibrator = set_temperature(X, y)
        models[t] = calibrator
        
    # Get predictions
    if config["NAME"] == "TS_topic":
        test_df["Temp_Scaling_BART_" + str(samp_size) + "_topic"] = test_df[config["COLUMN_NAME"]].apply(lambda x: get_calibrated_predictions_TS(x, models, claims))
    elif config["NAME"] == "TS_stance":
        test_df["Temp_Scaling_BART_" + str(samp_size) + "_stance"] = test_df[config["COLUMN_NAME"]].apply(lambda x: get_calibrated_predictions_TS(x, models, claims))
    else:
        test_df["Temp_Scaling_BART_" + str(samp_size)] = test_df[config["COLUMN_NAME"]].apply(lambda x: get_calibrated_predictions_TS(x, models, claims))

5


  from .autonotebook import tqdm as notebook_tqdm
100%|███████████████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 31.25it/s]


10


100%|███████████████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 45.35it/s]


20


100%|███████████████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 46.94it/s]


40


100%|███████████████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 38.99it/s]


80


100%|███████████████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 36.30it/s]


160


100%|███████████████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 30.57it/s]


In [11]:
test_df.to_pickle(config["PATH"] + "testing.pkl")