# Finetune a GPT model
* [1. Import libraries](#heading1)
* [2. Prepare dataset](#heading2)
* [3. Load GPT model](#heading3)
    * [3.1. Pretrained GPT2 model](#heading3-1)
    * [3.2. Load checkpoint](#heading3-2)
* [4. Training](#heading4)
* [5. Generate text](#heading5)

# 1. Import libraries <a class="anchor" id="heading1"></a>

In [None]:
import jadegpt

# 2. Prepare dataset <a class="anchor" id="heading2"></a>

In [None]:
# load data
input_dir = 'C:\\data'
data_file_name = 'input.txt'

data = jadegpt.open_dataset_file(input_dir + '\\' + data_file_name)

In [None]:
# split data
split = 0.9

train_data, val_data = jadegpt.split_dataset(data, split)

In [None]:
# encode and export datasets to files
use_gpt2_encoding = True # True: use gpt encoding; False: use custom encoding
data_dir = 'C:\\data'
train_file_name = 'train.bin'
val_file_name = 'val.bin'
meta_file_name = 'meta.pkl'

jadegpt.export_data_to_files(data, train_data, val_data, use_gpt2_encoding, data_dir, train_file_name, val_file_name, meta_file_name)

# 3. Load GPT model <a class="anchor" id="heading3"></a>

## 3.1. Load from a pretrain GPT2 model  <a class="anchor" id="heading3-1"></a>

In [None]:
# choose gpt2 model
gpt2_model = 'gpt2' # 'gpt2', 'gpt2-medium', 'gpt2-large', or 'gpt2-xl'
# random seed
random_seed = 1337

In [None]:
# initialize the model
model = jadegpt.init_gpt2(gpt2_model, random_seed)

## 3.2. Load a checkpoint <a class="anchor" id="heading3-2"></a>

In [None]:
# load model
model_dir = 'C:\\model'
model_file_name = 'model-100.ckpt'

# random seed
random_seed = 1337
# choose device
device = 'cuda' # 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc.

In [None]:
# resume the model
model = jadegpt.resume_gpt(model_dir + '\\' + model_file_name, random_seed, device)

# 4. Fine-tuning <a class="anchor" id="heading4"></a>

In [None]:
# load data files to memory-map
train_data = jadegpt.load_data_file_to_memmap(data_dir, train_file_name)
val_data = jadegpt.load_data_file_to_memmap(data_dir, val_file_name)

In [None]:
# training parameters
# training
batch_size = 8
block_size = 32
gradient_accumulation_steps = 5
device = 'cuda' # 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc.
dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16'
# evaluation
eval_interval = 50
eval_iters = 20
log_interval = 10
# adamw optimizer
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 100
decay_lr = True # whether to decay the learning rate
warmup_iters = 10 # not super necessary potentially
lr_decay_iters = max_iters # make equal to max_iters usually
min_lr = learning_rate / 10.0 # learning_rate / 10 usually
# saving checkpoint
only_save_on_finish = False
save_interval = 50
model_dir = 'C:\\model'
model_name = 'model'

In [None]:
# fine-tuning
jadegpt.train_gpt(model, dtype, device, train_data, val_data, block_size, batch_size,\
                  max_iters, weight_decay, learning_rate, beta1, beta2, warmup_iters,\
                  lr_decay_iters, min_lr, decay_lr, eval_interval, eval_iters,\
                  gradient_accumulation_steps, grad_clip, log_interval,\
                  only_save_on_finish, save_interval, model_dir, model_name)

## 5. Generate text from fine-tuned model <a class="anchor" id="heading5"></a>

In [None]:
# prompt
prompt = "hello world"

In [None]:
# configuration
meta_dir = 'C:\\data'
meta_file_name = 'meta.pkl'
num_samples = 3 # number of samples to draw
max_new_tokens = 100 # number of tokens generated in each sample
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 20 # retain only the top_k most likely tokens, clamp others to have 0 probability
device = 'cuda' # 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc.
dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16'

In [None]:
# generate text
jadegpt.generate_text(model, prompt, use_gpt2_encoding, meta_dir, meta_file_name, num_samples, max_new_tokens, temperature, top_k, device, dtype)