### Imports

In [1]:
#!pip install transformers datasets accelerate evaluate tensorboard nnsight wandb
#!wandb login
#!huggingface-cli login

In [1]:
from tqdm.auto import tqdm
import os

from accelerate import Accelerator
from accelerate.utils import set_seed

import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

from transformers import (
    AutoTokenizer, default_data_collator,
    get_scheduler, AutoModelForCausalLM,
    AutoConfig, GenerationConfig,
)

from datasets import load_dataset

### Hyperparameters

In [2]:
# causal model
# longest command is 9 words : https://arxiv.org/pdf/1711.00350
max_len = 9
dummy_token = "<empty>"

# command type maps
actions = {
    "walk": "I_WALK",
    "run": "I_RUN",
    "jump": "I_JUMP",
    "look": "I_LOOK",
    "turn": dummy_token,
    dummy_token: dummy_token,
    }

turns = {
    "around": "yyyy",
    "opposite": "yy",
    dummy_token: dummy_token
}

directions = {
    "right": "I_TURN_RIGHT",
    "left": "I_TURN_LEFT",
    dummy_token: dummy_token
}

nums = {
    "twice": "xx",
    "thrice": "xxx",
    dummy_token: dummy_token
}

conjs = ["and", "after", dummy_token]

# command structure
command_structure = {
    0: actions,
    1: turns,
    2: directions,
    3: nums,
    4: conjs,
    5: actions,
    6: turns,
    7: directions,
    8: nums,
}

# seed
seed = 42

# model, tokenizer
# need to load a 'half-trained' model
#model_name_or_path = 'distilbert/distilgpt2'
model_name_or_path = '/root/Caricatures/models/distilgpt2_40k'
#model_name_or_path = '/users/ujan/caricatures/models/scan/distilgpt2_40k'
#model_name_or_path = '/home/drdo/Caricatures/models/scan_distilgpt2/checkpoint-40000'
special_tokens_dict = {
    "pad_token": "<pad>",
    "sep_token": "<sep>",
}

# dataset
dataset = "scan"
# 'simple', 'addprim_jump', 'addprim_turn_left', 'filler_num0', 
# 'filler_num1', 'filler_num2', 'filler_num3', 'length', 
# 'template_around_right', 'template_jump_around_right', 
# 'template_opposite_right', 'template_right'
dataset_config = "simple"
validation_split = 0.1
max_source_length = 512
max_target_length = 512
max_gen_length = 256

# training
#output_dir = '/users/ujan/caricatures/models/scan/scan_distilgpt2_reinforce'
#output_dir = '/home/drdo/Caricatures/models/distilgpt2_reinforce'
output_dir = '/root/Caricatures/models/distilgpt2_reinforce'
num_workers = os.cpu_count()  # 1, None, 32
per_device_train_batch_size = 4 # 64
per_device_eval_batch_size = 4  # 64
train_steps = 100000
warmup_steps = 0
gradient_accumulation_steps = 1
eval_steps = 5000
lr = 5e-5
weight_decay = 0.0
lr_scheduler_type = 'linear'
mixed_precision = 'no'
num_beams = 1
report_to = 'wandb'
EPSILON = 1e-20

### Seed and Trainer

In [3]:
set_seed(seed)

accelerator_log_kwargs = {}
accelerator_log_kwargs["log_with"] = report_to
accelerator_log_kwargs["project_dir"] = output_dir
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)

### Dataset

In [4]:
raw_datasets = load_dataset(dataset, dataset_config, trust_remote_code=True)

# split train set into train and validation
train_val_split = raw_datasets['train'].train_test_split(test_size=validation_split, seed=seed)
raw_datasets['train'] = train_val_split['train']
raw_datasets['validation'] = train_val_split['test']

column_names = raw_datasets["train"].column_names
input_column = column_names[0]
output_column = column_names[1]

# format dataset with dummy tokens
special_tokens_dict["additional_special_tokens"] = [dummy_token]

def add_empty_token(x):
    command_str = x[input_column]
    command = command_str.split()
    padded_command = []
    index = 0
    c = 0
    while index < max_len:
        expected_cs = command_structure[index]
        if c < len(command) and command[c] in expected_cs:
            padded_command.append(command[c])
            c += 1
        else:
            padded_command.append(dummy_token)
        index += 1
    
    x[input_column] = ' '.join(padded_command)
    return x

with accelerator.main_process_first():
    raw_datasets["train"] = raw_datasets["train"].map(
        add_empty_token,
        batched=False,
        num_proc=num_workers, 
        desc="Running tokenizer on dataset",
    )
    raw_datasets["validation"] = raw_datasets["validation"].map(
        add_empty_token,
        batched=False,
        num_proc=num_workers,
        desc="Running tokenizer on dataset",
)

### Model and Tokenizer

In [5]:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, trust_remote_code=True)
tokenizer.add_special_tokens(special_tokens_dict)
# left padding for batch generation
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            config=config,
            trust_remote_code=True,
        )

# Resize the embeddings only when necessary to avoid index errors
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

# Generation config
generation_config = GenerationConfig.from_pretrained(model_name_or_path)
#gen_dict = generation_config.to_dict()
#gen_dict["language"] = model_lang
# reload with new attributes
#generation_config = GenerationConfig.from_dict(gen_dict)
#max_gen_length = model.config.max_length
#num_beams = args.num_beams if args.num_beams is not None else model.config.num_beams
gen_kwargs = {"max_new_tokens": max_gen_length, "num_beams": num_beams}

### Preprocess Dataset

In [6]:
# preprocess dataset
def preprocess_function(examples):
    # commands, actions
    inputs = examples[input_column]
    targets = examples[output_column]

    # tokenize as single sequence separated by special token
    model_inputs = tokenizer(
        [i+tokenizer.sep_token for i in inputs],
        padding='max_length', max_length=max_source_length
    )
    # labels same as inputs. labels shifted right in the model forward by default
    model_inputs['labels'] = tokenizer(
        [t+tokenizer.eos_token for t in targets],
        padding='max_length', max_length=max_source_length
    )['input_ids']
    # set label padding to -100 
    #model_inputs['labels'] = [
        #[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs['labels']
    #]

    return model_inputs


with accelerator.main_process_first():
    train_dataset = raw_datasets["train"].map(
        preprocess_function,
        batched=True,
        num_proc=num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )
    eval_dataset = raw_datasets["validation"].map(
        preprocess_function,
        batched=True,
        num_proc=num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )

### Dataloaders

In [7]:
# data collator and loaders
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=per_device_train_batch_size
)
eval_dataloader = DataLoader(
    eval_dataset, collate_fn=default_data_collator, batch_size=per_device_eval_batch_size
)

### Optimizer and Scheduler

In [8]:
# prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

# scheduler
lr_scheduler = get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=warmup_steps * accelerator.num_processes,
    num_training_steps=train_steps * accelerator.num_processes,
)

### Accelerator

In [9]:
# prepare everything for accelerator
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

### Eval

In [None]:
eval_bar = tqdm(range(len(eval_dataloader)), position=0)
accuracy = 0

model.eval()

for batch in eval_dataloader:
    with torch.no_grad():
        output_ids = accelerator.unwrap_model(model).generate(
            **batch,
            generation_config=generation_config,
            **gen_kwargs
        )

    # pad_acrss_processes to get equal length for each processs
    output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
    label_ids = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
    # gather
    output_ids = accelerator.gather(output_ids) 
    label_ids = accelerator.gather(label_ids)  
    # decode
    batch_output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    batch_input = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
    outputs = [batch_output[b].replace(batch_input[b], '') for b in range(len(batch_output))]
    labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    # compute accuracy
    acc = [o==l for o, l in zip(outputs, labels)]
    accuracy += sum(acc)/len(acc)

    eval_bar.update(1)

print(accuracy/len(eval_dataloader))


### Train

In [10]:
# right padding for logits
tokenizer.padding_side = "right"
ignore_index = -100


# re-tokenize left padded sequences need for batch generation to right padded sequences
def re_tokenize(token_ids):
    tokens = tokenizer.batch_decode(token_ids, skip_special_tokens=False)
    tokens = [o.replace(tokenizer.pad_token, '') for o in tokens]
    tokens = [o.replace(tokenizer.eos_token, '') for o in tokens]
    tokenized_tokens = tokenizer(
        tokens,
        padding='max_length',
        max_length=max_source_length,
        return_tensors='pt',
    ).to(model.device)
    input_ids = tokenized_tokens['input_ids']
    attention_mask = tokenized_tokens['attention_mask']
    return input_ids, attention_mask


def prepare_input_for_rl_step(output_ids, gen_label_ids):

    generated_ids, attention_mask = re_tokenize(output_ids)
    gen_label_ids, _ = re_tokenize(gen_label_ids) 
    # context labels needed for ce loss for context
    # get only context labels
    all_tokens = tokenizer.batch_decode(generated_ids)
    context_tokens = [t.split(tokenizer.sep_token)[0] for t in all_tokens]
    tokenized_context = tokenizer(
        [c+tokenizer.sep_token for c in context_tokens],
        padding='max_length',
        max_length=max_source_length,
        return_tensors='pt',
    ).to(model.device)
    context_label_ids = tokenized_context['input_ids']
    # set context label padding to -100 
    context_label_ids = [
        [
            (l if l != tokenizer.pad_token_id else ignore_index) for l in label
        ] for label in context_label_ids.tolist()
    ]
    context_label_ids = torch.tensor(context_label_ids).to(model.device)

    return generated_ids, attention_mask, gen_label_ids, context_label_ids
    

# returns 1 if accurate, 0 otherwise
# experiment with other rewards
def reward_function(output_ids, gen_label_ids):
    # decode output
    output_tokens = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    output_tokens = [
        o.replace(tokenizer.pad_token, '').split(tokenizer.sep_token)[1] for o in output_tokens
    ]
    # decode labels
    label_tokens = tokenizer.batch_decode(gen_label_ids, skip_special_tokens=False)
    label_tokens = [l.replace(tokenizer.pad_token, '') for l in label_tokens]
    # compute reward(=accuracy)
    reward = [o==l for o, l in zip(output_tokens, label_tokens)]
    reward = torch.tensor(reward, dtype=torch.float32).to(model.device)
    return reward


def logprobs_from_logits(logits, labels):
    # https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
    logp = torch.log(F.softmax(logits, dim=2) + EPSILON)
    logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logpy
    

def loss_function(logits, context_label_ids, gen_label_ids, attention_mask, reward):

    # ce loss for context
    # shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_context_labels = context_label_ids[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(ignore_index=ignore_index)
    context_loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)), shift_context_labels.view(-1)
    )

    # reinforce loss
    # logprobs -> b x seq_len
    logprob = logprobs_from_logits(logits, gen_label_ids)
    # zero out context positions in logits
    logprob[context_label_ids != ignore_index] = 0
    # zero out padding positions in logits
    logprob[attention_mask == 0] = 0
    # reshape reward
    reward = reward.unsqueeze(1).repeat(1,logprob.shape[1])
    reinforce_loss = -logprob * reward

    # total loss
    # zero out context from attention_mask
    attention_mask[context_label_ids != ignore_index] = 0
    reinforce_loss = torch.sum(reinforce_loss) / torch.sum(attention_mask)
    total_loss = context_loss + reinforce_loss

    return total_loss
    
    
def reinforce_step(generated_ids, attention_mask, gen_label_ids, context_label_ids):
    # calculate reward 'to go' (reinforce)
    # reward -> batch_size
    reward = reward_function(generated_ids, gen_label_ids)
    # model forward
    logits = model(input_ids=generated_ids, attention_mask=attention_mask).logits
    # compute loss
    loss = loss_function(logits, context_label_ids, gen_label_ids, attention_mask, reward)
    return loss

In [11]:
global_step = 0  # tracks total steps

progress_bar = tqdm(range(global_step, train_steps), disable=not accelerator.is_main_process, position=0)
# eval bar
eval_bar = tqdm(range(len(eval_dataloader)), position=1)

while True:
    for batch in train_dataloader:
        model.train()
        with torch.no_grad():
            output_ids = accelerator.unwrap_model(model).generate(
                **batch,
                generation_config=generation_config,
                **gen_kwargs
            )

        # pad_acrss_processes to get equal length for each processs
        output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
        label_ids = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

        # gather
        output_ids = accelerator.gather(output_ids) 
        label_ids = accelerator.gather(label_ids)  

        # re-tokenize for rl step
        # generated_ids -> context ids + generated action ids
        # attention mask -> attention mask for generated_ids
        # gen_label_ids -> generated action ids
        # context_label_ids -> context ids, needed to compute ce loss for context
        generated_ids, attention_mask, gen_label_ids, context_label_ids = prepare_input_for_rl_step(output_ids, label_ids)

        # reinforce
        loss = reinforce_step(generated_ids, attention_mask, gen_label_ids, context_label_ids)

        # backprop
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
        

        

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/419 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Settin

KeyboardInterrupt: 

In [None]:
a = torch.rand(4)
a

In [28]:
a.unsqueeze(1).repeat(1,4)

tensor([[0.8854, 0.8854, 0.8854, 0.8854],
        [0.5739, 0.5739, 0.5739, 0.5739],
        [0.2666, 0.2666, 0.2666, 0.2666],
        [0.6274, 0.6274, 0.6274, 0.6274]])

In [29]:
a.unsqueeze(1).repeat(1,4).shape

torch.Size([4, 4])