In [1]:
import torch

In [2]:
from datasets import load_dataset

ds = load_dataset("food101")

In [3]:
labels = ds["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

'baklava'

In [4]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [5]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
    [
        RandomResizedCrop(image_processor.size["height"]),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [6]:
train_ds = ds["train"]
val_ds = ds["validation"]

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [7]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [8]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [10]:
from peft import AdaLoraConfig, get_peft_model, IA3Model, IA3Config, LoraConfig

ada_lora_config = AdaLoraConfig(
    r=8,
    init_r=12,
    tinit=200,
    tfinal=1000,
    deltaT=10,
    target_modules=["query", "value"],
    modules_to_save=["classifier"],
)

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query", "value"],
    lora_dropout=0.01,
)
ia3_config = IA3Config(
    peft_type="IA3",
    target_modules=["key", "value", "dense"],
    feedforward_modules=["dense"],
)

peft_model = get_peft_model(model, ada_lora_config)
peft_model.print_trainable_parameters()
#"trainable params: 520,325 || all params: 87,614,722 || trainable%: 0.5938785036606062"

trainable params: 520,325 || all params: 86,396,674 || trainable%: 0.6022511931419953


In [None]:
def compute_metrics():
	pass

In [11]:
from transformers import TrainingArguments, Trainer

batch_size = 16

args = TrainingArguments(
    "peft_test",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=32,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    label_names=["labels"],
)

In [12]:
trainer = Trainer(
    peft_model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    data_collator=collate_fn,
)
trainer.train()

  0%|          | 0/735 [00:00<?, ?it/s]

{'loss': 3.71, 'learning_rate': 0.004931972789115647, 'epoch': 0.07}
{'loss': 1.9883, 'learning_rate': 0.0048639455782312924, 'epoch': 0.14}
{'loss': 1.2399, 'learning_rate': 0.004795918367346939, 'epoch': 0.2}
{'loss': 1.0627, 'learning_rate': 0.004727891156462586, 'epoch': 0.27}
{'loss': 0.9556, 'learning_rate': 0.004659863945578231, 'epoch': 0.34}
{'loss': 0.917, 'learning_rate': 0.004591836734693878, 'epoch': 0.41}
{'loss': 0.8947, 'learning_rate': 0.004523809523809524, 'epoch': 0.47}
{'loss': 0.8512, 'learning_rate': 0.00445578231292517, 'epoch': 0.54}
{'loss': 0.892, 'learning_rate': 0.004387755102040816, 'epoch': 0.61}
{'loss': 0.8804, 'learning_rate': 0.004319727891156463, 'epoch': 0.68}
{'loss': 0.8493, 'learning_rate': 0.004251700680272108, 'epoch': 0.74}
{'loss': 0.8486, 'learning_rate': 0.004183673469387756, 'epoch': 0.81}
{'loss': 0.8406, 'learning_rate': 0.0041156462585034016, 'epoch': 0.88}
{'loss': 0.802, 'learning_rate': 0.004047619047619048, 'epoch': 0.95}


  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 0.4648701250553131, 'eval_runtime': 328.176, 'eval_samples_per_second': 76.94, 'eval_steps_per_second': 4.811, 'epoch': 0.99}
{'loss': 0.7775, 'learning_rate': 0.003979591836734694, 'epoch': 1.01}
{'loss': 0.7401, 'learning_rate': 0.0039115646258503405, 'epoch': 1.08}
{'loss': 0.7486, 'learning_rate': 0.0038435374149659862, 'epoch': 1.15}
{'loss': 0.7394, 'learning_rate': 0.0037755102040816324, 'epoch': 1.22}
{'loss': 0.7224, 'learning_rate': 0.0037074829931972794, 'epoch': 1.28}
{'loss': 0.7298, 'learning_rate': 0.0036394557823129256, 'epoch': 1.35}
{'loss': 0.7716, 'learning_rate': 0.0035714285714285718, 'epoch': 1.42}
{'loss': 0.7458, 'learning_rate': 0.003503401360544218, 'epoch': 1.49}
{'loss': 0.7402, 'learning_rate': 0.003435374149659864, 'epoch': 1.55}
{'loss': 0.7094, 'learning_rate': 0.0033673469387755102, 'epoch': 1.62}
{'loss': 0.7616, 'learning_rate': 0.0032993197278911564, 'epoch': 1.69}
{'loss': 0.7013, 'learning_rate': 0.003231292517006803, 'epoch': 1.76}


  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 0.41003984212875366, 'eval_runtime': 341.3613, 'eval_samples_per_second': 73.969, 'eval_steps_per_second': 4.626, 'epoch': 1.99}
{'loss': 0.6716, 'learning_rate': 0.0029591836734693877, 'epoch': 2.03}
{'loss': 0.6728, 'learning_rate': 0.002891156462585034, 'epoch': 2.1}
{'loss': 0.6483, 'learning_rate': 0.00282312925170068, 'epoch': 2.16}
{'loss': 0.6904, 'learning_rate': 0.002755102040816326, 'epoch': 2.23}
{'loss': 0.6771, 'learning_rate': 0.002687074829931973, 'epoch': 2.3}
{'loss': 0.6848, 'learning_rate': 0.0026190476190476194, 'epoch': 2.37}
{'loss': 0.6883, 'learning_rate': 0.0025510204081632655, 'epoch': 2.43}
{'loss': 0.6542, 'learning_rate': 0.0024829931972789117, 'epoch': 2.5}
{'loss': 0.6782, 'learning_rate': 0.002414965986394558, 'epoch': 2.57}
{'loss': 0.6697, 'learning_rate': 0.002346938775510204, 'epoch': 2.64}
{'loss': 0.6685, 'learning_rate': 0.0022789115646258506, 'epoch': 2.7}
{'loss': 0.6469, 'learning_rate': 0.002210884353741497, 'epoch': 2.77}
{'los

  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 0.37821701169013977, 'eval_runtime': 337.8259, 'eval_samples_per_second': 74.743, 'eval_steps_per_second': 4.674, 'epoch': 2.99}
{'loss': 0.6124, 'learning_rate': 0.0019387755102040817, 'epoch': 3.04}
{'loss': 0.6352, 'learning_rate': 0.001870748299319728, 'epoch': 3.11}
{'loss': 0.584, 'learning_rate': 0.0018027210884353742, 'epoch': 3.18}
{'loss': 0.615, 'learning_rate': 0.0017346938775510204, 'epoch': 3.24}
{'loss': 0.6225, 'learning_rate': 0.0016734693877551022, 'epoch': 3.31}
{'loss': 0.6091, 'learning_rate': 0.0016054421768707484, 'epoch': 3.38}
{'loss': 0.645, 'learning_rate': 0.0015374149659863946, 'epoch': 3.45}
{'loss': 0.6192, 'learning_rate': 0.001469387755102041, 'epoch': 3.51}
{'loss': 0.6277, 'learning_rate': 0.0014013605442176871, 'epoch': 3.58}
{'loss': 0.5923, 'learning_rate': 0.0013333333333333333, 'epoch': 3.65}
{'loss': 0.6192, 'learning_rate': 0.0012653061224489795, 'epoch': 3.72}
{'loss': 0.619, 'learning_rate': 0.0011972789115646258, 'epoch': 3.78}

  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 0.36065277457237244, 'eval_runtime': 335.3836, 'eval_samples_per_second': 75.287, 'eval_steps_per_second': 4.708, 'epoch': 3.99}
{'loss': 0.5912, 'learning_rate': 0.0009251700680272108, 'epoch': 4.05}
{'loss': 0.5885, 'learning_rate': 0.0008571428571428572, 'epoch': 4.12}
{'loss': 0.5673, 'learning_rate': 0.0007891156462585034, 'epoch': 4.19}
{'loss': 0.5821, 'learning_rate': 0.0007210884353741496, 'epoch': 4.26}
{'loss': 0.6128, 'learning_rate': 0.000653061224489796, 'epoch': 4.33}
{'loss': 0.5744, 'learning_rate': 0.0005850340136054422, 'epoch': 4.39}
{'loss': 0.5872, 'learning_rate': 0.0005170068027210885, 'epoch': 4.46}
{'loss': 0.5888, 'learning_rate': 0.0004489795918367347, 'epoch': 4.53}
{'loss': 0.582, 'learning_rate': 0.000380952380952381, 'epoch': 4.6}
{'loss': 0.5681, 'learning_rate': 0.00031292517006802724, 'epoch': 4.66}
{'loss': 0.5654, 'learning_rate': 0.00024489795918367346, 'epoch': 4.73}
{'loss': 0.5919, 'learning_rate': 0.00017687074829931976, 'epoch': 

  0%|          | 0/1579 [00:00<?, ?it/s]

{'eval_loss': 0.3536034822463989, 'eval_runtime': 309.0997, 'eval_samples_per_second': 81.689, 'eval_steps_per_second': 5.108, 'epoch': 4.97}
{'train_runtime': 9997.6021, 'train_samples_per_second': 37.884, 'train_steps_per_second': 0.074, 'train_loss': 0.7547856937460348, 'epoch': 4.97}


TrainOutput(global_step=735, training_loss=0.7547856937460348, metrics={'train_runtime': 9997.6021, 'train_samples_per_second': 37.884, 'train_steps_per_second': 0.074, 'train_loss': 0.7547856937460348, 'epoch': 4.97})

- AdaLORA: 10.6s/iter
- LORA: 13.22s/iter