In [None]:
import sys
import os
from pathlib import Path
sub_project_dir = Path(os.path.abspath(''))
project_dir = sub_project_dir.parent
sys.path.insert(0, project_dir.parent.as_posix())

import evaluate
import datasets
from tqdm import trange, tqdm
import numpy as np
import torch
from matplotlib import pyplot as plt
import timm
from transformers import Trainer, TrainingArguments, AutoImageProcessor

%load_ext autoreload
%autoreload 2
from llm_papers.vit.models import ViT
from llm_papers.vit.dataset import CIFARDataset
from llm_papers.utils import device

%matplotlib inline

In [None]:
model = ViT(
    img_size=224,
    patch_size=32,
    num_classes=100,
    d_model=768,
    d_mlp=3072,
    num_layers=12,
    num_heads=12,
)

model_timm = timm.create_model("vit_base_patch32_224", pretrained=True, num_classes=model.num_classes)
image_processor = AutoImageProcessor.from_pretrained(model_timm.default_cfg['hf_hub_id'])
model.load_timm(model_timm.state_dict())
model.freeze()
model.to(device)

ViT(
  (patch_embed): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  (blocks): Sequential(
    (0): TransformerBlock(
      (attn): Attention(
        (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
        (Wo): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm_attn): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (dropout): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (norm_mlp): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attn): Attention(
        (Wqkv): Linear(in_features=768, out_features=2304, bias=True)
        (Wo): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm_attn): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc

In [5]:
dataset = datasets.load_dataset("uoft-cs/cifar100")
train_dataset = CIFARDataset(dataset["train"], image_processor)
test_dataset = CIFARDataset(dataset["test"], image_processor)

In [6]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, -1)
    labels = labels

    accuracy = evaluate.load("accuracy")
    results = accuracy.compute(predictions=predictions, references=labels)

    return results


lr = 1e-3
batch_size = 512
epochs = 5
wd = 0.1
output_dir = (
    sub_project_dir
    / "checkpoints"
    / f"vit_b_32_224_cifar100_lr{lr}_bs{batch_size}_ep{epochs}_wd{wd}"
)
output_dir.mkdir(parents=True, exist_ok=True)
logging_steps = 10
os.environ["WANDB_ENTITY"] = "ztzhu11"
os.environ["WANDB_PROJECT"] = "ViT"
args = TrainingArguments(
    output_dir=output_dir.as_posix(),
    per_device_train_batch_size=batch_size // 8,
    gradient_accumulation_steps=8,
    num_train_epochs=epochs,
    learning_rate=lr,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    weight_decay=wd,
    seed=42,
    data_seed=43,
    fp16=True,
    fp16_full_eval=True,
    eval_strategy="steps",
    eval_on_start=True,
    per_device_eval_batch_size=128,
    eval_steps=logging_steps*4,
    save_strategy="steps",
    save_steps=logging_steps*4,
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    log_level="info",
    logging_first_step=True,
    logging_steps=logging_steps,
    report_to="wandb",
    run_name=output_dir.name,
)
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

Using auto half precision backend


In [7]:
trainer.train()

***** Running training *****
  Num examples = 50,000
  Num Epochs = 5
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 512
  Gradient Accumulation steps = 8
  Total optimization steps = 490
  Number of trainable parameters = 76,900
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mztzhu1[0m ([33mztzhu11[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Could not log the number of model parameters in Weights & Biases due to an AttributeError.

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 128


Step,Training Loss,Validation Loss,Accuracy
0,No log,5.036108,0.0091
40,0.663000,0.621115,0.8328
80,0.515400,0.49388,0.8586
120,0.393500,0.461918,0.8657
160,0.373900,0.450013,0.8665
200,0.339500,0.436018,0.8721
240,0.285300,0.431257,0.8712
280,0.322600,0.424499,0.875
320,0.267300,0.421453,0.8749
360,0.260700,0.420154,0.8749



***** Running Evaluation *****
  Num examples = 10000
  Batch size = 128
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/vit/checkpoints/vit_b_32_224_cifar100_lr0.001_bs512_ep5_wd0.1/checkpoint-40
Trainer.model is not a `PreTrainedModel`, only saving its state dict.

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 128
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/vit/checkpoints/vit_b_32_224_cifar100_lr0.001_bs512_ep5_wd0.1/checkpoint-80
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Deleting older checkpoint [/workspace/LLM-papers-from-scratch/llm_papers/vit/checkpoints/vit_b_32_224_cifar100_lr0.001_bs512_ep5_wd0.1/checkpoint-40] due to args.save_total_limit

***** Running Evaluation *****
  Num examples = 10000
  Batch size = 128
Saving model checkpoint to /workspace/LLM-papers-from-scratch/llm_papers/vit/checkpoints/vit_b_32_224_cifar100_lr0.001_bs512_ep5_wd0.1/checkpoint-120
Trainer.mod

TrainOutput(global_step=490, training_loss=0.4452876207779865, metrics={'train_runtime': 1402.5514, 'train_samples_per_second': 178.247, 'train_steps_per_second': 0.349, 'total_flos': 0.0, 'train_loss': 0.4452876207779865, 'epoch': 5.0})