# demo_multChoice

This demo code owes much to that provided by [Lucia Zheng here](https://github.com/reglab/casehold/tree/main/multiple_choice), and also to hints from this [Captum tutorial](https://captum.ai/tutorials/Bert_SQUAD_Interpret).  

It has been stream-lined somewhat, eg, by dropping some options (eg, Tensorflow, computing pre-training loss).

## Preliminaries

Both data and models are [made available via the casehold account at Hugging Face](https://huggingface.co/casehold).

In general, you'll probably want to cache these files on some local machine.
Two parameters connect this code to your local environment:

- `DataDir`: containing training and testing examples
- `ModelPath`: points to the cached source for the models

These parameters are set below.

## package imports

In [None]:
from collections import defaultdict
import logging
import os
import sys
import warnings

from dataclasses import dataclass, field
from itertools import chain
from typing import Optional, Union
import evaluate
from functools import partial

import numpy as np

import torch
import datasets
from datasets import load_dataset
import datasets.utils
# Support for load_metric has been removed in datasets@3.0.0, see Release 3.0.0 · huggingface/datasets
# from datasets import load_metric

import platform
import torch

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torch.nn import functional as F

import transformers
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import  check_min_version
import transformers.utils

import captum
from captum.attr import (
    FeatureAblation, 
    ShapleyValues,
    LayerIntegratedGradients, 
    LLMAttribution, 
    LLMGradientAttribution, 
    TextTokenInput, 
    TextTemplateInput,
    ProductBaselines,
    LayerConductance,
    TokenReferenceBase
)
from captum.attr import visualization as viz

## Multiple choice utilities from https://github.com/reglab/casehold

In [None]:
from utils_casehold import *

## important constants

In [None]:
import socket
HOST = socket.gethostname()

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.40.0.dev0")

logger = logging.getLogger(__name__)

processors = {"casehold": CaseHOLDProcessor}

NUM_MULTIPLE_CHOICE_LABELS = 5

Verbose = False

## utility functions

In [None]:
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9
def count_TrainParam(model: torch.nn.Module) -> int:
    """ Returns the number of learnable parameters for a PyTorch model """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_AllParam(model: torch.nn.Module) -> int:
    """ Returns the number of learnable parameters for a PyTorch model """
    return sum(p.numel() for p in model.parameters())

def tensor2scalar(t):
    return t.cpu().detach().numpy().tolist()

def tensor2list(t):
    return t.cpu().detach().numpy().tolist()[0]

def roundAttrib(atensor):
    '''convert float to restricted range of strings
    '''
    a = tensor2scalar(atensor)
    return f'{a:.2f}'


## Transformer classes/functions

These patterns seem common across transformer demos

In [None]:
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(processors.keys())})
    data_dir: str = field(metadata={"help": "Should contain the data files for the task."})
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )

    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    
# from https://github.com/reglab/casehold/blob/main/multiple_choice/run_multiple_choice.py
# Define custom compute_metrics function, returns macro F1 metric for CaseHOLD task
EvalF1Metric = evaluate.load('f1')
def compute_metrics_f1(p: EvalPrediction):
    preds = np.argmax(p.predictions, axis=1)
    # metric = load_metric("f1")
    metric = EvalF1Metric
    # Compute macro F1 for 5-class CaseHOLD task
    f1 = metric.compute(predictions=preds, references=p.label_ids, average='macro')
    return f1

# from https://captum.ai/tutorials/Bert_SQUAD_Interpret
def multChoice_forward(model,inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, )
    prob = torch.softmax(output.logits,1)
    # targetIdx = tensor2list(prob.argmax(1))
    return prob


## Captum hooks

Most of the work is done in the `captumIG` function. It has been written assuming it is given a model and tokenizer, a list of examples and tests.  It produces several results, written to a directory.

It makes use of a function `vis_text2` which is a modified version of captum's (`visualization.vis_text`)[https://github.com/pytorch/captum/blob/master/captum/attr/_utils/visualization.py] function, "...partially copied from experiments conducted by Davide Testuggine at Facebook."


In [None]:
def vis_text2(
    datarecords: Iterable[viz.VisualizationDataRecord], idx, legend: bool = False
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    assert HAS_IPYTHON, (
        "IPython must be available to visualize text. "
        "Please run 'pip install ipython'."
    )
    dom = [] # ["<table width: 100%>"]
    
    rows = [
        "<tr>"
        "<th>Idx</th>"
        "<th>True Label</th>"
        "<th>Predicted Label</th>"
        "<th>Attribution Label</th>"
        "<th>Attribution Score</th>"
        "<th>Word Importance</th>"
        "</tr>\n"
    ]
    rows = []
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    f'<td>{idx}</td>',
                    viz.format_classname(datarecord.true_class),
                    viz.format_classname(
                        "{0} ({1:.2f})".format(
                            datarecord.pred_class, datarecord.pred_prob
                        )
                    ),
                    viz.format_classname(datarecord.attr_class),
                    viz.format_classname("{0:.2f}".format(datarecord.attr_score)),
                    viz.format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "</tr>/n",
                ]
            )
        )

    if legend:
        dom.append(
            '<div style="border-top: 1px solid; margin-top: 5px; \
            padding-top: 5px; display: inline-block">'
        )
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=viz._get_color(value), label=label
                )
            )
        dom.append("</div>")

    dom.append("".join(rows))
    # dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

    return html

# https://github.com/pytorch/captum/blob/master/tutorials/Bert_SQUAD_Interpret.ipynb
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def captumIG(model,tokenizer,testList,targetList,outdir):
    '''captum integrated gradiant attribution
        for multiple choice
    '''
    
    # from https://github.com/pytorch/captum/issues/303
    # https://github.com/pytorch/captum/blob/master/tutorials/Bert_SQUAD_Interpret.ipynb
    MCfwd = partial(multChoice_forward,model)

    lig = LayerIntegratedGradients(MCfwd, model.bert.embeddings)

    rptFile = outdir + 'stats.csv'
    outs = open(rptFile,'w')
    outs.write('Idx,target,predLbl,correct,predProb,attrSum\n')
    allHTML = ''
    
    CaptumTblHdr = '''<table width: 100%>
      <tr><th>Idx</th><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th>'''

    CaptumTblTrlr = '</table>\n'
    
    allHTML += CaptumTblHdr
    
    tokFreq = defaultdict(lambda: defaultdict(int)) # inputIdx -> roundAttrib -> freq
    
    for egIdx, tst in enumerate(testList):
        # NB: use parallel structure of testList, targetList
        
        # 241207: vector of prob
        # targetVec = targetList[egIdx][0]
        # targetIdx = tensor2scalar(targetVec.argmax())
        
        targetIdx = tensor2scalar(targetList[egIdx])
        
        model.zero_grad()
        
        out = model(**tst)
        
        # interpret logits
        logits = out.logits    
        prob = torch.softmax(logits,1)
        predIdx = tensor2list(prob.argmax(1))
        predProb = tensor2list(prob.max(1)[0])
        correct = 1 if predIdx == targetIdx else 0
        print(f'captumIG: {egIdx=} {targetIdx=} {predIdx=} {correct} {predProb=}')    
        
        tstEGTuple = (tst['input_ids'], 
                      tst['attention_mask'], 
                      tst['token_type_ids'])
            
        attributions_ig = lig.attribute(tstEGTuple, n_steps=5,target=targetIdx) 
        
        # attributions_sum has shape [5,128]
        attributions_sum = summarize_attributions(attributions_ig)
    
        # 241208: focus on predicted; also target if different?
        predAttrib = attributions_sum[predIdx]
        
        # NB: need tyo get FIRST element of tst['input_ids']?
        indices = tst['input_ids'][0][predIdx].detach().tolist()
        # indices has all NUM_MULTIPLE_CHOICE_LABELS choices
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        
        # collect token attributions across all examples
        for i,tokIdx in enumerate(indices):
            if tokIdx == PAD_IDX:
                continue
            tokAttrib = predAttrib[i]
            # NB: round floats to create small set of keys
            tokarnd = roundAttrib(tokAttrib)
            tokFreq[tokIdx][tokarnd] += 1
                
        delta_start = 0.
    
        # cf. viz.VisualizationDataRecord.__init__()
        vis = viz.VisualizationDataRecord(
                            predAttrib,        #    word_attributions: Tensor,
                            predProb,                #     pred_prob: float,
                            predIdx,                #     pred_class: int,
                            targetIdx,                    #     true_class: int,
                            str(targetIdx),            #     attr_class: int,
                            predAttrib.sum(), #     attr_score: float,
                            all_tokens,                #     raw_input_ids: List[str],
                            delta_start)            #     convergence_score: float,
    
        visHTML = vis_text2([vis],egIdx)
        visTxt = visHTML.data
        allHTML += visTxt
        
        outs.write(f'{egIdx},{targetIdx},{predIdx},{correct},"{predProb}",{predAttrib.sum()}\n')
    
    outs.close()
    
    vizPath = outdir+'viz.html'
    allHTML += CaptumTblTrlr
    outs = open(vizPath,'w')
    outs.write(allHTML)
    print(f'# viz written to {vizPath}')
    outs.close()

    allAttribSet = set()
    for tokIdx in tokFreq.keys():
        for attrib in tokFreq[tokIdx].keys():
            allAttribSet.add(attrib)
    print(f'captumEG: Total attributes={len(allAttribSet)}')
    allAttrib = sorted(list(allAttribSet))
    
    # NB: NEGATIVE attrib score STRINGS in REVERSED order
    firstPos = None
    for i,a in enumerate(allAttrib):
        if not a.startswith('-'):
            firstPos = i
            break
    negAttrib = list(reversed(allAttrib[:firstPos]))
    allAttrib = negAttrib + allAttrib[firstPos:]
    
    tokFreqFile = outdir + 'tokfreq.csv'
    outs = open(tokFreqFile,'w')
    hdr = 'Idx,Token,SigNeg,SigPos,NUAttrib'
    for a in allAttrib:
        hdr += f',{a}'
    outs.write(hdr+',TotFreq\n')
    
    signifThresh = 0.05
    for tokIdx in sorted(list(tokFreq.keys())):
        tok = tokenizer.convert_ids_to_tokens([tokIdx])[0]
        if tok=='"':
            tok = '""'
        nsignifPos = 0
        nsignifNeg = 0
        
        for attrib in tokFreq[tokIdx]:
            fa = float(attrib)
            if fa < -signifThresh:
                nsignifNeg += 1
            elif fa > signifThresh:
                nsignifPos += 1
                
        line = f'{tokIdx},"{tok}",{nsignifNeg},{nsignifPos},{len(tokFreq[tokIdx])}'
        tot = 0 
        for attrib in allAttrib:
            # total frequency weighted ABSOLUTE attribution
            if attrib in tokFreq[tokIdx]:
                line += f',{tokFreq[tokIdx][attrib]}'
                # 
                tot += tokFreq[tokIdx][attrib] * abs(float(attrib))
            else:
                line += ', '
        line += f',{tot}'
        outs.write(line+'\n')
    outs.close()        
    print('captumEG: done')
    
    return allHTML

## Ready to get some work done!

### Different machines

I use a couple of different machines. The default for this demo is to assume the MPS library `(platform.platform()="macOS-14.7-arm64-arm-64bit")` and using `DEVICE='mps'`. It has also been run on a Linux machine with a NVidia GV??? 

Modify `BatchSize` to fit your GPU capabilities.

In [None]:
if HOST.startswith('wayne'):
    # DEVICE='cpu'
    DEVICE='mps'
    DataDir = '/Users/rik/.cache/'        
    OutDir = '/Users/rik/data/ai4law/'
    ModelPath = '/Users/rik/.cache/casehold-models/'
    BatchSize = 16
elif HOST=='mjq':
    DEVICE='cuda'
    DataDir =  '/home/rik/.cache/'
    OutDir = '/home/Data/ai4law/'
    ModelPath = '/home/Data/ai4lawData/casehold-models/'
    BatchSize = 12
else:
    assert False, f'Unknown host?! {HOST}'

### Echo some platform specific details

In [None]:
print(f'# {HOST=} {DEVICE=}')
print(f'# platform={platform.platform()}') # macOS-14.7-arm64-arm-64bit
print(f'# MPS backend: {torch.backends.mps.is_available()}')
for package in (torch, transformers,captum):
    print('#\t',package.__name__, package.__version__)

### Identify model and other parameters for transformer

To make this python module self-contained, arguments that the casehold demo made shell argument are explicitly contructed as a list.

This demo assumes locally cached models.

Note that the data cache is being over-written, forcing the SWAG-style `MultipleChoiceDataset.convert_examples_to_features()` to be rerun.

In [None]:
ModelName = 'legalbert' # 'legalbert' 'custom-legalbert' 'bert-double'
ModelPath += f'{ModelName}/'

DataSet = 'casehold'
DataSetTags = 'casehold'
DataDir += 'datasets/casehold___casehold/all/direct/'

Trained = 'train' # 'train' 'useTrained'

RunName = f'{ModelName}_{DataSet}_{Trained}'
OutDir += f'{RunName}/'

print(f'# {ModelPath=}\n# {RunName=}\n# {DataDir=}\n# {OutDir=} {Trained=}')

args is a used as a module parameter, rather than via a call to python as implemented by the [casehold demo](https://github.com/reglab/casehold/blob/main/demo.ipynb). `HfArgumentParser` is used to break these into sets for model, data and training.

In [None]:
args = []

if Trained=='train':
    args.append('--do_train')
    
args.append('--model_name_or_path');  args.append(ModelPath)

args.append('--data_dir');            args.append(DataDir)
args.append('--task_name');           args.append('casehold')

args.append('--output_dir');          args.append(OutDir)

args.append('--overwrite_cache=True')

args.append('--max_seq_length');      args.append('128')
args.append('--do_eval')
args.append('--eval_strategy'); args.append('steps')

args.append(f'--per_device_train_batch_size={BatchSize}')
args.append(f'--per_device_eval_batch_size={BatchSize}')
    
args.append('--learning_rate=1e-5')
args.append('--num_train_epochs=2.0')
args.append('--overwrite_output_dir=True')
args.append('--logging_steps');       args.append('50')

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
       
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args)


### Set up logging, checkpoints, config 

In [None]:
if not os.path.exists(OutDir):
    os.makedirs(OutDir)
    
# Setup logging
logging.basicConfig(
    filename=OutDir+f'{RunName}.log',
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

# Set seed before initializing model.
set_seed(training_args.seed)

config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    token=model_args.token,
    trust_remote_code=model_args.trust_remote_code,
)


### Get the model and its tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    use_fast=model_args.use_fast_tokenizer,
    revision=model_args.model_revision,
    token=model_args.token,
    trust_remote_code=model_args.trust_remote_code,
)
        
model = AutoModelForMultipleChoice.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    token=model_args.token,
    trust_remote_code=model_args.trust_remote_code,
)
model.to(DEVICE)

global PAD_IDX
global CLS_IDX
global SEP_IDX
PAD_IDX = tokenizer.pad_token_id
CLS_IDX = tokenizer.cls_token_id
SEP_IDX = tokenizer.sep_token_id

# ValueError: You are trying to save a non contiguous tensor:      
# https://github.com/huggingface/transformers/issues/28293#issuecomment-2284567863
for param in model.parameters(): param.data = param.data.contiguous()

if data_args.max_seq_length > tokenizer.model_max_length:
    logger.warning(
        f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
        f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
    )
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)


### Establish training and evaluation datasets

Note that the cache is being over-written, forcing the SWAG-style `MultipleChoiceDataset.convert_examples_to_features()` to be rerun.

In [None]:
train_dataset = \
    MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.train,
    )

eval_dataset = \
    MultipleChoiceDataset(
        data_dir=data_args.data_dir,
        tokenizer=tokenizer,
        task=data_args.task_name,
        max_seq_length=data_args.max_seq_length,
        overwrite_cache=data_args.overwrite_cache,
        mode=Split.test,
    )


For TESTING, you can identify subset of eval_dataset

In [None]:
# caseholdSamples = [0,1,2,3,4]
# tstEG = [eval_dataset[i] for i in caseholdSamples]

### Create modified versions of examples for various purposes

In [None]:
inBits = ['input_ids','attention_mask','token_type_ids']

# dict of tensors works with model
eval_TensorDict = []
for eg in eval_dataset:
    d = {}
    for k in inBits:
        d[k] =  torch.tensor( eg.__getattribute__(k),device=DEVICE).unsqueeze(0)
    eval_TensorDict.append(d)
                
# tuple of dicts works for SummaryWriter.add_graph()
eval_TensorTuple = [ (d['input_ids'], 
                      d['attention_mask'], 
                      d['token_type_ids']) for d in eval_TensorDict]


### Create vector of examples' target labels

In [None]:
idxTbl = {}
egTarget = []
for i,tst in enumerate(eval_dataset):
    idxTbl[i] = tst.example_id
    targetIdx = tst.label
   
    # 241207: NO: target should NOT be vector of probabilities?
    # target4captum = [ torch.tensor([1. if i == targetIdx else 0. for i in range(NUM_MULTIPLE_CHOICE_LABELS)]) ]

    # captum/attri/_core/integrated_gradient.py L#168
    #
    # - target (int, tuple, Tensor, or list, optional): Output indices
    # for which gradients are computed (for classification cases, this
    # is usually the target class).
    target4captum = torch.tensor(targetIdx,device=DEVICE)
    
    egTarget.append( target4captum )

### Produce details of network for use by tensorboard

In [None]:
writer = SummaryWriter(OutDir + f'runs/summary/')
# NB: assumes args (tuple): input tensor[s] for the model
    
writer.add_graph(model, eval_TensorTuple[0], use_strict_trace=False)
writer.close()

### Train the model!

This takes several hours on my machines.  The trainer also produces periodic evaluations of the model every 50 steps (batches of training examples) during the training process.

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics_f1,
)

checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()  # Saves the tokenizer too for easy upload
metrics = train_result.metrics

metrics["train_samples"] = len(train_dataset)

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()


### Document our work

In [None]:
kwargs = {
    "finetuned_from": model_args.model_name_or_path,
    "tasks": "multiple-choice",
    "dataset_tags": DataSetTags,
    "dataset_args": "regular",
    "dataset": DataSet,
    "language": "en",
}

trainer.create_model_card(**kwargs)

### Apply captum_IG to test examples

This produces the visualized annotation of all sentences (inline in the notebook) and several files in `OutDir`:
  * `viz.html` a file with all of these annotations
  * `stats.csv`a file with summary statistics for each test example, and
  * `tokfreq.csv` a file with the distribution of tokens' attribution scores across all test examples

In [None]:
captumIG(model,tokenizer,eval_TensorDict,egTarget,OutDir)