<a href="https://colab.research.google.com/github/shpotes/spanish-gpt-neo/blob/main/notebooks/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!rm -rf spanish-gpt-neo
!pip install "datasets[streaming]" flax optax 
!pip install git+https://github.com/huggingface/transformers
!git clone https://github.com/shpotes/spanish-gpt-neo.git
%cd spanish-gpt-neo

In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [3]:
import logging
import math
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional

import datasets
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm

import json
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.serialization import to_bytes, from_bytes
from transformers import (
    AutoConfig,
    AutoTokenizer,
    GPTNeoForCausalLM,
    FlaxAutoModelForCausalLM,
    HfArgumentParser,
    TrainingArguments,
    is_tensorboard_available,
)
from transformers.testing_utils import CaptureLogger


logger = logging.getLogger(__name__)

from src.utils import * 

In [4]:
class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))


def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
    """
    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
    Shuffle batches if `shuffle` is `True`.
    """
    steps_per_epoch = len(dataset) // batch_size

    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
    else:
        batch_idx = jnp.arange(len(dataset))

    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: jnp.array(v) for k, v in batch.items()}

        batch = shard(batch)

        yield batch


def write_train_metric(summary_writer, train_metrics, train_time, step):
    summary_writer.scalar("train_time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        tag = f"train_{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)


def write_eval_metric(summary_writer, eval_metrics, step):
    for metric_name, value in eval_metrics.items():
        summary_writer.scalar(f"eval_{metric_name}", value, step)


def create_learning_rate_fn(
    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn

In [5]:
model_args = ModelArguments(
    model_name_or_path="EleutherAI/gpt-neo-125M",
    tokenizer_name="flax-community/bertin-roberta-large-spanish",
    dtype="bfloat16"
)

In [6]:
data_args = DataTrainingArguments(
    dataset_name="oscar", 
    dataset_config_name="unshuffled_deduplicated_es", 
    block_size=1024,
    max_train_samples=10000, 
    max_eval_samples=1000, 
    preprocessing_num_workers=32
)

In [7]:
training_args = TrainingArguments(
    num_train_epochs=1,
    output_dir="model/", 
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16, 
    learning_rate=3e-4,
    weight_decay=0.1,
    do_train=True,
    do_eval=True,
    warmup_steps=100,
    push_to_hub=False,
    overwrite_output_dir=True,
    report_to=None,
)

In [8]:
train_dataset = load_dataset(
    data_args.dataset_name,
    data_args.dataset_config_name, 
    streaming=True,
)

Downloading:   0%|          | 0.00/5.58k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/359k [00:00<?, ?B/s]

In [9]:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
config.vocab_size = tokenizer.vocab_size

Downloading:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/292 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/846k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/505k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.45M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [10]:
model = FlaxAutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ignore_mismatched_sizes=True,
)

Downloading:   0%|          | 0.00/501M [00:00<?, ?B/s]

INFO:absl:Unable to initialize backend 'gpu': Failed precondition: No visible GPU devices.
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
Some weights of FlaxGPTNeoForCausalLM were not initialized from the model checkpoint at EleutherAI/gpt-neo-125M and are newly initialized because the shapes did not match:
- ('transformer', 'wte', 'embedding'): found shape (50257, 768) in the checkpoint and (50265, 768) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def tokenize_function(examples):
    output = tokenizer(examples["text"])
    
    return output

tokenized_dataset = train_dataset["train"].map(
    tokenize_function,
    batched=True,
)

In [12]:
unbatch = next(iter(tokenized_dataset))

In [14]:
if data_args.block_size is None:
    block_size = tokenizer.model_max_length
    if block_size > config.max_position_embeddings:
        logger.warning(
            f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
            "Picking 1024 instead. You can change that default value by passing --block_size xxx."
        )
        block_size = 1024
else:
    if data_args.block_size > tokenizer.model_max_length:
        logger.warning(
            f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
            f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
        )
    block_size = min(data_args.block_size, tokenizer.model_max_length)

In [15]:
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)

In [16]:
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
total_train_steps = training_args.max_steps * training_args.gradient_accumulation_steps

In [17]:
linear_decay_lr_schedule_fn = create_learning_rate_fn(
    len(train_dataset),
    train_batch_size,
    training_args.num_train_epochs,
    training_args.warmup_steps,
    training_args.learning_rate,
)

INFO:absl:A polynomial schedule was set with a non-positive `transition_steps` value; this results in a constant schedule with value `init_value`.


In [18]:
if training_args.adafactor:
    # We use the default parameters here to initialize adafactor,
    # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
    optimizer = optax.adafactor(
        learning_rate=linear_decay_lr_schedule_fn,
    )
else:
    optimizer = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
    )
    optimizer = optax.chain(
        optax.clip_by_global_norm(1),
        optimizer
    )

In [19]:
model.params['transformer']['wte']

{'embedding': DeviceArray([[-0.0175781, 0.0108032, -0.0314941, ..., -0.0303955,
               0.0209961, 0.00897217],
              [-0.0314941, 0.0288086, 0.0112305, ..., -0.0390625,
               -0.0390625, 0.0224609],
              [0.012207, 0.0196533, -0.0314941, ..., 0.0108032, 0.0098877,
               -0.0105591],
              ...,
              [0.0310059, 0.00285339, 0.00166321, ..., 0.0146484,
               0.00817871, -0.0100708],
              [-0.0212402, 0.0179443, -0.0332031, ..., 0.00689697,
               -0.0228271, 0.00897217],
              [-0.0159912, 0.0098877, 0.00166321, ..., 0.0112305,
               -0.0153809, 0.00402832]], dtype=bfloat16)}

In [20]:
trainable_params = {'transformer': {'wte': model.params['transformer']['wte']}}

In [21]:
state = TrainState.create(
    apply_fn=model.__call__, 
    params=trainable_params, 
    tx=optimizer,
    dropout_rng=dropout_rng
)

In [22]:
def loss_fn(logits, labels):
    shift_logits = logits[..., :-1, :]
    shift_labels = labels[..., 1:]
    loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
    return loss.mean()                           

In [23]:
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

In [24]:
def compute_loss(params):
    labels = batch.pop("labels")
    logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
    loss = loss_fn(logits, labels)
    return loss

In [25]:
compute_loss(trainable_params)

NameError: ignored