In [None]:
!nvidia-smi
!pip install datasets evaluate numpy torch accelerate tqdm

In [None]:
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning a ü§ó Transformers model for question answering using ü§ó Accelerate.
"""
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.

import argparse
import collections
import json
import logging
import math
import os
import random
from pathlib import Path
from types import SimpleNamespace
from typing import Optional

import datasets
import evaluate
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    SchedulerType,
    default_data_collator,
    get_scheduler,
)

In [None]:
args = SimpleNamespace(
    train_file="/kaggle/input/ntu-adl-2025-hw-1/train.json",
    validation_file="/kaggle/input/ntu-adl-2025-hw-1/valid.json",
    context_file="/kaggle/input/ntu-adl-2025-hw-1/context.json",
    max_seq_length=512,
    pad_to_max_length=False,
    model_name_or_path="hfl/chinese-macbert-base",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=3e-5,
    num_train_epochs=2,
    max_train_steps=None,
    gradient_accumulation_steps=2,
    lr_scheduler_type=SchedulerType.LINEAR, # choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
    output_dir="/kaggle/working/",
    seed=1234,
    doc_stride=128,
    n_best_size=20,
    max_answer_length=30,
)

print(args)

In [None]:
def postprocess_qa_predictions(
    examples,
    features,
    predictions: tuple[np.ndarray, np.ndarray],
    version_2_with_negative: bool = False,
    n_best_size: int = 20,
    max_answer_length: int = 30,
    null_score_diff_threshold: float = 0.0,
    output_dir: Optional[str] = None,
    prefix: Optional[str] = None,
    log_level: Optional[int] = logging.WARNING,
):
    if len(predictions) != 2:
        raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
    all_start_logits, all_end_logits = predictions

    if len(predictions[0]) != len(features):
        raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()

    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_prediction = None
        prelim_predictions = []

        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]
            # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
            # available in the current feature.
            token_is_max_context = features[feature_index].get("token_is_max_context", None)

            # Update minimum null prediction.
            feature_null_score = start_logits[0] + end_logits[0]
            if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or len(offset_mapping[start_index]) < 2
                        or offset_mapping[end_index] is None
                        or len(offset_mapping[end_index]) < 2
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue
                    # Don't consider answer that don't have the maximum context available (if such information is
                    # provided).
                    if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                        continue

                    prelim_predictions.append(
                        {
                            "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                            "score": start_logits[start_index] + end_logits[end_index],
                            "start_logit": start_logits[start_index],
                            "end_logit": end_logits[end_index],
                        }
                    )

        # Only keep the best `n_best_size` predictions.
        predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

        # Use the offsets to gather the answer text in the original context.
        context = example["context"]
        for pred in predictions:
            offsets = pred.pop("offsets")
            pred["text"] = context[offsets[0] : offsets[1]]

        # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
        # failure.
        if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
            predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})

        # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
        # the LogSumExp trick).
        scores = np.array([pred.pop("score") for pred in predictions])
        exp_scores = np.exp(scores - np.max(scores))
        probs = exp_scores / exp_scores.sum()

        # Include the probabilities in our predictions.
        for prob, pred in zip(probs, predictions):
            pred["probability"] = prob

        # Pick the best prediction. If the null answer is not possible, this is easy.
        all_predictions[example["id"]] = predictions[0]["text"]

        # Make `predictions` JSON-serializable by casting np.float back to float.
        all_nbest_json[example["id"]] = [
            {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
            for pred in predictions
        ]

    # If we have an output_dir, let's save all those dicts.
    if output_dir is not None:
        if not os.path.isdir(output_dir):
            raise OSError(f"{output_dir} is not a directory.")

        prediction_file = os.path.join(
            output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
        )
        nbest_file = os.path.join(
            output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
        )

        print(f"Saving predictions to {prediction_file}.")
        with open(prediction_file, "w") as writer:
            writer.write(json.dumps(all_predictions, indent=4) + "\n")
        print(f"Saving nbest_preds to {nbest_file}.")
        with open(nbest_file, "w") as writer:
            writer.write(json.dumps(all_nbest_json, indent=4) + "\n")

    return all_predictions


In [None]:
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
    """
    Save results while prefixing metric names.

    Args:
        results: (:obj:`dict`):
            A dictionary of results.
        output_dir: (:obj:`str`):
            An output directory.
        file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
            An output file name.
        metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
            A metric name prefix.
    """
    # Prefix all keys with metric_key_prefix + '_'
    for key in list(results.keys()):
        if not key.startswith(f"{metric_key_prefix}_"):
            results[f"{metric_key_prefix}_{key}"] = results.pop(key)

    with open(os.path.join(output_dir, file_name), "w") as f:
        json.dump(results, f, indent=4)

In [None]:
# Initialize the accelerator.
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

# Handle the repository creation
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()

In [None]:
import json
import datasets

# ËºâÂÖ• context ÊÆµËêΩ
with open(args.context_file, "r", encoding="utf-8") as f:
    contexts = json.load(f)

def load_span_selection(filename, contexts):
    """ËÆÄÂèñ span selection datasetÔºåÂõûÂÇ≥ list of dict"""
    with open(filename, "r", encoding="utf-8") as f:
        data = json.load(f)

    examples = []
    for ex in data:
        context_text = contexts[ex["relevant"]]  # Áî® relevant id ÊâæÊÆµËêΩ
        answer_text = ex["answer"]["text"]
        answer_start = ex["answer"]["start"]

        examples.append({
            "id": ex["id"],
            "question": ex["question"],
            "context": context_text,
            "answers": {
                "text": [answer_text],
                "answer_start": [answer_start],
            }
        })
    return examples

# === Âª∫Á´ã DatasetDict ===
dataset_splits = {}
if args.train_file is not None:
    dataset_splits["train"] = load_span_selection(args.train_file, contexts)
if args.validation_file is not None:
    dataset_splits["validation"] = load_span_selection(args.validation_file, contexts)

raw_datasets = datasets.DatasetDict({
    split: datasets.Dataset.from_list(data)
    for split, data in dataset_splits.items()
})

print(raw_datasets)
print(raw_datasets["train"][0])

In [None]:
# ========= Load pretrained model and tokenizer =========
config = AutoConfig.from_pretrained(args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
model = AutoModelForQuestionAnswering.from_pretrained(
    args.model_name_or_path,
    from_tf=bool(".ckpt" in args.model_name_or_path),
    config=config,
)

# ========= Dataset column names =========
column_names = raw_datasets["train"].column_names
question_column_name = "question" if "question" in column_names else column_names[0]
context_column_name = "context" if "context" in column_names else column_names[1]
answer_column_name = "answers" if "answers" in column_names else column_names[2]

pad_on_right = tokenizer.padding_side == "right"
max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)

In [None]:
# ========= Preprocessing (Train) =========
def prepare_train_features(examples):
    examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

    tokenized_examples = tokenizer(
        examples[question_column_name if pad_on_right else context_column_name],
        examples[context_column_name if pad_on_right else question_column_name],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_seq_length,
        stride=args.doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length" if args.pad_to_max_length else False,
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")

    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id) if tokenizer.cls_token_id in input_ids else 0
        sequence_ids = tokenized_examples.sequence_ids(i)
        sample_index = sample_mapping[i]
        answers = examples[answer_column_name][sample_index]

        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = next(idx for idx, s in enumerate(sequence_ids) if s == (1 if pad_on_right else 0))
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)

                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

with accelerator.main_process_first():
    train_dataset = raw_datasets["train"].map(
        prepare_train_features,
        batched=True,
        remove_columns=column_names,
        desc="Running tokenizer on train dataset",
    )

# ========= Preprocessing (Validation) =========
def prepare_validation_features(examples):
    examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]

    tokenized_examples = tokenizer(
        examples[question_column_name if pad_on_right else context_column_name],
        examples[context_column_name if pad_on_right else question_column_name],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_seq_length,
        stride=args.doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length" if args.pad_to_max_length else False,
    )

    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

eval_examples = raw_datasets["validation"]
with accelerator.main_process_first():
    eval_dataset = eval_examples.map(
        prepare_validation_features,
        batched=True,
        remove_columns=column_names,
        desc="Running tokenizer on validation dataset",
    )

# ========= DataLoaders =========
if args.pad_to_max_length:
    data_collator = default_data_collator
else:
    pad_to_multiple_of = 16 if accelerator.mixed_precision == "fp8" else (8 if accelerator.mixed_precision != "no" else None)
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=pad_to_multiple_of)

train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size)
eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"])
eval_dataloader = DataLoader(eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

# ========= Post-processing =========
def post_processing_function(examples, features, predictions, stage="eval"):
    predictions = postprocess_qa_predictions(
        examples=examples,
        features=features,
        predictions=predictions,
        version_2_with_negative=False,
        n_best_size=args.n_best_size,
        max_answer_length=args.max_answer_length,
        null_score_diff_threshold=0.0,
        output_dir=args.output_dir,
        prefix=stage,
    )
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
    references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
    return EvalPrediction(predictions=formatted_predictions, label_ids=references)

metric = evaluate.load("squad")

# Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
    step = 0
    # create a numpy array and fill it with -100.
    logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64)
    # Now since we have create an array now we will populate it with the outputs gathered using accelerator.gather_for_metrics
    for i, output_logit in enumerate(start_or_end_logits):  # populate columns
        # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
        # And after every iteration we have to change the step

        batch_size = output_logit.shape[0]
        cols = output_logit.shape[1]

        if step + batch_size < len(dataset):
            logits_concat[step : step + batch_size, :cols] = output_logit
        else:
            logits_concat[step:, :cols] = output_logit[: len(dataset) - step]

        step += batch_size

    return logits_concat

# ========= Optimizer & Scheduler =========
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=args.max_train_steps,
)

# ========= Accelerator prepare =========
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Recalculate total training steps
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# ========= Training info =========
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps


In [None]:
print("***** Running training *****")
print(f"  Num examples = {len(train_dataset)}")
print(f"  Num Epochs = {args.num_train_epochs}")
print(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
print(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f"  Total optimization steps = {args.max_train_steps}")

# Progress bar (only main processÈ°ØÁ§∫)
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

# Â¶ÇÊûúÊòØÂæû checkpoint ÊÅ¢Âæ©ÔºåÊõ¥Êñ∞ÈÄ≤Â∫¶Ê¢ù
progress_bar.update(completed_steps)

# ====== Logs ======
loss_history = []
metric_history = []
steps_record = []

# Progress bar (only main processÈ°ØÁ§∫)
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

progress_bar.update(completed_steps)

# -------- Training Loop --------
for epoch in range(starting_epoch, args.num_train_epochs):
    model.train()
    running_loss = 0.0  
    log_interval = 1000
    eval_interval = args.max_train_steps // 5   # Ëá≥Â∞ë 5 ÂÄãÈªû

    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(model):
            outputs = model(**batch)
            loss = outputs.loss

            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        running_loss += loss.item()

        # ÊØè log_interval steps Ë®òÈåÑ loss
        if (step + 1) % log_interval == 0 and accelerator.is_local_main_process:
            avg_loss = running_loss / log_interval
            print(f"Epoch {epoch} | Step {step+1} | Avg Loss: {avg_loss:.4f}")
            loss_history.append(avg_loss)
            steps_record.append(completed_steps)
            running_loss = 0.0  

        # ÊØè eval_interval steps ÂÅö evaluation
        if completed_steps % eval_interval == 0 and completed_steps > 0 and accelerator.is_local_main_process:
            model.eval()
            all_start_logits, all_end_logits = [], []
            for e_step, e_batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**e_batch)
                    start_logits = outputs.start_logits
                    end_logits = outputs.end_logits

                if not args.pad_to_max_length:
                    start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
                    end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)

                all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
                all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())

            max_len = max(x.shape[1] for x in all_start_logits)
            start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
            end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)
            outputs_numpy = (start_logits_concat, end_logits_concat)

            prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
            eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)

            metric_history.append(eval_metric["exact_match"])
            print(f"Step {completed_steps} | EM = {eval_metric['exact_match']:.2f}")

            del all_start_logits, all_end_logits
            model.train()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            completed_steps += 1

        if completed_steps >= args.max_train_steps:
            break


In [None]:
# -------- Evaluation Loop --------
print("***** Running Evaluation *****")
print(f"  Num examples = {len(eval_dataset)}")
print(f"  Batch size = {args.per_device_eval_batch_size}")

all_start_logits = []
all_end_logits = []
model.eval()

for step, batch in enumerate(eval_dataloader):
    with torch.no_grad():
        outputs = model(**batch)
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

        if not args.pad_to_max_length:
            start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
            end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)

        all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
        all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())

max_len = max(x.shape[1] for x in all_start_logits)

# concatenate results
start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len)
end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len)

# clean up
del all_start_logits
del all_end_logits

# post-processing + metrics
outputs_numpy = (start_logits_concat, end_logits_concat)
prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy)
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
print(f"Evaluation metrics: {eval_metric}")

In [None]:
# -------- Save Model --------
if args.output_dir is not None:
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        args.output_dir, 
        is_main_process=accelerator.is_main_process, 
        save_function=accelerator.save
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(args.output_dir)

In [None]:
!zip -r /kaggle/working/working.zip /kaggle/working/*