In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate

Imports

In [2]:
import numpy as np
import torch
from datasets import load_dataset
from evaluate import load
from transformers import (
    ViTForImageClassification,
    ViTFeatureExtractor,
    Trainer,
    TrainingArguments,
)

### train_vit_clf.py

In [3]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def create_dataloaders_and_mappings(data_path):
    dataset = load_dataset("imagefolder", data_dir=data_path)

    splits = dataset["train"].train_test_split(test_size=0.33)
    dataset["train"] = splits["train"]
    dataset["val"] = splits["test"]

    id2label = {
        id: label for id, label in enumerate(dataset["train"].features["label"].names)
    }

    label2id = {label: id for id, label in id2label.items()}

    return dataset, id2label, label2id


def compute_metrics(eval_pred):
    metric1 = load("accuracy")
    metric2 = load("precision")
    metric3 = load("recall")
    metric4 = load("f1")

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = metric1.compute(predictions=predictions, references=labels)["accuracy"]
    precision = metric2.compute(predictions=predictions, references=labels, average="weighted")["precision"]
    recall = metric3.compute(predictions=predictions, references=labels, average="weighted")["recall"]
    f1 = metric4.compute(predictions=predictions, references=labels, average="weighted")["f1"]
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


def main():
    from google.colab import drive
    drive.mount('/content/drive')
    colab_data_path = "/content/drive/MyDrive/Seminar2/data/ribe_512x768"
    colab_dir = "/content/drive/MyDrive/Seminar2/model"
    model_id = "google/vit-base-patch16-224"

    dataset, id2label, label2id = create_dataloaders_and_mappings(colab_data_path)

    feature_extractor = ViTFeatureExtractor.from_pretrained(model_id, do_resize=False, patch_size=64)

    def transform(example_batch):
        inputs = feature_extractor(
            [x.convert("RGB") for x in example_batch["image"]], return_tensors="pt"
        )
        inputs["labels"] = example_batch["label"]
        return inputs

    dataset = dataset.with_transform(transform)

    model = ViTForImageClassification.from_pretrained(
        pretrained_model_name_or_path=model_id,
        num_labels=len(id2label),
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True,
    )

    training_args = TrainingArguments(
        output_dir=colab_dir,
        per_device_train_batch_size=4,
        evaluation_strategy="steps",
        num_train_epochs=5,
        fp16=True,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to="tensorboard",
        load_best_model_at_end=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=dataset["train"],
        eval_dataset=dataset["val"],
    )

    train_results = trainer.train()
    trainer.save_model()
    trainer.log_metrics("train", train_results.metrics)
    trainer.save_metrics("train", train_results.metrics)
    trainer.save_state()

    metrics = trainer.evaluate(dataset["test"])
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)


### Experiment

In [4]:
main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Resolving data files:   0%|          | 0/1332 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1434 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
100,0.709,0.820743,0.747727,0.883004,0.747727,0.734477
200,0.5575,0.3814,0.813636,0.863129,0.813636,0.769381
300,0.1275,0.242318,0.911364,0.923508,0.911364,0.912789
400,0.1635,0.277191,0.906818,0.913864,0.906818,0.907465
500,0.2165,0.224583,0.936364,0.936692,0.936364,0.936519
600,0.0317,0.138756,0.968182,0.968732,0.968182,0.967871
700,0.0002,0.106988,0.972727,0.973665,0.972727,0.972317
800,0.0763,0.053277,0.984091,0.984694,0.984091,0.984169
900,0.0001,0.051447,0.988636,0.98863,0.988636,0.988624
1000,0.0,0.07217,0.981818,0.982667,0.981818,0.981923


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

***** train metrics *****
  epoch                    =          5.0
  total_flos               = 2522498038GF
  train_loss               =       0.2014
  train_runtime            =   0:52:18.18
  train_samples_per_second =        1.421
  train_steps_per_second   =        0.355


***** eval metrics *****
  epoch                   =        5.0
  eval_accuracy           =     0.5314
  eval_f1                 =     0.5289
  eval_loss               =     3.4565
  eval_precision          =     0.6794
  eval_recall             =     0.5314
  eval_runtime            = 0:10:27.77
  eval_samples_per_second =      2.284
  eval_steps_per_second   =      0.287
