In [None]:
import os

args = {
    "model": "Salesforce/codet5p-220m",
    "dataset": "benjis/diversevul",
    "save_dir": "saved_model",
    "max_source_len": 512,
    "max_target_len": 32,
    "epochs": 2,
    "batch_size_per_replica": 2,
    "grad_acc_steps": 16,
    "lr": 5e-5,
    "lr_warmup_steps": 200,
    "log_freq": 10,
    "local_rank": -1,
    "deepspeed": None,
    "fp16": False
}

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
import os

final_checkpoint_dir = os.path.join(args['save_dir'], "final_checkpoint")

train_data = load_dataset(args['dataset'], split="train")

tokenizer = AutoTokenizer.from_pretrained(args['model'])

if args['local_rank'] in [0, -1]:
    tokenizer.save_pretrained(final_checkpoint_dir)

def preprocess_function(examples):
    source = [ex for ex in examples["func"]]
    target = [','.join(ex) if len(ex) > 0 else "CWE0" for ex in examples["cwe"]]
    
    model_inputs = tokenizer(source, max_length=args['max_source_len'], padding="max_length", truncation=True)
    labels = tokenizer(target, max_length=args['max_target_len'], padding="max_length", truncation=True)

    model_inputs["labels"] = labels["input_ids"].copy()
    #model_inputs["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs["labels"]]
    return model_inputs

train_data = train_data.map(
    preprocess_function,
    batched=True,
    remove_columns=train_data.column_names,
    num_proc=64,
    load_from_cache_file=False,
)

In [None]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(args['model'])

In [None]:
from transformers import TrainingArguments, Trainer
import os

training_args = TrainingArguments(
    report_to='tensorboard',
    output_dir=args['save_dir'],
    overwrite_output_dir=False,

    do_train=True,
    save_strategy='epoch',

    num_train_epochs=args['epochs'],
    per_device_train_batch_size=args['batch_size_per_replica'],
    gradient_accumulation_steps=args['grad_acc_steps'],

    learning_rate=args['lr'],
    weight_decay=0.05,
    warmup_steps=args['lr_warmup_steps'],

    logging_dir=args['save_dir'],
    logging_first_step=True,
    logging_steps=args['log_freq'],
    save_total_limit=1,

    dataloader_drop_last=True,
    dataloader_num_workers=4,

    local_rank=args['local_rank'],
    deepspeed=args['deepspeed'],
    fp16=args['fp16'],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
)

trainer.train()

if args['local_rank'] in [0, -1]:
    model.save_pretrained(final_checkpoint_dir)