### Imports

In [1]:
#!pip install transformers datasets accelerate evaluate tensorboard nnsight wandb
#!wandb login
#!huggingface-cli login

In [1]:
from tqdm.auto import tqdm
import os

from accelerate import Accelerator
from accelerate.utils import set_seed

import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss

from transformers import (
    AutoTokenizer, default_data_collator,
    get_scheduler, AutoModelForCausalLM,
    AutoConfig, GenerationConfig,
)

from datasets import load_dataset

2024-10-17 20:51:25.258160: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Hyperparameters

In [2]:
# causal model
# longest command is 9 words : https://arxiv.org/pdf/1711.00350
max_len = 9
dummy_token = "<empty>"

# command type maps
actions = {
    "walk": "I_WALK",
    "run": "I_RUN",
    "jump": "I_JUMP",
    "look": "I_LOOK",
    "turn": dummy_token,
    dummy_token: dummy_token,
    }

turns = {
    "around": "yyyy",
    "opposite": "yy",
    dummy_token: dummy_token
}

directions = {
    "right": "I_TURN_RIGHT",
    "left": "I_TURN_LEFT",
    dummy_token: dummy_token
}

nums = {
    "twice": "xx",
    "thrice": "xxx",
    dummy_token: dummy_token
}

conjs = ["and", "after", dummy_token]

# command structure
command_structure = {
    0: actions,
    1: turns,
    2: directions,
    3: nums,
    4: conjs,
    5: actions,
    6: turns,
    7: directions,
    8: nums,
}

# seed
seed = 42

# model, tokenizer
# need to load a 'half-trained' model
#model_name_or_path = 'distilbert/distilgpt2'
#model_name_or_path = '/root/caricatures/models/distilgpt2_40k'
model_name_or_path = '/users/ujan/caricatures/models/scan/distilgpt2_40k'
#model_name_or_path = '/home/drdo/Caricatures/models/scan_distilgpt2/checkpoint-40000'
special_tokens_dict = {
    "pad_token": "<pad>",
    "sep_token": "<sep>",
}

# dataset
dataset = "scan"
# 'simple', 'addprim_jump', 'addprim_turn_left', 'filler_num0', 
# 'filler_num1', 'filler_num2', 'filler_num3', 'length', 
# 'template_around_right', 'template_jump_around_right', 
# 'template_opposite_right', 'template_right'
dataset_config = "simple"
validation_split = 0.1
max_source_length = 512
max_target_length = 512
max_gen_length = 256

# training
output_dir = '/users/ujan/caricatures/models/scan/scan_distilgpt2_reinforce'
#output_dir = '/home/drdo/Caricatures/models/distilgpt2_reinforce'
num_workers = os.cpu_count()  # 1, None, 32/users/ujan/caricatures/models/scan/scan_distilgpt2_reinforce
per_device_train_batch_size = 4 # 64
per_device_eval_batch_size = 4  # 64
train_steps = 100000
warmup_steps = 0
gradient_accumulation_steps = 1
eval_steps = 5000
lr = 5e-5
weight_decay = 0.0
lr_scheduler_type = 'linear'
mixed_precision = 'no'
num_beams = 1
report_to = 'wandb'

### Seed and Trainer

In [3]:
set_seed(seed)

accelerator_log_kwargs = {}
accelerator_log_kwargs["log_with"] = report_to
accelerator_log_kwargs["project_dir"] = output_dir
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, **accelerator_log_kwargs)

### Dataset

In [4]:
raw_datasets = load_dataset(dataset, dataset_config, trust_remote_code=True)

# split train set into train and validation
train_val_split = raw_datasets['train'].train_test_split(test_size=validation_split, seed=seed)
raw_datasets['train'] = train_val_split['train']
raw_datasets['validation'] = train_val_split['test']

column_names = raw_datasets["train"].column_names
input_column = column_names[0]
output_column = column_names[1]

# format dataset with dummy tokens
special_tokens_dict["additional_special_tokens"] = [dummy_token]

def add_empty_token(x):
    command_str = x[input_column]
    command = command_str.split()
    padded_command = []
    index = 0
    c = 0
    while index < max_len:
        expected_cs = command_structure[index]
        if c < len(command) and command[c] in expected_cs:
            padded_command.append(command[c])
            c += 1
        else:
            padded_command.append(dummy_token)
        index += 1
    
    x[input_column] = ' '.join(padded_command)
    return x

with accelerator.main_process_first():
    raw_datasets["train"] = raw_datasets["train"].map(
        add_empty_token,
        batched=False,
        num_proc=num_workers, 
        desc="Running tokenizer on dataset",
    )
    raw_datasets["validation"] = raw_datasets["validation"].map(
        add_empty_token,
        batched=False,
        num_proc=num_workers,
        desc="Running tokenizer on dataset",
)

### Model and Tokenizer

In [5]:
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, trust_remote_code=True)
tokenizer.add_special_tokens(special_tokens_dict)
# left padding for batch generation
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            config=config,
            trust_remote_code=True,
        )

# Resize the embeddings only when necessary to avoid index errors
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

# Generation config
generation_config = GenerationConfig.from_pretrained(model_name_or_path)
#gen_dict = generation_config.to_dict()
#gen_dict["language"] = model_lang
# reload with new attributes
#generation_config = GenerationConfig.from_dict(gen_dict)
#max_gen_length = model.config.max_length
#num_beams = args.num_beams if args.num_beams is not None else model.config.num_beams
gen_kwargs = {"max_new_tokens": max_gen_length, "num_beams": num_beams}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Preprocess Dataset

In [6]:
# preprocess dataset
def preprocess_function(examples):
    # commands, actions
    inputs = examples[input_column]
    targets = examples[output_column]

    # tokenize as single sequence separated by special token
    model_inputs = tokenizer(
        [i+tokenizer.sep_token for i in inputs],
        padding='max_length', max_length=max_source_length
    )
    # labels same as inputs. labels shifted right in the model forward by default
    model_inputs['labels'] = tokenizer(
        [t+tokenizer.eos_token for t in targets],
        padding='max_length', max_length=max_source_length
    )['input_ids']
    # set label padding to -100 
    #model_inputs['labels'] = [
        #[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in model_inputs['labels']
    #]

    return model_inputs


with accelerator.main_process_first():
    train_dataset = raw_datasets["train"].map(
        preprocess_function,
        batched=True,
        num_proc=num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )
    eval_dataset = raw_datasets["validation"].map(
        preprocess_function,
        batched=True,
        num_proc=num_workers,
        remove_columns=column_names,
        desc="Running tokenizer on dataset",
    )

### Dataloaders

In [7]:
# data collator and loaders
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=per_device_train_batch_size
)
eval_dataloader = DataLoader(
    eval_dataset, collate_fn=default_data_collator, batch_size=per_device_eval_batch_size
)

### Optimizer and Scheduler

In [8]:
# prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)

# scheduler
lr_scheduler = get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=warmup_steps * accelerator.num_processes,
    num_training_steps=train_steps * accelerator.num_processes,
)

### Accelerator

In [9]:
# prepare everything for accelerator
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

### Eval

In [None]:
eval_bar = tqdm(range(len(eval_dataloader)), position=0)
accuracy = 0

model.eval()

for batch in eval_dataloader:
    with torch.no_grad():
        output_ids = accelerator.unwrap_model(model).generate(
            **batch,
            generation_config=generation_config,
            **gen_kwargs
        )

    # pad_acrss_processes to get equal length for each processs
    output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
    label_ids = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
    # gather
    output_ids = accelerator.gather(output_ids) 
    label_ids = accelerator.gather(label_ids)  
    # decode
    batch_output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    batch_input = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
    outputs = [batch_output[b].replace(batch_input[b], '') for b in range(len(batch_output))]
    labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    # compute accuracy
    acc = [o==l for o, l in zip(outputs, labels)]
    accuracy += sum(acc)/len(acc)

    eval_bar.update(1)

print(accuracy/len(eval_dataloader))


### Train

In [10]:
# right padding for logits
tokenizer.padding_side = "right"
ignore_index = -100


# re-tokenize left padded sequences need for batch generation to right padded sequences
def re_tokenize(token_ids):
    tokens = tokenizer.batch_decode(token_ids, skip_special_tokens=False)
    tokens = [o.replace(tokenizer.pad_token, '') for o in tokens]
    tokens = [o.replace(tokenizer.eos_token, '') for o in tokens]
    tokenized_tokens = tokenizer(
        tokens,
        padding='max_length',
        max_length=max_source_length,
        return_tensors='pt',
    ).to(model.device)
    input_ids = tokenized_tokens['input_ids']
    attention_mask = tokenized_tokens['attention_mask']
    return input_ids, attention_mask


def prepare_input_for_rl_step(output_ids, gen_label_ids):

    generated_ids, attention_mask = re_tokenize(output_ids)
    gen_label_ids, _ = re_tokenize(gen_label_ids) 
    # context labels needed for ce loss for context
    # get only context labels
    all_tokens = tokenizer.batch_decode(generated_ids)
    context_tokens = [t.split(tokenizer.sep_token)[0] for t in all_tokens]
    tokenized_context = tokenizer(
        [c+tokenizer.sep_token for c in context_tokens],
        padding='max_length',
        max_length=max_source_length,
        return_tensors='pt',
    ).to(model.device)
    context_label_ids = tokenized_context['input_ids']
    # set context label padding to -100 
    context_label_ids = [
        [
            (l if l != tokenizer.pad_token_id else ignore_index) for l in label
        ] for label in context_label_ids.tolist()
    ]
    context_label_ids = torch.tensor(context_label_ids).to(model.device)

    return generated_ids, attention_mask, gen_label_ids, context_label_ids
    
    
def reward_function(output_ids, gen_label_ids):
    # decode output
    output_tokens = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    output_tokens = [
        o.replace(tokenizer.pad_token, '').split(tokenizer.sep_token)[1] for o in output_tokens
    ]
    # decode labels
    label_tokens = tokenizer.batch_decode(gen_label_ids, skip_special_tokens=False)
    label_tokens = [l.replace(tokenizer.pad_token, '') for l in label_tokens]
    # compute reward(=accuracy)
    reward = [o==l for o, l in zip(output_tokens, label_tokens)]
    reward = torch.tensor(reward, dtype=torch.float32).to(model.device)
    return reward


def loss_function(logits, context_label_ids, gen_label_ids, attention_mask, reward):

    # ce loss for context
    # shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_context_labels = context_label_ids[..., 1:].contiguous()
    loss_fct = CrossEntropyLoss(ignore_index=ignore_index)
    context_loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)), shift_context_labels.view(-1)
    )

    # construct context mask
    context_mask = torch.ones_like(context_label_ids)
    context_mask[context_label_ids != ignore_index] = 0
    # apply mask on logits
    print(logits)
    logits[context_label_ids != ignore_index] = 0
    print(logits)
    raise

    # construct end padding mask -> attention_mask


    print(context_loss)
    print(reward.shape)
    print(logits.shape)
    raise

    # reinforce loss
    # TODO:
    
    
def reinforce_step(generated_ids, attention_mask, gen_label_ids, context_label_ids):
    # calculate reward 'to go' (reinforce)
    reward = reward_function(generated_ids, gen_label_ids)
    # model forward
    logits = model(input_ids=generated_ids, attention_mask=attention_mask).logits
    # compute loss
    loss_function(logits, context_label_ids, gen_label_ids, attention_mask, reward)
    raise

In [11]:
global_step = 0  # tracks total steps

progress_bar = tqdm(range(global_step, train_steps), disable=not accelerator.is_main_process, position=0)
# eval bar
eval_bar = tqdm(range(len(eval_dataloader)), position=1)

while True:
    for batch in train_dataloader:
        model.train()
        with torch.no_grad():
            output_ids = accelerator.unwrap_model(model).generate(
                **batch,
                generation_config=generation_config,
                **gen_kwargs
            )

        # pad_acrss_processes to get equal length for each processs
        output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
        label_ids = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

        # gather
        output_ids = accelerator.gather(output_ids) 
        label_ids = accelerator.gather(label_ids)  

        # re-tokenize for rl step
        # generated_ids -> context ids + generated action ids
        # attention mask -> attention mask for generated_ids
        # gen_label_ids -> generated action ids
        # context_label_ids -> context ids, needed to compute ce loss for context
        generated_ids, attention_mask, gen_label_ids, context_label_ids = prepare_input_for_rl_step(output_ids, label_ids)

        # reinforce
        reinforce_step(generated_ids, attention_mask, gen_label_ids, context_label_ids)


        

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/419 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[[-1.5630e+01, -1.4854e+01, -1.4864e+01,  ...,  3.2617e-01,
          -5.9968e-02, -4.8850e-02],
         [-8.3999e+01, -8.2473e+01, -8.6874e+01,  ...,  5.1035e-01,
           4.3946e-03,  5.5931e-02],
         [-1.1466e+01, -1.1571e+01, -1.2388e+01,  ...,  5.8126e-03,
          -6.7820e-02,  8.5390e-02],
         ...,
         [ 2.7513e+01,  2.7691e+01,  2.9968e+01,  ..., -2.4414e-01,
           2.1287e-01,  3.2635e-02],
         [ 2.2583e+01,  2.2783e+01,  2.5139e+01,  ..., -1.4681e-01,
          -7.4961e-02,  5.7500e-02],
         [ 1.2865e+01,  1.3360e+01,  1.5330e+01,  ...,  3.3217e-01,
           3.2817e-02, -5.1834e-02]],

        [[-1.4374e+01, -1.3922e+01, -1.3552e+01,  ..., -2.0467e-01,
           8.7158e-02, -2.7965e-02],
         [-8.7493e+00, -8.3984e+00, -9.6252e+00,  ...,  1.8210e-01,
           5.8279e-02, -1.8418e-02],
         [-1.4499e+01, -1.4457e+01, -1.4552e+01,  ...,  2.0997e-01,
           1.0126e-01,  5.3064e-02],
         ...,
         [ 2.8604e+01,  3

RuntimeError: No active exception to reraise

In [73]:
tokenizer.decode([15344,   220, 50259,   826,   220, 50259,   706,  1210,  1088,  1364, 5636,   501, 50258])

'turn <empty> right <empty> after turn around left thrice<sep>'

In [1]:
import torch
a = torch.tensor([[0,0,1,1], [0,0,0,1], [0,1,1,1]])
a

tensor([[0, 0, 1, 1],
        [0, 0, 0, 1],
        [0, 1, 1, 1]])

In [6]:
b = torch.ones(3,4,2)*5
b

tensor([[[5., 5.],
         [5., 5.],
         [5., 5.],
         [5., 5.]],

        [[5., 5.],
         [5., 5.],
         [5., 5.],
         [5., 5.]],

        [[5., 5.],
         [5., 5.],
         [5., 5.],
         [5., 5.]]])

In [12]:
torch.stack([a,a]).shape

torch.Size([2, 3, 4])

In [9]:
torch.mul(c,b)

tensor([[[0., 0.],
         [0., 0.],
         [5., 5.],
         [5., 5.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.],
         [5., 5.]],

        [[0., 0.],
         [5., 5.],
         [5., 5.],
         [5., 5.]]])