In [1]:
from datasets import load_dataset

ds = load_dataset("dominguesm/alpaca-data-pt-br")

In [2]:
from transformers import AutoTokenizer

# Example usage of the GPT model
tokenizer = AutoTokenizer.from_pretrained("pierreguillou/gpt2-small-portuguese")

In [3]:
df = ds["train"].to_pandas()

In [4]:
df["concat"] = (
    df["instruction"] + "\n\n" + df["input"].fillna("") + "\n\n" + df["output"]
)

In [5]:
# create a column with count of tokens for each row
df["tokens"] = df["concat"].apply(lambda x: len(tokenizer.encode(x)))

In [6]:
df["tokens"].describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])

count    51759.000000
mean        74.283371
std         40.517703
min         10.000000
25%         38.000000
50%         67.000000
75%        110.000000
90%        127.000000
95%        136.000000
99%        159.000000
max        838.000000
Name: tokens, dtype: float64

## Train

In [7]:
import torch

from model import GPT, GPTConfig

device = "cuda" if torch.cuda.is_available() else "cpu"

config = GPTConfig(
    block_size=128,
    vocab_size=tokenizer.vocab_size,
    n_layer=4,
    n_head=4,
    n_embd=256,
    n_experts=2,
    capacity_factor=1,
    k=2,
    experts_weight=0.01,
    router_weight=0.001,
    dropout=0.2,
    bias=True,
)

model = GPT(config)

number of parameters: 18.13M


In [8]:
inputs = tokenizer(
    df["concat"].tolist(),
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=128,
)

input_ids = inputs["input_ids"]
targets = input_ids.clone()

targets[:, :-1] = input_ids[:, 1:]
targets[:, -1] = -1

# convert targets 0 to -1
targets[targets == 0] = -1

# print("input ids:", inputs["input_ids"], "shape:", inputs["input_ids"].shape)

In [9]:
# invert 0 -> 1 and 1 -> 0 in attention mask
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# create a new upper triangular mask
upper_mask = torch.triu(torch.ones(input_ids.shape[1], input_ids.shape[1]), diagonal=1)

# apply the mask to the attention mask
attention_mask = attention_mask.unsqueeze(1) * upper_mask.unsqueeze(0)

# conver to bool
attention_mask = attention_mask.bool()

In [10]:
model(input_ids[:5, :], attention_mask=attention_mask[:5, :, :])[0].shape

Router loss: 0.0005323818186298013, Balance loss: 0.009999999776482582
Router loss: 0.0005640311283059418, Balance loss: 0.009999999776482582
Router loss: 0.0006321023683995008, Balance loss: 0.009999999776482582
Router loss: 0.0004682167782448232, Balance loss: 0.009999999776482582


torch.Size([5, 1, 50257])

### Train loop

In [11]:
# send data to device
model = model.to(device)
input_ids = input_ids[:32, :].to(device)  # limit to 16 for testing
attention_mask = attention_mask[:32, :, :].to(device)
targets = targets[:32, :].to(device)

# get only one sample for each
input_ids = input_ids[0:1, :]
attention_mask = attention_mask[0:1, :, :]
targets = targets[0:1, :]

In [12]:
from tqdm.notebook import tqdm

with torch.no_grad():
    logits, loss = model(input_ids, attention_mask, targets)
    # print("logits:", logits, "shape:", logits.shape)
print("loss:", loss.item() if loss is not None else "N/A")
if loss is not None:
    perplexity = torch.exp(loss)
    print("perplexity:", perplexity.item())

# pred_tokens = logits.argmax(dim=-1)
# tokens = pred_tokens[0].tolist()
# print("predicted tokens:", tokens)
# print("predicted text:", tokenizer.decode(tokens))

optimizer = model.configure_optimizers(
    weight_decay=0.0,
    learning_rate=3e-4,
    betas=(0.9, 0.95),
    device_type="cuda",
)

pbar = tqdm(range(500), desc="Training Epochs")

for epoch in pbar:
    model.train()
    optimizer.zero_grad()
    logits, loss = model(input_ids, attention_mask, targets)
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        logits, loss = model(input_ids, attention_mask, targets)

    perplexity = torch.exp(loss)
    pbar.set_postfix(loss=loss.item(), perplexity=perplexity.item())

    pred_tokens = logits.argmax(dim=-1)
    tokens = pred_tokens[0].tolist()

# Router loss: 0.0004672443028539419, Balance loss: 0.019999997690320015

Router loss: 0.0004927273839712143, Balance loss: 0.009999999776482582
Router loss: 0.0007793569238856435, Balance loss: 0.009999999776482582
Router loss: 0.0005962931318208575, Balance loss: 0.009999999776482582
Router loss: 0.00046999318874441087, Balance loss: 0.009999999776482582
loss: 10.907464027404785
perplexity: 54582.25390625
using fused AdamW: True


Training Epochs:   0%|          | 0/500 [00:00<?, ?it/s]

Router loss: 0.0005381920491345227, Balance loss: 0.009999999776482582
Router loss: 0.0007281192811205983, Balance loss: 0.009999999776482582
Router loss: 0.0006311602191999555, Balance loss: 0.009999999776482582
Router loss: 0.0005019208765588701, Balance loss: 0.009999999776482582
Router loss: 0.00043407001066952944, Balance loss: 0.009999999776482582
Router loss: 0.0005083784344606102, Balance loss: 0.009999999776482582
Router loss: 0.00041413685539737344, Balance loss: 0.009999999776482582
Router loss: 0.0004415688163135201, Balance loss: 0.010000000707805157
Router loss: 0.00042982472223229706, Balance loss: 0.009999999776482582
Router loss: 0.0005580366123467684, Balance loss: 0.009999999776482582
Router loss: 0.00047444243682548404, Balance loss: 0.009999998845160007
Router loss: 0.00046147702960297465, Balance loss: 0.009999999776482582
Router loss: 0.0003632362640928477, Balance loss: 0.009999999776482582
Router loss: 0.0003565136867109686, Balance loss: 0.009999999776482582
R

In [15]:
with torch.no_grad():
    generated_ids = model.generate(
        input_ids[:, :1].clone(),
        max_new_tokens=100,
        temperature=0.1,
        top_k=1,
        greedy=True,
    )
generated_text = tokenizer.decode(generated_ids[0].tolist())
_df = df.head(1)
_df["generated_text"] = [
    tokenizer.decode(generated_ids[i].tolist()) for i in range(len(generated_ids))
]
print(f"Generated text after epoch {epoch + 1}: {generated_text}")


Router loss: 1.0006592674471904e-06, Balance loss: 0.009999998845160007
Router loss: 7.375970199063886e-07, Balance loss: 0.009999999776482582
Router loss: 3.7990373584761983e-06, Balance loss: 0.009999999776482582
Router loss: 9.688825230114162e-06, Balance loss: 0.009999999776482582
Router loss: 6.226515552043566e-07, Balance loss: 0.009999999776482582
Router loss: 5.606545073533198e-06, Balance loss: 0.009999999776482582
Router loss: 1.3037845747021493e-05, Balance loss: 0.009999999776482582
Router loss: 1.4436724995903205e-05, Balance loss: 0.009999999776482582
Router loss: 1.3016343700655852e-06, Balance loss: 0.009999999776482582
Router loss: 8.00445741333533e-06, Balance loss: 0.009999999776482582
Router loss: 1.9286160750198178e-05, Balance loss: 0.009999999776482582
Router loss: 1.9746912585105747e-05, Balance loss: 0.009999999776482582
Router loss: 1.5882696970948018e-06, Balance loss: 0.009999999776482582
Router loss: 7.777319297019858e-06, Balance loss: 0.009999999776482582

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  _df["generated_text"] = [


In [16]:
_df.head(16).style

Unnamed: 0,instruction,input,output,concat,tokens,generated_text
0,Dê três dicas para se manter saudável.,,1. Coma uma dieta equilibrada e certifique-se de incluir muitas frutas e vegetais. 2. Exercite-se regularmente para manter seu corpo ativo e forte. 3. Durma o suficiente e mantenha um horário de sono consistente.,Dê três dicas para se manter saudável. 1. Coma uma dieta equilibrada e certifique-se de incluir muitas frutas e vegetais. 2. Exercite-se regularmente para manter seu corpo ativo e forte. 3. Durma o suficiente e mantenha um horário de sono consistente.,62,Dêêê três três três três três três três três três três dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas dicas se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se se
