In [1]:
import torch
import torch_xla



In [2]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-2-v1")

In [28]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.bos_token_id = 128000
tokenizer.eos_token_id = 128001
tokenizer.pad_token_id = tokenizer.eos_token_id 

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=100)

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [29]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [30]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [31]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=100,
)

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [32]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [33]:
len(lm_datasets["validation"])  # type:ignore

3760

In [40]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,  # Model size
    num_hidden_layers=4,  # Number of transformer layers
    num_attention_heads=2,  # Number of attention heads
    intermediate_size=256,  # Size of the feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=True,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [41]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    max_steps=5,
    save_steps=5,
    save_total_limit=2,
    learning_rate=1e-7,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=1,
    gradient_accumulation_steps=1,
    fp16=True,  # Enable mixed precision
    report_to="tensorboard",  # Enable logging to TensorBoard
    push_to_hub=False,
)

In [42]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42).select(range(20)),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42).select(range(20)),  # type:ignore
    tokenizer=tokenizer,
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
max_steps is given, it will override any value given in num_train_epochs


In [43]:
trainer.train()

Using for loop to run decoder layers


Epoch,Training Loss,Validation Loss
1,11.8525,11.843368


  xldata.append(torch.load(xbio))


Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers


TrainOutput(global_step=5, training_loss=11.807607269287109, metrics={'train_runtime': 33.4144, 'train_samples_per_second': 0.599, 'train_steps_per_second': 0.15, 'total_flos': 261742264320.0, 'train_loss': 11.807607269287109, 'epoch': 1.0})

## Train again, this time using scan

In [44]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=128,  # Model size
    num_hidden_layers=4,  # Number of transformer layers
    num_attention_heads=2,  # Number of attention heads
    intermediate_size=256,  # Size of the feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [45]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42).select(range(20)),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42).select(range(20)),  # type:ignore
    tokenizer=tokenizer,
)

trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
max_steps is given, it will override any value given in num_train_epochs


Using apply_layers to speed up compilation


Epoch,Training Loss,Validation Loss
1,11.8525,11.843363


  xldata.append(torch.load(xbio))


Using apply_layers to speed up compilation
Using apply_layers to speed up compilation
Using apply_layers to speed up compilation
Using apply_layers to speed up compilation
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers
Using for loop to run decoder layers


TrainOutput(global_step=5, training_loss=11.807618522644043, metrics={'train_runtime': 31.2497, 'train_samples_per_second': 0.64, 'train_steps_per_second': 0.16, 'total_flos': 261742264320.0, 'train_loss': 11.807618522644043, 'epoch': 1.0})