In [None]:
! cp -r '/content/drive/MyDrive/CS 421/datasets' '/content/dataset_folders'

In [None]:
%%capture
! pip install transformers datasets sentencepiece protobuf==3.20.* tensorboardX
! pip install accelerate -U

In [None]:
from sklearn.metrics import mean_squared_error
import warnings
warnings.filterwarnings('ignore')

## Data Utils

In [None]:
import argparse
import re
import json
import numpy as np
import accelerate

from datasets import Dataset, DatasetDict, load_dataset

DATASET_ROOT = '/content/dataset_folders'


class DatasetLoader(object):
    def __init__(self, dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=None):
        self.data_root = DATASET_ROOT
        self.dataset_name = dataset_name
        self.source_dataset_name = source_dataset_name
        self.dataset_version = dataset_version
        self.has_valid = has_valid
        self.split_map = split_map

        self.batch_size = batch_size
        self.train_batch_idxs = train_batch_idxs
        self.test_batch_idxs = test_batch_idxs
        self.valid_batch_idxs = valid_batch_idxs

        assert self.split_map is not None


    def load_from_source(self):
        if self.source_dataset_name is None:
            self.source_dataset_name = self.dataset_name
        if self.dataset_version is None:
            datasets = load_dataset(self.source_dataset_name)
        else:
            datasets = load_dataset(self.source_dataset_name, self.dataset_version)
        return datasets


    def to_json(self, datasets):
        for k, v in self.split_map.items():
            datasets[v].to_json(f'{self.data_root}/{self.dataset_name}/{self.dataset_name}_{k}.json')


    def load_from_json(self):
        data_files = {
            'train': f'{self.data_root}/{self.dataset_name}/{self.dataset_name}_train.json',
            'test': f'{self.data_root}/{self.dataset_name}/{self.dataset_name}_test.json',
        }

        if self.has_valid:
            data_files.update({'valid': f'{self.data_root}/{self.dataset_name}/{self.dataset_name}_valid.json',})

        datasets = load_dataset('json', data_files=data_files)
        datasets = self._post_process(datasets)

        # subsample training dataset if needed
        num_train = len(datasets['train'])
        idxs = list()
        for idx in self.train_batch_idxs:
            idxs += range(idx*self.batch_size, (idx+1)*self.batch_size)
        datasets['train'] = Dataset.from_dict(datasets['train'][[idx for idx in idxs if idx < num_train]])

        return datasets


    def load_llm_preds(self, split):
        labels = list()
        rationales = list()
        for idx in getattr(self, f'{split}_batch_idxs'):
            with open(f'{self.data_root}/{self.dataset_name}/llm/{split}_CoT_{idx}.json') as f:
                outputs = json.load(f)

            for output in outputs:
                rationale, label = self._parse_llm_output(output)

                rationales.append(rationale)
                labels.append(label)

        return rationales, labels


    def load_gpt_preds(self, split):
        labels = list()
        rationales = list()

        with open(f'{self.data_root}/gpt-neox/{self.dataset_name}/{split}.json') as f:
            outputs = json.load(f)

        for output in outputs:
            rationale, label = self._parse_gpt_output(output)

            rationales.append(rationale)
            labels.append(label)

        return rationales, labels


    def _post_process(self, datasets):
        raise NotImplementedError


    def _parse_llm_output(self, output):
        raise NotImplementedError


    def _parse_gpt_output(self, output):
        raise NotImplementedError


class CQADatasetLoader(DatasetLoader):
    def __init__(self):
        dataset_name = 'cqa'
        source_dataset_name = 'cos_e'
        dataset_version = 'v1.11'
        has_valid = False
        split_map = {
            'train': 'train',
            'test': 'validation',
        }
        batch_size = 1000
        train_batch_idxs = range(10)
        test_batch_idxs = range(2)

        super().__init__(dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=None)


    def _post_process(self, datasets):

        def prepare_input(example):
            question = example['question']
            c_0 = example['choices'][0]
            c_1 = example['choices'][1]
            c_2 = example['choices'][2]
            c_3 = example['choices'][3]
            c_4 = example['choices'][4]

            input = f'{question}\nAnswer Choices:\n(a) {c_0}\n(b) {c_1}\n(c) {c_2}\n(d) {c_3}\n(e) {c_4}'

            example['input'] = input
            example['label'] = example['answer']

            return example

        datasets = datasets.map(prepare_input)
        datasets = datasets.remove_columns(['id', 'question', 'choices', 'answer', 'abstractive_explanation', 'extractive_explanation'])

        return datasets


    def _parse_llm_output(self, output):
        rationale_label = output.split('Q:')[0]
        rationale_label = rationale_label.rstrip()
        rationale, label = rationale_label.split('So the answer is')
        rationale = rationale.rstrip()

        try:
            label = re.split(r'\(.\)', label)[1].strip()
            label = label if label[-1]!='.' else label[:-1]
        except:
            label = ' '

        return rationale, label


    def _parse_gpt_output(self, output):
        rationale_label = output.split('Q:')[0]
        rationale_label = rationale_label.rstrip().lstrip()
        try:
            rationale, label = rationale_label.split('So the answer is')
            rationale = rationale.rstrip()
        except:
            rationale = ' '
            label = ' '
            return rationale, label

        try:
            label = re.split(r'\(.\)', label)[1].strip()
            label = label if label[-1]!='.' else label[:-1]
        except:
            label = ' '

        return rationale, label


class SVAMPDatasetLoader(DatasetLoader):
    def __init__(self):
        dataset_name = 'svamp'
        source_dataset_name = 'svamp'
        dataset_version = None
        has_valid = False
        split_map = {
            'train': 'train',
            'test': 'test',
        }
        batch_size = 500
        train_batch_idxs = range(2)
        test_batch_idxs = range(1)


        super().__init__(dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=None)


    def load_from_source(self):
        with open(f'{self.data_root}/{self.dataset_name}/SVAMP.json') as f:
            original_dataset = json.load(f)

        dataset = list()
        for data in original_dataset:
            input = f'{data["Body"]}\n{data["Question"]}'
            equation = data["Equation"]

            dataset.append({
                'input': input,
                'label': equation,
            })

        idxs = np.random.RandomState(seed=0).permutation(len(dataset))
        train_idxs = idxs[:800]
        test_idxs = idxs[800:]

        train_dataset = Dataset.from_list(np.array(dataset)[train_idxs].tolist())
        test_dataset = Dataset.from_list(np.array(dataset)[test_idxs].tolist())

        datasets = DatasetDict({
            'train': train_dataset,
            'test': test_dataset
        })

        return datasets


    def _post_process(self, datasets):
        return datasets


    def _parse_llm_output(self, output):
        rationale_label = output.split('Q:')[0]
        rationale_label = rationale_label.rstrip()
        try:
            rationale, label = rationale_label.split('The answer is')
        except:
            rationale = ' '
            label = ' '
            return rationale, label

        rationale = rationale.rstrip()
        try:
            label = re.search(r'\(.*\)', label).group(0)
        except:
            label = ' '

        return rationale, label

    def _parse_gpt_output(self, output):
        rationale_label = output.split('Q:')[0]
        rationale_label = rationale_label.rstrip().lstrip()
        try:
            rationale, label = rationale_label.split('The answer is')
        except:
            rationale = ' '
            label = ' '
            return rationale, label

        rationale = rationale.rstrip()
        try:
            label = re.search(r'\(.*\)', label).group(0)
        except:
            label = ' '

        return rationale, label


class ASDivDatasetLoader(DatasetLoader):
    def __init__(self):
        dataset_name = 'asdiv'
        dataset_version = None
        has_valid = False
        split_map = {
            'train': 'train',
            'test': 'test',
        }
        batch_size = 1000
        train_batch_idxs = range(3)
        test_batch_idxs = range(1)

        super().__init__(dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=None)


    def load_from_source(self):
        raise NotImplementedError


    def _post_process(self, datasets):

        def prepare_input(example):
            example['input'] = example['Body'] + '\n' + example['Question']
            answer = example['Answer'].split(' ')[0]
            example['label'] = answer

            return example

        datasets = datasets.map(prepare_input)
        datasets = datasets.remove_columns(['Body', 'Question', 'Formula', 'Answer'])

        return datasets


    def _parse_llm_output(self, output):
        rationale_label = output.split('Q:')[0]
        rationale_label = rationale_label.rstrip()
        try:
            rationale, label = rationale_label.split('The answer is')
        except:
            rationale = ' '
            label = ' '
            return rationale, label

        rationale = rationale.rstrip()
        try:
            label = re.search(r'\(.*\)', label).group(0)
        except:
            label = ' '

        return rationale, label


    def _parse_gpt_output(self, output):
        raise NotImplementedError


class ESNLIDatasetLoader(DatasetLoader):
    def __init__(self, subset='full'):
        dataset_name = 'esnli'
        source_dataset_name = 'esnli'
        dataset_version = None
        has_valid = True
        split_map = {
            'train': 'train',
            'valid': 'validation',
            'test': 'test',
        }
        batch_size = 5500
        if subset == 'full':
            train_batch_idxs = range(100)
        elif subset == 'small':
            train_batch_idxs = range(10)
        else:
            raise ValueError
        test_batch_idxs = range(2)
        valid_batch_idxs = range(2)

        super().__init__(dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=valid_batch_idxs)


    def _post_process(self, datasets):

        def prepare_input(example):
            if example['label'] == 0:
                example['label'] = 'entailment'
            elif example['label'] == 1:
                example['label'] = 'neutral'
            elif example['label'] == 2:
                example['label'] = 'contradiction'

            return example

        datasets = datasets.map(prepare_input)
        datasets = datasets.remove_columns(['explanation_1', 'explanation_2', 'explanation_3'])

        return datasets


    def _parse_llm_output(self, output):
        rationale = output.split("Answer:")[0].rstrip()
        try:
            label = output.split("Answer: ")[1].split("Premise")[0].rstrip()
        except:
            label = ' '

        return rationale, label


    def _parse_gpt_output(self, output):
        rationale = output.split("Answer:")[0].rstrip().lstrip()
        try:
            label = output.split("Answer: ")[1].split("Premise")[0].rstrip()
        except:
            label = ' '

        return rationale, label


class ANLIDatasetLoader(DatasetLoader):
    def __init__(self, dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs):

        super().__init__(dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=valid_batch_idxs)

    def _post_process(self, datasets):

        def label_idx2text(example):
            if example['label'] == 0:
                example['label'] = 'entailment'
            elif example['label'] == 1:
                example['label'] = 'neutral'
            elif example['label'] == 2:
                example['label'] = 'contradiction'
            return example

        datasets = datasets.map(label_idx2text)
        datasets = datasets.remove_columns(['uid', 'reason'])

        return datasets


    def _parse_llm_output(self, output):
        try:
            rationale, label = output.split("Premise:")[0].rstrip().split("So the answer is")
        except:
            rationale = ''
            label = ''

        rationale = rationale.rstrip()
        label = label.lstrip()[:-1]

        return rationale, label


    def _parse_gpt_output(self, output):
        try:
            rationale, label = output.split("Premise:")[0].rstrip().lstrip().split("So the answer is")
        except:
            try:
                rationale, label = output.split("Premise:")[0].rstrip().lstrip().split("The answer is")
            except:
                rationale = ''
                label = ''


        rationale = rationale.rstrip()
        label = label.lstrip()[:-1]

        return rationale, label


class ANLI1DatasetLoader(ANLIDatasetLoader):
    def __init__(self):
        dataset_name = 'anli1'
        source_dataset_name = 'anli'
        dataset_version = None
        has_valid = True
        split_map = {
            'train': 'train_r1',
            'valid': 'dev_r1',
            'test': 'test_r1',
        }
        batch_size = 5000
        train_batch_idxs = range(4)
        test_batch_idxs = range(1)
        valid_batch_idxs = range(1)

        super().__init__(dataset_name, source_dataset_name, dataset_version, has_valid, split_map,
                 batch_size, train_batch_idxs, test_batch_idxs, valid_batch_idxs=valid_batch_idxs)

## Metrics

In [None]:
import numpy as np
import accelerate
from sklearn.preprocessing import LabelEncoder

def compute_text_acc(preds, labels):
    return np.mean(np.array(preds) == np.array(labels))

def compute_text_loss(preds, labels):
    le = LabelEncoder()
    le.fit(np.array(labels))
    le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
    labels = np.array([le_dict.get(each, -1) for each in labels])
    preds = np.array([le_dict.get(each, -1) for each in preds])
    return mean_squared_error(labels, preds)

def compute_equation_acc(preds, labels):
    preds = [eval_equation(pred) for pred in preds]
    labels = [eval_equation(label) for label in labels]
    return np.mean(np.array(preds) == np.array(labels))

def compute_equation_loss(preds, labels):
    preds = [eval_equation(pred) for pred in preds]
    labels = [eval_equation(label) for label in labels]

    le = LabelEncoder()
    le.fit(np.array(labels))
    le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
    labels = np.array([le_dict.get(each, -1) for each in labels])
    preds = np.array([le_dict.get(each, -1) for each in preds])
    return mean_squared_error(labels, preds)

def eval_equation(equation):
    try:
        answer = eval(equation)
    except:
        answer = np.nan

    return answer


def compute_metrics_text(tokenizer):
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        decoded_preds = tokenizer.batch_decode(predictions[0], skip_special_tokens=True)

        labels = np.where(labels[0] != -100, labels[0], tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        acc = np.mean(np.array(decoded_preds) == np.array(decoded_labels))

        return {'accuracy': acc}

    return compute_metrics


def compute_metrics_text_aux(tokenizer):
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        acc = np.mean(np.array(decoded_preds) == np.array(decoded_labels))

        return {'accuracy': acc}

    return compute_metrics



def compute_metrics_equation(tokenizer):
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        decoded_preds = tokenizer.batch_decode(predictions[0], skip_special_tokens=True)

        labels = np.where(labels[0] != -100, labels[0], tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        preds = list()
        for pred in decoded_preds:
            preds.append(eval_equation(pred))

        labels = list()
        for label in decoded_labels:
            labels.append(eval_equation(label))

        acc = np.mean(np.array(preds) == np.array(labels))

        return {'accuracy': acc}

    return compute_metrics


def compute_metrics_equation_aux(tokenizer):
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        preds = list()
        for pred in decoded_preds:
            preds.append(eval_equation(pred))

        labels = list()
        for label in decoded_labels:
            labels.append(eval_equation(label))

        acc = np.mean(np.array(preds) == np.array(labels))

        return {'accuracy': acc}

    return compute_metrics

## Model Utils

In [None]:
import pandas as pd
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from torch import nn
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer
import accelerate

"""T5 Multi-Task by Task Prefix
"""
class TaskPrefixDataCollator(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        features_df = pd.DataFrame(features)
        pred_features = features_df.loc[:, ~features_df.columns.isin(['aux_labels', 'expl_input_ids', 'expl_attention_mask'])].to_dict('records')
        expl_features = features_df.loc[:, ~features_df.columns.isin(['labels', 'input_ids', 'attention_mask'])].rename(
            columns={'aux_labels': 'labels', 'expl_input_ids': 'input_ids', 'expl_attention_mask': 'attention_mask'}).to_dict('records')

        pred_features = super().__call__(pred_features, return_tensors)
        expl_features = super().__call__(expl_features, return_tensors)

        return {
            'pred': pred_features,
            'expl': expl_features,
        }


class TaskPrefixTrainer(Seq2SeqTrainer):
    def __init__(self, alpha, output_rationale, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.output_rationale = output_rationale


    def compute_loss(self, model, inputs, return_outputs=False):
        pred_outputs = model(**inputs['pred'])
        expl_outputs = model(**inputs['expl'])

        loss = self.alpha * pred_outputs.loss + (1. - self.alpha) * expl_outputs.loss

        return (loss, {'pred': pred_outputs, 'expl': expl_outputs}) if return_outputs else loss


    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:

        pred_outputs = super().prediction_step(model, inputs['pred'], prediction_loss_only=False, ignore_keys=ignore_keys)
        if self.output_rationale:
            expl_outputs = super().prediction_step(model, inputs['expl'], prediction_loss_only=False, ignore_keys=ignore_keys)
        else:
            expl_outputs = pred_outputs # placeholder only

        loss = self.alpha * pred_outputs[0]  + (1 - self.alpha) * expl_outputs[0]

        return (
            loss,
            [pred_outputs[1], expl_outputs[1]],
            [pred_outputs[2], expl_outputs[2]],
        )

## Trainer Utils

In [None]:
import os
import shutil
import logging

from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import T5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq
from transformers.trainer_utils import set_seed
import accelerate

# from model_utils import TaskPrefixDataCollator, TaskPrefixTrainer


def get_config_dir(args):
    return f'{args["dataset"]}/{args["from_pretrained"].split("/")[0]}/{args["model_type"]}/{args["llm"]}/{args["subsample"]}/{args["label_type"]}/{args["alpha"]}/{args["max_input_length"]}/{args["grad_steps"]*args["batch_size"]}/{args["optimizer_name"]}/{args["lr"]}'


def train_and_evaluate(args, run, tokenizer, tokenized_datasets, compute_metrics):
    set_seed(run)

    model = T5ForConditionalGeneration.from_pretrained(args['from_pretrained'])

    if args['parallelize']:
        model.parallelize()

    config_dir = get_config_dir(args)
    output_dir = f'ckpts/{config_dir}/{run}'  # for model ckpts
    logging_dir = f'logs/{config_dir}/{run}'  # for training logs

    if args['no_log']:
        logging_strategy = 'no'
        logging_dir = None
    else:
        logging_strategy = 'steps'

    # clear output dir if already exists
    if os.path.exists(output_dir):
        logging.info('Found existing ckpt directory. Deleted the old directory for the latest run.')
        shutil.rmtree(output_dir)

    training_args = Seq2SeqTrainingArguments(
        output_dir,
        remove_unused_columns = False,
        evaluation_strategy = 'steps',
        eval_steps=args['eval_steps'],
        save_strategy='no',
        save_steps=args['eval_steps'],
        logging_dir=logging_dir,
        logging_strategy=logging_strategy,
        logging_steps=args['eval_steps'],
        max_steps=args['max_steps'],
        learning_rate=args['lr'],
        gradient_accumulation_steps=args['grad_steps'],
        per_device_train_batch_size=args['batch_size'],
        per_device_eval_batch_size=args['batch_size'],
        predict_with_generate=True,
        seed=run,
        local_rank=args['local_rank'],
        bf16=args['bf16'],
        generation_max_length=args['gen_max_len'],
        prediction_loss_only=False,
    )

    if args['model_type'] == 'task_prefix':
        data_collator = TaskPrefixDataCollator(tokenizer=tokenizer, model=model)
    elif args['model_type'] == 'standard':
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    else:
        raise ValueError


    trainer_kwargs = {
        'alpha': args['alpha'],
        'output_rationale': args['output_rationale'],
        'model': model,
        'args': training_args,
        'train_dataset': tokenized_datasets["train"],
        'eval_dataset': {'test': tokenized_datasets["test"],},
        'data_collator': data_collator,
        'tokenizer': tokenizer,
        'compute_metrics': compute_metrics,
    }


    if args['model_type'] == 'task_prefix':
        trainer = TaskPrefixTrainer(**trainer_kwargs)
    elif args['model_type'] == 'standard':
        trainer_kwargs['pop']('alpha')
        trainer_kwargs['pop']('output_rationale')
        trainer = Seq2SeqTrainer(**trainer_kwargs)
    else:
        raise ValueError


    trainer.train()

## RUN

In [None]:
import argparse

from datasets import DatasetDict, concatenate_datasets
from transformers import AutoTokenizer
import accelerate

# from data_utils import CQADatasetLoader, SVAMPDatasetLoader, ESNLIDatasetLoader, ANLI1DatasetLoader, ASDivDatasetLoader
# from metrics import compute_text_acc, compute_equation_acc, compute_metrics_text, compute_metrics_equation, compute_metrics_text_aux, compute_metrics_equation_aux
# from train_utils import train_and_evaluate


def run(args):
    #### Prepare datasets
    if args['dataset'] == 'cqa':
        dataset_loader = CQADatasetLoader()
    elif args['dataset'] == 'svamp':
        dataset_loader = SVAMPDatasetLoader()
    elif args['dataset'] == 'esnli':
        dataset_loader = ESNLIDatasetLoader()
    elif args['dataset'] == 'anli1':
        dataset_loader = ANLI1DatasetLoader()
    elif args['dataset'] == 'asdiv':  # NOTE: for augmenting SVAMP only
        dataset_loader = SVAMPDatasetLoader()
        dataset_loader_svamp = SVAMPDatasetLoader()
        dataset_loader_asdiv = ASDivDatasetLoader()
    else:
        raise ValueError

    if args['dataset'] == 'asdiv':
        datasets_svamp = dataset_loader_svamp.load_from_json()
        datasets_asdiv = dataset_loader_asdiv.load_from_json()
        datasets = DatasetDict({
            'train': concatenate_datasets([datasets_svamp['train'], datasets_asdiv['train']]),
            'test': datasets_svamp['test']
        })
    else:
        datasets = dataset_loader.load_from_json()

    if args['llm'] is None:
        pass
    elif args['llm'] == 'palm':
        if args['dataset'] == 'asdiv':
            # training set = SVAMP training + ASDiv training
            train_llm_rationales_svamp, train_llm_labels_svamp = dataset_loader_svamp.load_llm_preds(split='train')
            train_llm_rationales_asdiv, train_llm_labels_asdiv = dataset_loader_asdiv.load_llm_preds(split='train')
            train_llm_rationales = train_llm_rationales_svamp + train_llm_rationales_asdiv
            train_llm_labels = train_llm_labels_svamp + train_llm_labels_asdiv
            # test set = SVAMP test
            test_llm_rationales, test_llm_labels = dataset_loader_svamp.load_llm_preds(split='test')
        else:
            train_llm_rationales, train_llm_labels = dataset_loader.load_llm_preds(split='train')
            test_llm_rationales, test_llm_labels = dataset_loader.load_llm_preds(split='test')
    elif args['llm'] == 'gpt':
        train_llm_rationales, train_llm_labels = dataset_loader.load_gpt_preds(split='train')
        test_llm_rationales, test_llm_labels = dataset_loader.load_gpt_preds(split='test')
    else:
        raise ValueError

    if args['llm'] is not None:
        datasets['train'] = datasets['train'].add_column('llm_label', train_llm_labels)
        datasets['test'] = datasets['test'].add_column('llm_label', test_llm_labels)
        datasets['train'] = datasets['train'].add_column('llm_rationale', train_llm_rationales)
        datasets['test'] = datasets['test'].add_column('llm_rationale', test_llm_rationales)

    if args['subsample'] < 1.0:
        datasets['train'] = datasets['train'].train_test_split(test_size=1.0-args['subsample'], seed=args['run'])['train']

    if dataset_loader.has_valid:
        if args['llm'] is None:
            pass
        elif args['llm'] == 'palm':
            valid_llm_rationales, valid_llm_labels = dataset_loader.load_llm_preds(split='valid')
        elif args['llm'] == 'gpt':
            valid_llm_rationales, valid_llm_labels = dataset_loader.load_gpt_preds(split='valid')
        else:
            raise ValueError

        datasets['valid'] = datasets['valid'].add_column('llm_label', valid_llm_labels)
        datasets['valid'] = datasets['valid'].add_column('llm_rationale', valid_llm_rationales)
    else:
        train_valid_datasets = datasets['train'].train_test_split(test_size=0.1, seed=0)

        datasets = DatasetDict({
            'train': train_valid_datasets['train'],
            'valid': train_valid_datasets['test'],
            'test': datasets['test'],
        })

    if args['label_type'] == 'gt':
        pass
    elif args['label_type'] == 'llm' and args['llm'] is not None:
        if args['dataset'] not in ['svamp', 'asdiv']:
            train_label_acc = compute_text_acc(datasets['train']['llm_label'], datasets['train']['label'])
            test_label_acc = compute_text_acc(datasets['test']['llm_label'], datasets['test']['label'])
            train_label_loss = compute_text_loss(datasets['train']['llm_label'], datasets['train']['label'])
            test_label_loss = compute_text_loss(datasets['test']['llm_label'], datasets['test']['label'])

        else:
            train_label_acc = compute_equation_acc(datasets['train']['llm_label'], datasets['train']['label'])
            test_label_acc = compute_equation_acc(datasets['test']['llm_label'], datasets['test']['label'])
            train_label_loss = compute_equation_loss(datasets['train']['llm_label'], datasets['train']['label'])
            test_label_loss = compute_equation_loss(datasets['test']['llm_label'], datasets['test']['label'])

        print(f'LLM Train Acc: {train_label_acc:.4f}')
        print(f'LLM Test Acc: {test_label_acc:.4f}')
        print(f'LLM Train Loss: {train_label_loss:.4f}')
        print(f'LLM Test Loss: {test_label_loss:.4f}')

        datasets['train'] = datasets['train'].remove_columns('label')
        datasets['train'] = datasets['train'].add_column('label', datasets['train']['llm_label'])

    else:
        raise ValueError

    if args['llm'] is not None:
        if 'rationale' in datasets['train'].column_names:
            datasets = datasets.remove_columns('rationale')
        datasets = datasets.rename_column('llm_rationale', 'rationale')


    #### Prepare datasets Prepare data for training
    tokenizer = AutoTokenizer.from_pretrained(args['from_pretrained'], use_fast=False)

    if 'nli' in args['dataset']:
        datasets = datasets.map(
            lambda example: {'input': tokenizer.eos_token.join([example['premise'], example['hypothesis']])},
            remove_columns=['premise', 'hypothesis'],
        )


    if args['model_type'] == 'task_prefix' and args['llm'] is not None:
        def tokenize_function(examples):
            model_inputs = tokenizer(['predict: ' + text for text in examples['input']], max_length=args['max_input_length'], truncation=True)
            expl_model_inputs = tokenizer(['explain: ' + text for text in examples['input']], max_length=args['max_input_length'], truncation=True)
            model_inputs['expl_input_ids'] = expl_model_inputs['input_ids']
            model_inputs['expl_attention_mask'] = expl_model_inputs['attention_mask']

            with tokenizer.as_target_tokenizer():
                label_output_encodings = tokenizer(examples['label'], max_length=256, truncation=True)
                rationale_output_encodings = tokenizer(examples['rationale'], max_length=256, truncation=True)

            model_inputs['labels'] = label_output_encodings['input_ids']
            model_inputs['aux_labels'] = rationale_output_encodings['input_ids']

            return model_inputs

    elif args['model_type'] == 'standard':
        def tokenize_function(examples):
            model_inputs = tokenizer(
                examples['input'],
                max_length=args['max_input_length'],
                truncation=True
            )

            with tokenizer.as_target_tokenizer():
                label_output_encodings = tokenizer(examples['label'], max_length=256, truncation=True)

            model_inputs['labels'] = label_output_encodings['input_ids']

            return model_inputs

    else:
        raise ValueError


    if args['llm'] is None:
        tokenized_datasets = datasets.map(
            tokenize_function,
            remove_columns=['input', 'label'],
            batched=True
        )
    else:
        tokenized_datasets = datasets.map(
            tokenize_function,
            remove_columns=['input', 'rationale', 'label', 'llm_label'],
            batched=True
        )


    if args['model_type'] == 'standard':
        if args['dataset'] not in ['svamp', 'asdiv']:
            compute_metrics = compute_metrics_text_aux(tokenizer)
        else:
            compute_metrics = compute_metrics_equation_aux(tokenizer)

    else:
        if args['dataset'] not in ['svamp', 'asdiv']:
            compute_metrics = compute_metrics_text(tokenizer)
        else:
            compute_metrics = compute_metrics_equation(tokenizer)


    train_and_evaluate(args, args['run'], tokenizer, tokenized_datasets, compute_metrics)

## Standard Training

In [None]:
args = {
    'dataset': 'anli1',
    'subsample': 1.0,
    'alpha': 0.5,
    'max_steps': 5_000,
    'eval_steps': 500,
    'batch_size': 1,
    'optimizer_name': 'AdamW',
    'lr': 5e-5,
    'run': 0,
    'from_pretrained': 't5-small',
    'label_type': 'gt',
    'llm': 'palm',
    'max_input_length': 2048,
    'grad_steps': 1,
    'local_rank': -1,
    'gen_max_len': 64,
    'parallelize': True,
    'model_type': 'standard',
    'bf16': False,
    'no_log': False,
    'output_rationale': True,
    }

run(args)

## Standard Knowledge Distillation Training

In [None]:
args = {
    'dataset': 'anli1',
    'subsample': 1.0,
    'alpha': 0.5,
    'max_steps': 5_000,
    'eval_steps': 500,
    'batch_size': 1,
    'optimizer_name': 'AdamW',
    'lr': 5e-5,
    'run': 0,
    'from_pretrained': 't5-small',
    'label_type': 'llm',
    'llm': 'palm',
    'max_input_length': 2048,
    'grad_steps': 1,
    'local_rank': -1,
    'gen_max_len': 64,
    'parallelize': True,
    'model_type': 'standard',
    'bf16': False,
    'no_log': False,
    'output_rationale': True,
    }

run(args)

## Step-by-Step Distillation Training

In [None]:
args = {
    'dataset': 'anli1',
    'subsample': 1.0,
    'alpha': 0.5,
    'max_steps': 5_000,
    'eval_steps': 500,
    'batch_size': 1,
    'optimizer_name': 'AdamW',
    'lr': 5e-5,
    'run': 0,
    'from_pretrained': 't5-small',
    'label_type': 'llm',
    'llm': 'palm',
    'max_input_length': 2048,
    'grad_steps': 1,
    'local_rank': -1,
    'gen_max_len': 64,
    'parallelize': True,
    'model_type': 'task_prefix',
    'bf16': False,
    'no_log': False,
    'output_rationale': True,
    }

run(args)