# Parameters

## Defaults

In [10]:
class Params(object):
    task = 'mrpc'
    run_id = 0
    batch_size = 64
    max_len = 256
    lr = 3e-5
    n_epochs = 20
    logdir = 'checkpoints/'
    lm = 'roberta'
    bert_path = None
    fp16 = False
    finetuning = False
    da = None
    size = None
    alpha_aug = 0.8
    alpha = 0.2
    num_aug = 2
    u_lambda = 10.0
    no_ssl = False
    balance = False
    warmup = False
    test_file = ''

## Set Parameters Here

In [11]:
## Hyper-parameters ##
hp = Params()
hp.task = 'em_SANTOS-XS_notest'
hp.size = 300
hp.logdir = 'results_em/'
hp.finetuning = True
hp.batch_size = 32
hp.lr = 3e-5
hp.n_epochs = 20
hp.max_len = 128
hp.fp16 = True
hp.lm = 'roberta'
hp.da = 'None'
hp.balance = True
hp.run_id = 0

## Directory with test files ##
test_dir = 'data/em/SANTOS-XS/individual tests/'

# Setup

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
import argparse
import json
import sklearn.metrics as metrics
import uuid
import pandas as pd

from rotom.dataset import TextCLSDataset
from ditto.dataset import DittoDataset
from functools import partial
from torch.utils import data
from snippext.model import MultiTaskNet
from snippext.dataset import *
from tensorboardX import SummaryWriter
from transformers import AdamW, get_linear_schedule_with_warmup
from apex import amp
from transformers.data import glue_processors, glue_compute_metrics

# Configuration

Copied from `train_any.py`

In [13]:
num_classes = {'AMAZON2': 2, 'AMAZON5': 5,
               'AG': 4}

vocabs = {'SNIPS': ['AddToPlaylist', 'BookRestaurant',
                 'GetWeather', 'PlayMusic',
                 'RateBook', 'SearchCreativeWork',
                 'SearchScreeningEvent'],
          'ATIS': ['atis_abbreviation', 'atis_aircraft',
              'atis_airfare', 'atis_airline',
              'atis_airline#atis_flight_no', 'atis_airport',
              'atis_capacity', 'atis_city', 'atis_distance',
              'atis_flight', 'atis_flight#atis_airfare',
              'atis_flight_no', 'atis_flight_time',
              'atis_ground_fare', 'atis_ground_service',
              'atis_quantity', 'atis_restriction',
              'atis_meal',
              'atis_day_name',
              'atis_airfare#atis_flight',
              'atis_flight#atis_airline',
              'atis_flight_no#atis_airline',
              'atis_airfare#atis_flight_time',
              'atis_ground_service#atis_ground_fare'],
          'TREC': ['0', '1', '2', '3', '4', '5'],
          'SST-2': ['0', '1'],
          'SST-5': ['0', '1', '2', '3', '4'],
          'IMDB': ['pos', 'neg']}


def get_cls_config(hp):
    """Get configuration of the task"""
    taskname = hp.task
    if 'em_' in taskname:
        name = taskname[3:]
        vocab = ['0', '1']
        path = 'data/em/%s/' % name
        config = {'name': taskname,
                'task_type': 'classification',
                'vocab': vocab}
        return config,\
               DittoDataset,\
               DittoDataset
    elif 'cleaning_' in taskname:
        LL = taskname.split('_')
        if hp.size is not None:
            size, idx = str(hp.size), str(hp.run_id)
            name = LL[1]
        else:
            prefix, size, idx = LL[0], LL[-2], LL[-1]
            name = '_'.join(LL[1:-2])

        path = 'data/cleaning/%s/%s_10000/%s/' % (name, size, idx)
        vocab = ['0', '1']
        config = {'name': taskname,
                  'task_type': 'classification',
                  'vocab': vocab}
        return config, DittoDataset, DittoDataset
    elif 'compare' in taskname:
        # compare2_SST-2
        LL = taskname.split('_')
        prefix, name = LL[0], LL[1]
        path = 'data/textcls/%s/%s/' % (prefix, name)
        vocab = vocabs[name]
        idx = str(hp.run_id)
        config = {'name': taskname,
                  'task_type': 'classification',
                  'vocab': vocab}
        return config, TextCLSDataset, TextCLSDataset
    else:
        # Text CLS datasets
        if 'textcls_' in taskname:
            taskname = taskname.replace('textcls_', '')
        if hp.size is None:
            path, size = taskname.split('_')
        else:
            path = taskname
            size = str(hp.size)
        path = path.upper()
        if path in vocabs:
            vocab = vocabs[path]
        else:
            vocab = [str(i) for i in \
                    range(1, num_classes[path]+1)]
        path = 'data/textcls/%s' % path

        config = {'name': taskname,
                  'task_type': 'classification',
                  'vocab': vocab}
        return config, TextCLSDataset, TextCLSDataset

def get_ops(hp):
    """return a pair of DA operators for each task"""
    em = ["t5", "del", "del"]
    cls = ["t5", "token_repl_tfidf", "token_del_tfidf"]
    cleaning = ["t5", "swap", "swap"]

    if 'cleaning_' in task:
        return cleaning
    if "em_" in task: # EM
        return em
    else:
        return cls

# Initialize Model

In [14]:
def get_model(task_config, hp):
    """Get model and optimizer
    Args:
        task_config (dictionary): the configuration of the task
        hp (Namespace): the parsed hyperparameters
        
    Returns:
        model, optimizer
    """
    # initialize model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = MultiTaskNet([task_config], device,
                         lm=hp.lm, bert_path=hp.bert_path)

    # move to device
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr=hp.lr)
    if device == 'cuda' and hp.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    return model, optimizer

# Evaluation

Adapted from `snippext.train_util.py`

In [15]:
def eval_classifier(model, iterator, threshold=None, get_threshold=False):
    """Evaluate a classification model state on a dev/test set.

    Args:
        model (MultiTaskNet): the model state
        iterator (DataLoader): a batch iterator of the dev/test set
        threshold (float, optional): the cut-off threshold for binary cls
        get_threshold (boolean, optional): return the selected threshold if True

    Returns:
        float: Precision (or accuracy if more than 2 classes)
        float: Recall (or accuracy if more than 2 classes)
        float: F1 (or macro F1 if more than 2 classes)
        float: The Loss
        float: The cut-off threshold
    """
    model.eval()

    Y = []
    Y_hat = []
    Y_prob = []
    loss_list = []
    total_size = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            _, x, _, _, _, y, _, taskname = batch
            taskname = taskname[0]
            logits, y1, y_hat = model(x, y, task=taskname)
            logits = logits.view(-1, logits.shape[-1])
            y1 = y1.view(-1)
            if 'sts-b' in taskname.lower():
                loss = nn.MSELoss()(logits, y1)
            else:
                loss = nn.CrossEntropyLoss()(logits, y1)

            loss_list.append(loss.item() * y.shape[0])
            total_size += y.shape[0]

            Y.extend(y.numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())
            Y_prob.extend(logits.softmax(dim=-1).max(dim=-1)[0].cpu().numpy().tolist())

    loss = sum(loss_list) / total_size

    print("=============%s==================" % taskname)

    # for glue
    if taskname in glue_processors:
        Y_hat = np.array(Y_hat).squeeze()
        Y = np.array(Y)
        result = glue_compute_metrics(taskname, Y_hat, Y)
        result['loss'] = loss
        print(result)
        return result
    elif taskname[:5] == 'glue_':
        task = taskname.split('_')[1].lower()
        Y_hat = np.array(Y_hat).squeeze()
        Y = np.array(Y)
        result = glue_compute_metrics(task, Y_hat, Y)
        result['loss'] = loss
        print(result)
        return result
    else:
        num_classes = len(set(Y))
        # Binary classification
        if num_classes <= 2:
            accuracy = metrics.accuracy_score(Y, Y_hat)
            precision = metrics.precision_score(Y, Y_hat)
            recall = metrics.recall_score(Y, Y_hat)
            f1 = metrics.f1_score(Y, Y_hat)
            if any([prefix in taskname for prefix in \
                ['cleaning_', 'em_']]): # handle imbalance:
                max_f1 = f1
                if threshold is None:
                    for th in np.arange(0.0, 1.0, 0.005):
                        Y_hat = [y if p > th else 0 for (y, p) in zip(Y_hat, Y_prob)]
                        f1 = metrics.f1_score(Y, Y_hat)
                        if f1 > max_f1:
                            max_f1 = f1
                            accuracy = metrics.accuracy_score(Y, Y_hat)
                            precision = metrics.precision_score(Y, Y_hat)
                            recall = metrics.recall_score(Y, Y_hat)
                            threshold = th
                    f1 = max_f1
                else:
                    Y_hat = [y if p > threshold else 0 for (y, p) in zip(Y_hat, Y_prob)]
                    accuracy = metrics.accuracy_score(Y, Y_hat)
                    precision = metrics.precision_score(Y, Y_hat)
                    recall = metrics.recall_score(Y, Y_hat)
                    f1 = metrics.f1_score(Y, Y_hat)

            print("accuracy=%.3f"%accuracy)
            print("precision=%.3f"%precision)
            print("recall=%.3f"%recall)
            print("f1=%.3f"%f1)
            print("======================================")
            if get_threshold:
                return accuracy, precision, recall, f1, loss, threshold
            else:
                return accuracy, precision, recall, f1, loss
        else:
            accuracy = metrics.accuracy_score(Y, Y_hat)
            f1 = metrics.f1_score(Y, Y_hat, average='macro')
            precision = recall = accuracy # We might just not return anything
            print("accuracy=%.3f"%accuracy)
            print("macro_f1=%.3f"%f1)
            print("======================================")
            return accuracy, f1, loss


def eval_on_task(model,
                 task,
                 test_iter,
                 test_dataset):
    """Run the eval function on the dev/test datasets and log the results.

    Args:
        model (MultiTaskNet): the model state
        task (str): the task name to be evaluated
        test_iter (DataLoader): the test set iterator
        run_tag (str): the tag of the run

    Returns:
        test metrics: dict
    """
    t_prec = t_recall = t_f1 = t_loss = None
    if any([prefix in task for prefix in \
            ['cleaning_', 'em_']]): # handle imbalance:
        print('Test:')
        # t_acc, t_prec, t_recall, t_f1, t_loss = eval_classifier(model, test_iter)
        t_acc, t_prec, t_recall, t_f1, t_loss, _ = eval_classifier(model, test_iter, get_threshold=True)
        scalars = {'acc': t_acc,
                   'precision': t_prec,
                   'recall': t_recall,
                   'f1': t_f1,
                   'loss': t_loss}
    else:
        if test_iter is not None:
            print('Test:')
            t_output = eval_classifier(model, test_iter)

        if len(v_output) == 5:
            acc, prec, recall, f1, v_loss = v_output
            t_acc, t_prec, t_recall, t_f1, t_loss = t_output
            scalars = {'acc': t_acc,
                       'precision': t_prec,
                       'recall': t_recall,
                       'f1': t_f1,
                       'loss': t_loss}
        else:
            acc, f1, v_loss = v_output
            t_acc, t_f1, t_loss = t_output
            scalars = {'acc': t_acc,
                       'f1': t_f1,
                       'loss': t_loss}

    return scalars

# Main

In [16]:
# set seed
seed = hp.run_id
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# create the tag of the run
if hp.no_ssl:
    run_tag = '%s_lm=%s_da=%s_no_ssl_alpha=%.1f_id=%d' % (hp.task, hp.lm, hp.da, hp.alpha_aug, hp.run_id)
else:
    run_tag = '%s_lm=%s_da=%s_alpha=%.1f_id=%d' % (hp.task, hp.lm, hp.da, hp.alpha_aug, hp.run_id)
if hp.size is not None:
    run_tag += '_size=%d' % hp.size

checkpt_path = run_tag + '_dev.pt'

config, Dataset, TestDataset = get_cls_config(hp)

if hp.balance:
    Dataset = partial(Dataset, balance=hp.balance)

task = config['name']
vocab = config['vocab']
task_type = config['task_type']

model, optimizer = get_model(config, hp)
model.load_state_dict(torch.load(checkpt_path))



Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.

Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic


<All keys matched successfully>

In [17]:
test_files = os.listdir(test_dir)
results = []
for testset in test_files:
    if os.path.isdir(os.path.join(test_dir, testset)): continue
    print(testset)
    test_dataset = TestDataset(os.path.join(test_dir, testset), vocab, task, lm=hp.lm, max_len=hp.max_len)
    padder = SnippextDataset.pad
    test_iter = data.DataLoader(dataset=test_dataset,
                                 batch_size=hp.batch_size*4,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=padder)
    scalars = eval_on_task(model,
                             task,
                             test_iter,
                             test_dataset)
    scalars["dataset"] = testset[:-4]
    results.append(scalars)

civic_building_locations.txt
Test:
accuracy=0.909
precision=1.000
recall=0.750
f1=0.857
complaint_by_practice.txt
Test:
accuracy=0.300
precision=0.300
recall=1.000
f1=0.462
contributors_parties.txt
Test:
accuracy=0.714
precision=0.667
recall=1.000
f1=0.800
data_mill.txt
Test:
accuracy=0.636
precision=0.500
recall=1.000
f1=0.667
deaths_2012_2018.txt
Test:
accuracy=0.667
precision=0.625
recall=0.909
f1=0.741
film_locations_in_san_francisco.txt
Test:
accuracy=0.818
precision=0.833
recall=0.833
f1=0.833
311_calls_historic_data.txt
Test:
accuracy=0.867
precision=0.875
recall=0.875
f1=0.875
abandoned_wells.txt
Test:
accuracy=1.000
precision=1.000
recall=1.000
f1=1.000
albums.txt
Test:
accuracy=0.250
precision=0.000
recall=0.000
f1=0.000
animal_tag_data.txt
Test:
accuracy=1.000
precision=1.000
recall=1.000
f1=1.000
biodiversity.txt
Test:
accuracy=0.800
precision=1.000
recall=0.667
f1=0.800
business_rates.txt
Test:
accuracy=0.543
precision=0.529
recall=1.000
f1=0.692
cdc_nutrition_physical_act

In [18]:
results_df = pd.DataFrame.from_dict(results)
results_df.to_csv(os.path.join(test_dir, 'out', run_tag + '__all_results.csv'), index=False)