In [51]:
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
import pandas as pd
import numpy as np
import itertools

from datasets import load_dataset
import torch

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

In [73]:
# set huggingface cache dir to prevent filling up home dir
os.environ['HF_HOME'] = '/net/scratch/shangao/latent-concept/cache'

In [74]:
print(f'available devices: {torch.cuda.device_count()}')
print(f'current device: {torch.cuda.current_device()}')
print(f'device name: {torch.cuda.get_device_name()}')

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("using: " + str(device))

available devices: 1
current device: 0
device name: NVIDIA A100 80GB PCIe
using: cuda


In [75]:
set_seed(1234)

## init tokenizer and model

In [76]:
model_type = 'gpt2'
n_events = 5

In [77]:
# model config is needed for loading tokenizer from json
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L354
config_kwargs = {
    "cache_dir": None,
    "revision": "main",
    "use_auth_token": None,
}
config = CONFIG_MAPPING[model_type]()

In [78]:
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L367
tokenizer_kwargs = {
    "cache_dir": None,
    "use_fast": True,
    "revision": "main",
    "use_auth_token": None,
}
# config = AutoConfig.from_pretrained('gpt2') 
# tokenizer = AutoTokenizer.from_pretrained('gpt2', tokenizer_file="/net/scratch/shangao/latent-concept/data/tokenizer.json", config=AutoConfig.from_pretrained('gpt2'), **tokenizer_kwargs)

# tokenizer = AutoTokenizer.from_pretrained(model_type, tokenizer_file="/net/scratch/shangao/latent-concept/data/tokenizer.json", config=config, **tokenizer_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_type, tokenizer_file=f"/net/scratch/shangao/latent-concept/data/tokenizer_{n_events}.json", config=config, **tokenizer_kwargs)

tokenizer.pad_token = tokenizer.eos_token

In [79]:
tokenizer('[endoftext] / a d c b')
# tokenizer.vocab

{'input_ids': [5, 0, 1, 4, 3, 2], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [80]:
# small_model
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L393
config.vocab_size = tokenizer.vocab_size
config.n_layer = 4
config.n_head = 12

In [81]:
# train a new model from scratch
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L409
model = AutoModelForCausalLM.from_config(config).to(device)
model.resize_token_embeddings(len(tokenizer))

Embedding(7, 768)

In [82]:
for name, param in model.named_parameters():
    print(name)

transformer.wte.weight
transformer.wpe.weight
transformer.h.0.ln_1.weight
transformer.h.0.ln_1.bias
transformer.h.0.attn.c_attn.weight
transformer.h.0.attn.c_attn.bias
transformer.h.0.attn.c_proj.weight
transformer.h.0.attn.c_proj.bias
transformer.h.0.ln_2.weight
transformer.h.0.ln_2.bias
transformer.h.0.mlp.c_fc.weight
transformer.h.0.mlp.c_fc.bias
transformer.h.0.mlp.c_proj.weight
transformer.h.0.mlp.c_proj.bias
transformer.h.1.ln_1.weight
transformer.h.1.ln_1.bias
transformer.h.1.attn.c_attn.weight
transformer.h.1.attn.c_attn.bias
transformer.h.1.attn.c_proj.weight
transformer.h.1.attn.c_proj.bias
transformer.h.1.ln_2.weight
transformer.h.1.ln_2.bias
transformer.h.1.mlp.c_fc.weight
transformer.h.1.mlp.c_fc.bias
transformer.h.1.mlp.c_proj.weight
transformer.h.1.mlp.c_proj.bias
transformer.h.2.ln_1.weight
transformer.h.2.ln_1.bias
transformer.h.2.attn.c_attn.weight
transformer.h.2.attn.c_attn.bias
transformer.h.2.attn.c_proj.weight
transformer.h.2.attn.c_proj.bias
transformer.h.2.ln_2

## dataset

In [83]:
# load
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L332

data_files = {}
# data_files["train"] = f'/net/scratch/shangao/latent-concept/data/train_T{n_events}_N1000_L1024_E0.3.json'
# data_files["validation"] = f'/net/scratch/shangao/latent-concept/data/val_T{n_events}_N500_L1024_E0.3.json'
data_files["train"] = f'/net/scratch/shangao/latent-concept/data/rdn_train_T{n_events}_N1000_L1024_E0.3.json'
data_files["validation"] = f'/net/scratch/shangao/latent-concept/data/rdn_val_T{n_events}_N500_L1024_E0.3.json'
extension = (
    list(data_files.values())[0].split(".")[-1]
)
if extension == "txt":
    extension = "text"

datasets = load_dataset(extension, data_files=data_files)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [102]:
train_concept_idx = np.array(datasets['train']['concept_idx'])
train_concept_distrib = [
    sum(train_concept_idx==0)/len(train_concept_idx),
    sum(train_concept_idx==1)/len(train_concept_idx),
    sum(train_concept_idx==2)/len(train_concept_idx)
]
train_concept_distrib

[0.35, 0.332, 0.318]

In [84]:
# tokenize
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L433

def tokenize_function(examples):
    return tokenizer(examples['text'])

column_names = datasets["train"].column_names
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=1,
    remove_columns=column_names,
    load_from_cache_file=False,
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [85]:
# set block_size
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L444
block_size = min(1024, tokenizer.model_max_length)

In [86]:
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L461
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower to preprocess.
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information: https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=1,
    load_from_cache_file=False,
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

## train & eval

In [87]:
logger = logging.getLogger(__name__)

In [88]:
training_args = TrainingArguments(
    do_train=True,
    overwrite_output_dir=True,
    learning_rate=8e-4,
    num_train_epochs=5,
    # output_dir=f'/net/scratch/shangao/latent-concept/outputs_small/pretrain/T{n_events}',
    output_dir=f'/net/scratch/shangao/latent-concept/outputs_small/pretrain/rdn_T{n_events}',
    logging_steps=1,
    save_total_limit=4,
    eval_strategy='steps',
    # eval_strategy='epoch',
    save_steps=1500,
    warmup_steps=1000,
    lr_scheduler_type='linear',
    per_device_train_batch_size=8,  # 2
    gradient_accumulation_steps=4
)

In [89]:
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L494
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
    processing_class=tokenizer,
    # Data collator will default to DataCollatorWithPadding, so we change it.
    data_collator=default_data_collator,
)

In [90]:
# Detecting last checkpoint.
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L268C1-L281C14

last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif last_checkpoint is not None:
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

In [91]:
# Training
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L505

if last_checkpoint is not None:
    checkpoint = last_checkpoint
# elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
#     checkpoint = model_args.model_name_or_path
else:
    checkpoint = None

train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()  # Saves the tokenizer too for easy upload

metrics = train_result.metrics

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

Step,Training Loss,Validation Loss
1,2.1231,2.117914
2,2.1259,2.069924
3,2.0785,1.989856
4,2.0129,1.906206
5,1.9409,1.852222
6,1.886,1.826021
7,1.8507,1.787354
8,1.812,1.736797
9,1.7653,1.713228
10,1.7432,1.703602


***** train metrics *****
  epoch                    =       4.96
  total_flos               =   804697GF
  train_loss               =     1.6511
  train_runtime            = 0:08:39.48
  train_samples_per_second =      9.625
  train_steps_per_second   =      0.298


In [92]:
# Evaluation
# https://github.com/shaangao/incontext-learning/blob/84fab2141381001e33b5835e01f4fbf37f34a6a5/run_clm.py#L522

results = {}

logger.info("*** Evaluate ***")

eval_output = trainer.evaluate()

perplexity = math.exp(eval_output["eval_loss"])
results["perplexity"] = perplexity

trainer.log_metrics("eval", results)
trainer.save_metrics("eval", results)

***** eval metrics *****
  perplexity = 5.0648


In [93]:
eval_output

{'eval_loss': 1.6223130226135254,
 'eval_runtime': 2.8215,
 'eval_samples_per_second': 177.209,
 'eval_steps_per_second': 22.328,
 'epoch': 4.96}