In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
# %% [markdown]
# Instead of using the complex TRL we code it from scratch, using lighting
# 
# https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

# %%
from pathlib import Path

# ML
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange, reduce, repeat
from jaxtyping import Float, Int, Bool
from torch.utils.data import DataLoader

import wandb

# Numeric
import numpy as np
import pandas as pd
# from matplotlib import pyplot as plt

# lightning
import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.loggers.csv_logs import CSVLogger


# %%
# Local
from reprpo.helpers.torch import clear_mem
from reprpo.gen import generation_test
import reprpo.silence
from reprpo.helpers.lightning_hist import read_metrics_csv, plot_hist


from reprpo.train.dpo import compute_dpo_loss_batch, PL_DPO_MODEL

# %%
torch.set_float32_matmul_precision("high")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"



# %%
from simple_parsing import ArgumentParser
from dataclasses import dataclass

In [3]:

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from reprpo.train.lightning import TrainingArguments
args = TrainingArguments()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

In [3]:


@dataclass
class CLIArguments:
    method: str = 'dpo' # reprpo_svd # reprpo_side
    # dataset: str = 'code_easy'
    dataset: str = 'us_history_textbook'
    verbose: bool = False
    dev: bool = False


parser = ArgumentParser()
parser.add_arguments(CLIArguments, dest='cli')
# parser.add_argument('-m', '--method', type=str, default='dpo', help='dpo, reprpo_svd, reprpo_side')
# parser.add_argument('-d', '--dataset', type=str, default='code_easy', help='code_easy etc see subsets in https://huggingface.co/datasets/wassname/genie_dpo')
# parser.add_argument('-v', '--verbose', type=bool, default=False, action="store_true", help='print dataset')
# parser.add_argument('--dev', type=bool, default=False, action="store_true", help='fast dev run')
args1 = parser.parse_known_args([])[0].cli


if args1.method == 'dpo':
    from reprpo.train.dpo import DPOTrainingArguments as TrainingArguments, PL_DPO_MODEL as PL_MODEL
elif args1.method == 'reprpo_svd':
    from reprpo.train.reprpo_svd import ReprPOSVDTrainingArguments as TrainingArguments, PL_REPRPO_SVD_MODEL as PL_MODEL
elif args1.method == 'reprpo_side':
    from reprpo.train.reprpo_side import ReprPOSideInTrainingArguments as TrainingArguments, PL_REPRPO_SIDE_MODEL as PL_MODEL
else:
    raise ValueError(f"method {args1.method} not found. options: `reprpo_side`, `dpo`, `reprpo_svd`")

parser.add_arguments(TrainingArguments, dest='args')
# parser.add_arguments(CLIArguments, dest='cli')
args2 = parser.parse_args([])
args = TrainingArguments(**args2.args.__dict__)
print(f"args = {args}")

# %% [markdown]
# ## Load model

ts = pd.Timestamp.now().strftime("%Y-%m-%d_%H-%M-%S")
run_fname = f'{args1.dataset}/{args.adapter_name}/{ts}'
from reprpo.helpers.wandb import init_wandb
wandb.require(experiment='service')
from peft import LoraConfig, get_peft_model
from reprpo.models.load import load_model, print_trainable_parameters


model, tokenizer = load_model(args.model_name, load_in_4bit=args.load_in_4bit,  load_in_8bit=args.load_in_8bit,  
                              attn_implementation='eager' # for gemma
)



args = DPOTrainingArguments(model_name='NousResearch/Meta-Llama-3.1-8B-Instruct', load_in_4bit=True, load_in_8bit=True, use_gradient_checkpointing=False, batch_size=12, lr=0.0001, weight_decay=0.0, n_samples=3600, max_length=256, max_prompt_length=64, adapter_name='dpo')


`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:

# %%
from datasets import load_dataset
dataset2 = load_dataset("wassname/genie_dpo", name=args1.dataset)



# %% [markdown]
# ### Data Loader
# 
# We use huggingface datasets, which are pretokenized. So that we can stack

# %%


# %%
# from reprpo.data.collate import DPODataCollatorWithPadding, tokenize_row
from reprpo.data.collate3 import TokenizeRow
tokenize_row = TokenizeRow(tokenizer, max_length=args.max_length, max_prompt_length=args.max_prompt_length)



In [7]:
tokenizer.get_special_tokens_mask?

[0;31mSignature:[0m
[0mtokenizer[0m[0;34m.[0m[0mget_special_tokens_mask[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtoken_ids_0[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtoken_ids_1[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0malready_has_special_tokens[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mList[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.

Args:
    token_ids_0 (`List[int]`):
        List of ids of the first sequence.
    token_ids_1 (`List[int]`, *optional*):
        List 

In [19]:
r = dataset2['test'][0]
r2 = tokenize_row(r)

In [21]:
len(r2['prompt'])==args.max_prompt_length

False

In [17]:
print(r['prompt'])
print('-'*20)
print(tokenizer.decode(r2['prompt']))


Below is an instruction that describes a task, paired with an input that provides further context. Complete the request to the best of your ability.

### Instruction:
Predict the next few sentences of the following excerpt from a high-quality US History textbook. 

### Input:
The Louisiana Purchase occurred in 1803 and doubled the size of the United States. The purchase was negotiated by ...

### Response:

--------------------
 provides further context. Complete the request to the best of your ability.

### Instruction:
Predict the next few sentences of the following excerpt from a high-quality US History textbook. 

### Input:
The Louisiana Purchase occurred in 1803 and doubled the size of the United States. The purchase was negotiated by...

### Response


In [18]:
print(r['chosen'])
print('-'*20)
print(tokenizer.decode(r2['chosen']))

President Thomas Jefferson and his administration. The Louisiana territory was purchased from France for $15 million, acquiring land that stretched from the Mississippi River to the Rocky Mountains. This acquisition paved the way for further westward expansion and the eventual settlement of much of the American West.
--------------------
<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|>

In [15]:
print(tokenizer.decode(r2['rejected'])

SyntaxError: incomplete input (591574286.py, line 1)

In [None]:
dataset3 = dataset2.map(tokenize_row, batched=False)

In [1]:
tokenizer.build_inputs_with_special_tokens

NameError: name 'tokenizer' is not defined

In [None]:


ds = dataset3
dl_train = DataLoader(ds['train'], batch_size=args.batch_size, 
                    #   collate_fn=custom_collate_fn
                      )

dl_val = DataLoader(ds['test'], batch_size=args.batch_size
                    # , collate_fn=custom_collate_fn
                    )

if args1.verbose:

    print('QC one dataset row')
    r = dataset2['train'][0]
    print(r['prompt'])
    print('===')
    print(r['chosen'])
    print('---')
    print(r['rejected'])
    print()
    print()

    print('QC one train batch (after pad/crop')
    batch = next(iter(dl_train))
    print(batch.keys())
    print(tokenizer.decode(batch['prompt'][0]))
    print('===')
    print(tokenizer.decode(batch['chosen'][0]))
    print('---')
    print(tokenizer.decode(batch['rejected'][0]))
    print()
    print()
