Проверка работы wandb

In [1]:
import wandb

import os
import numpy as np
from datasets import load_dataset
from transformers import TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForSequenceClassification

Определяем функцию вычисления метрик

In [2]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": np.mean(predictions == labels)}

Готовим предварительные данные

In [3]:
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(300))

small_train_dataset = small_train_dataset.map(tokenize_function, batched=True)
small_eval_dataset = small_train_dataset.map(tokenize_function, batched=True)

README.md: 0.00B [00:00, ?B/s]

yelp_review_full/train-00000-of-00001.pa(…):   0%|          | 0.00/299M [00:00<?, ?B/s]

yelp_review_full/test-00000-of-00001.par(…):   0%|          | 0.00/23.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/650000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/50000 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

скачиваем модель

In [4]:
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=5)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


даем название проекту

In [5]:
os.environ["WANDB_PROJECT"]="my-awesome-project"

можно поставить false чтобы ускорить работу программы

In [6]:
os.environ["WANDB_WATCH"]="false"

обязательно передаем "wandb" в параметр report_to, чтобы включить ведение журнала wandb

In [7]:
training_args = TrainingArguments(
    output_dir='models',
    report_to="wandb",
    logging_steps=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="steps",
    eval_steps=20,
    max_steps = 100,
    save_steps = 100
)

передаем параметры для обучения и обучаем модель

In [8]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.train()

[34m[1mwandb[0m: Loading settings from /root/.config/wandb/settings
[34m[1mwandb[0m: [wandb.login()] Loaded credentials for http://wandb:8080 from /root/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33m-skorinaka[0m to [32mhttp://wandb:8080[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Accuracy
20,1.435,1.349941,0.474
40,1.1329,1.046686,0.598
60,0.9405,0.903021,0.662
80,0.8583,0.792507,0.712
100,0.7089,0.749986,0.76


TrainOutput(global_step=100, training_loss=1.104944987297058, metrics={'train_runtime': 137.0741, 'train_samples_per_second': 23.345, 'train_steps_per_second': 0.73, 'total_flos': 414380191457280.0, 'train_loss': 1.104944987297058, 'epoch': 3.125})

останавливаем wandb

In [9]:
wandb.finish()

0,1
eval/accuracy,▁▄▆▇█
eval/loss,█▄▃▁▁
eval/runtime,▁▂▄▅█
eval/samples_per_second,█▇▅▄▁
eval/steps_per_second,█▇▅▄▁
train/epoch,▁▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇████
train/global_step,▁▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇████
train/grad_norm,▂▁▁▄▅▇▇▃▅▃▅▅▅█▄▆▅▅▇▄
train/learning_rate,██▇▇▇▆▆▅▅▅▄▄▄▃▃▂▂▂▁▁
train/loss,██▇▇▆▆▅▄▄▄▄▃▃▃▂▂▂▂▂▁

0,1
eval/accuracy,0.76
eval/loss,0.74999
eval/runtime,9.2482
eval/samples_per_second,108.129
eval/steps_per_second,3.46
total_flos,414380191457280.0
train/epoch,3.125
train/global_step,100
train/grad_norm,3.87702
train/learning_rate,0
