### Imports

In [1]:
from tqdm.auto import tqdm
import os

from accelerate import Accelerator
from accelerate.utils import set_seed

import torch
from torch.utils.data import DataLoader

from transformers import (
    AutoTokenizer, default_data_collator,
    get_scheduler, AutoModelForCausalLM,
    AutoConfig, GenerationConfig,
)

from datasets import load_dataset

### Hyperparameters

In [11]:
# causal model
# longest command is 9 words : https://arxiv.org/pdf/1711.00350
max_len = 9
dummy_token = "<empty>"

# command type maps
actions = {
    "walk": "I_WALK",
    "run": "I_RUN",
    "jump": "I_JUMP",
    "look": "I_LOOK",
    "turn": dummy_token,
    dummy_token: dummy_token,
    }

turns = {
    "around": "yyyy",
    "opposite": "yy",
    dummy_token: dummy_token
}

directions = {
    "right": "I_TURN_RIGHT",
    "left": "I_TURN_LEFT",
    dummy_token: dummy_token
}

nums = {
    "twice": "xx",
    "thrice": "xxx",
    dummy_token: dummy_token
}

conjs = ["and", "after", dummy_token]

# command structure
command_structure = {
    0: actions,
    1: turns,
    2: directions,
    3: nums,
    4: conjs,
    5: actions,
    6: turns,
    7: directions,
    8: nums,
}

# seed
seed = 42

# model, tokenizer
model_name_or_path = 'distilbert/distilgpt2'
special_tokens_dict = {
    "pad_token": "<pad>",
    "sep_token": "<sep>",
}

# dataset
dataset = "scan"
# 'simple', 'addprim_jump', 'addprim_turn_left', 'filler_num0', 
# 'filler_num1', 'filler_num2', 'filler_num3', 'length', 
# 'template_around_right', 'template_jump_around_right', 
# 'template_opposite_right', 'template_right'
dataset_config = "simple"
validation_split = 0.1
max_source_length = 512
max_target_length = 512
max_gen_length = 256

# training
output_dir = '/root/Caricatures/models/scan_distilgpt2_reinforce'
num_workers = os.cpu_count()  # 1, None, 32
per_device_train_batch_size = 16
per_device_eval_batch_size = 16
train_steps = 100000
warmup_steps = 0
gradient_accumulation_steps = 1
eval_steps = 5000
lr = 5e-5
weight_decay = 0.0
lr_scheduler_type = 'linear'
mixed_precision = 'no'
num_beams = 1
report_to = 'wandb'

### Seed and Trainer

In [12]:
set_seed(seed)

accelerator_log_kwargs = {}
accelerator_log_kwargs["log_with"] = report_to
accelerator_log_kwargs["project_dir"] = output_dir
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)

### Dataset

In [13]:
raw_datasets = load_dataset(dataset, dataset_config, trust_remote_code=True)

# split train set into train and validation
train_val_split = raw_datasets['train'].train_test_split(test_size=validation_split, seed=seed)
raw_datasets['train'] = train_val_split['train']
raw_datasets['validation'] = train_val_split['test']

column_names = raw_datasets["train"].column_names
input_column = column_names[0]
output_column = column_names[1]

# format dataset with dummy tokens
special_tokens_dict["additional_special_tokens"] = [dummy_token]

def add_empty_token(x):
    command_str = x[input_column]
    command = command_str.split()
    padded_command = []
    index = 0
    c = 0
    while index < max_len:
        expected_cs = command_structure[index]
        if c < len(command) and command[c] in expected_cs:
            padded_command.append(command[c])
            c += 1
        else:
            padded_command.append(dummy_token)
        index += 1
    
    x[input_column] = ' '.join(padded_command)
    return x

with accelerator.main_process_first():
    raw_datasets["train"] = raw_datasets["train"].map(
        add_empty_token,
        batched=False,
        num_proc=num_workers, 
        desc="Running tokenizer on dataset",
    )
    raw_datasets["validation"] = raw_datasets["validation"].map(
        add_empty_token,
        batched=False,
        num_proc=num_workers,
        desc="Running tokenizer on dataset",
)

### Model and Tokenizer

In [14]:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, trust_remote_code=True)
tokenizer.add_special_tokens(special_tokens_dict)
model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            config=config,
            trust_remote_code=True,
        )

# Resize the embeddings only when necessary to avoid index errors
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

# Generation config
generation_config = GenerationConfig.from_pretrained(model_name_or_path)
#gen_dict = generation_config.to_dict()
#gen_dict["language"] = model_lang
# reload with new attributes
#generation_config = GenerationConfig.from_dict(gen_dict)
#max_gen_length = model.config.max_length
#num_beams = args.num_beams if args.num_beams is not None else model.config.num_beams
gen_kwargs = {"max_new_tokens": max_gen_length, "num_beams": num_beams}



### Preprocess Dataset

In [15]:
# preprocess dataset
def preprocess_function(examples):
    # commands, actions
    inputs = examples[input_column]
    targets = examples[output_column]

    # tokenize as single sequence separated by special token
    model_inputs = tokenizer(
        [i+tokenizer.sep_token for i in inputs],
        [t+tokenizer.eos_token for t in targets],
        padding='max_length', max_length=max_source_length
    )
    # labels same as inputs. labels shifted right in the model forward by default
    model_inputs['labels'] = model_inputs['input_ids'].copy()
    # set label padding to -100 
    model_inputs['labels'] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs['labels']
    ]

    return model_inputs


with accelerator.main_process_first():
    train_dataset = raw_datasets["train"].map(
        preprocess_function,
        batched=True,
        num_proc=num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )
    eval_dataset = raw_datasets["validation"].map(
        preprocess_function,
        batched=True,
        num_proc=num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )

Running tokenizer on dataset (num_proc=32):   0%|          | 0/15055 [00:00<?, ? examples/s]

Running tokenizer on dataset (num_proc=32):   0%|          | 0/1673 [00:00<?, ? examples/s]

### Dataloaders

In [16]:
# data collator and loaders
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=per_device_train_batch_size
)
eval_dataloader = DataLoader(
    eval_dataset, collate_fn=default_data_collator, batch_size=per_device_eval_batch_size
)

### Optimizer and Scheduler

In [17]:
# prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

# scheduler
lr_scheduler = get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=warmup_steps * accelerator.num_processes,
    num_training_steps=train_steps * accelerator.num_processes,
)

### Accelerator

In [18]:
# prepare everything for accelerator
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

### Train

In [21]:
global_step = 0  # tracks total steps

progress_bar = tqdm(range(global_step, train_steps), disable=not accelerator.is_main_process, position=0)
# eval bar
eval_bar = tqdm(range(len(eval_dataloader)), position=1)

while True:

    model.train()

    for batch in train_dataloader:
        output_ids = accelerator.unwrap_model(model).generate(
            **batch,
            generation_config=generation_config,
            **gen_kwargs
        )

        # pad_acrss_processes to get equal length for each processs
        output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
        label_ids = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

        output_ids = accelerator.gather(output_ids)  #.cpu().numpy()  # gather_for_metrics
        label_ids = accelerator.gather(label_ids)  #.cpu().numpy()  # gather_for_metrics

        print(output_ids.shape)
        print(label_ids.shape)
        raise

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/105 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([16, 768])
torch.Size([16, 512])


RuntimeError: No active exception to reraise