In [2]:
from transformers import ViTFeatureExtractor, ViTForImageClassification, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict, load_metric
import numpy as np
import torch
from PIL import Image

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and preprocess dataset
ds = load_dataset('betul2').rename_column('label', 'labels')
ds = ds['train'].train_test_split(test_size=0.1)
train_val = ds['train'].train_test_split(test_size=0.1)
ds['train'], ds['validation'], ds['test'] = train_val['train'], train_val['test'], ds['test']
dataset = DatasetDict(ds)

# Initialize feature extractor and model
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

# Transform function
def transform(example_batch):
    images = [np.moveaxis(np.array(x.convert('RGB')), -1, 0) for x in example_batch['image']]
    inputs = feature_extractor(images, return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

# Prepare dataset
prepared_ds = ds.with_transform(transform)

# Data collator
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

# Metric and compute function
metric = load_metric("accuracy", trust_remote_code=True)
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

# Training arguments
training_args = TrainingArguments(
    output_dir="./vit-base--v5",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

# Trainer initialization
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

# Train and evaluate
train_results = trainer.train()
trainer.save_model()
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(prepared_ds['test'])
trainer.save_metrics("eval", metrics)

# Model card creation
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'custom brats layers',
    "tags": ['image-classification'],
}
if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)


Resolving data files:   0%|          | 0/3662 [00:00<?, ?it/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/744 [00:00<?, ?it/s]

  context_layer = torch.nn.functional.scaled_dot_product_attention(


{'loss': 1.1698, 'grad_norm': 1.4556313753128052, 'learning_rate': 0.00019731182795698925, 'epoch': 0.05}
{'loss': 0.852, 'grad_norm': 1.0453912019729614, 'learning_rate': 0.0001946236559139785, 'epoch': 0.11}
{'loss': 0.8861, 'grad_norm': 0.5025181770324707, 'learning_rate': 0.00019193548387096775, 'epoch': 0.16}
{'loss': 0.7948, 'grad_norm': 10.801403045654297, 'learning_rate': 0.000189247311827957, 'epoch': 0.22}
{'loss': 0.6557, 'grad_norm': 2.3779098987579346, 'learning_rate': 0.00018655913978494625, 'epoch': 0.27}
{'loss': 0.8153, 'grad_norm': 2.8715460300445557, 'learning_rate': 0.00018387096774193548, 'epoch': 0.32}
{'loss': 0.8188, 'grad_norm': 0.9872220158576965, 'learning_rate': 0.00018118279569892475, 'epoch': 0.38}
{'loss': 0.5689, 'grad_norm': 1.0359077453613281, 'learning_rate': 0.00017849462365591398, 'epoch': 0.43}
{'loss': 0.7094, 'grad_norm': 1.143519639968872, 'learning_rate': 0.00017580645161290325, 'epoch': 0.48}
{'loss': 0.8534, 'grad_norm': 1.0831692218780518, '

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.7068464756011963, 'eval_accuracy': 0.7515151515151515, 'eval_runtime': 2.0267, 'eval_samples_per_second': 162.823, 'eval_steps_per_second': 20.723, 'epoch': 0.54}
{'loss': 0.7839, 'grad_norm': 2.110886573791504, 'learning_rate': 0.00017043010752688172, 'epoch': 0.59}
{'loss': 0.6712, 'grad_norm': 1.1638151407241821, 'learning_rate': 0.00016774193548387098, 'epoch': 0.65}
{'loss': 0.7448, 'grad_norm': 0.6875391602516174, 'learning_rate': 0.00016505376344086022, 'epoch': 0.7}
{'loss': 0.7018, 'grad_norm': 1.2545368671417236, 'learning_rate': 0.00016236559139784946, 'epoch': 0.75}
{'loss': 0.7516, 'grad_norm': 0.8829709887504578, 'learning_rate': 0.00015967741935483872, 'epoch': 0.81}
{'loss': 0.7675, 'grad_norm': 1.0640416145324707, 'learning_rate': 0.00015698924731182796, 'epoch': 0.86}
{'loss': 0.698, 'grad_norm': 0.7242034077644348, 'learning_rate': 0.00015430107526881722, 'epoch': 0.91}
{'loss': 0.9138, 'grad_norm': 1.1022175550460815, 'learning_rate': 0.0001516129032

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.6504932045936584, 'eval_accuracy': 0.7484848484848485, 'eval_runtime': 1.8829, 'eval_samples_per_second': 175.262, 'eval_steps_per_second': 22.306, 'epoch': 1.08}
{'loss': 0.8373, 'grad_norm': 0.7066428661346436, 'learning_rate': 0.00014354838709677422, 'epoch': 1.13}
{'loss': 0.583, 'grad_norm': 1.286678433418274, 'learning_rate': 0.00014086021505376346, 'epoch': 1.18}
{'loss': 0.4949, 'grad_norm': 1.0037462711334229, 'learning_rate': 0.0001381720430107527, 'epoch': 1.24}
{'loss': 0.5139, 'grad_norm': 0.7832070589065552, 'learning_rate': 0.00013548387096774193, 'epoch': 1.29}
{'loss': 0.6129, 'grad_norm': 1.1215609312057495, 'learning_rate': 0.0001327956989247312, 'epoch': 1.34}
{'loss': 0.597, 'grad_norm': 0.653411328792572, 'learning_rate': 0.00013010752688172043, 'epoch': 1.4}
{'loss': 0.7375, 'grad_norm': 1.1260840892791748, 'learning_rate': 0.0001274193548387097, 'epoch': 1.45}
{'loss': 0.671, 'grad_norm': 1.0885182619094849, 'learning_rate': 0.0001247311827956989

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.6633191704750061, 'eval_accuracy': 0.7636363636363637, 'eval_runtime': 1.8628, 'eval_samples_per_second': 177.151, 'eval_steps_per_second': 22.547, 'epoch': 1.61}
{'loss': 0.6966, 'grad_norm': 1.2594823837280273, 'learning_rate': 0.00011666666666666668, 'epoch': 1.67}
{'loss': 0.6268, 'grad_norm': 1.7371882200241089, 'learning_rate': 0.00011397849462365593, 'epoch': 1.72}
{'loss': 0.6121, 'grad_norm': 1.5216954946517944, 'learning_rate': 0.00011129032258064515, 'epoch': 1.77}
{'loss': 0.5024, 'grad_norm': 1.3477145433425903, 'learning_rate': 0.0001086021505376344, 'epoch': 1.83}
{'loss': 0.672, 'grad_norm': 1.0715265274047852, 'learning_rate': 0.00010591397849462365, 'epoch': 1.88}
{'loss': 0.6296, 'grad_norm': 3.713111400604248, 'learning_rate': 0.0001032258064516129, 'epoch': 1.94}
{'loss': 0.6717, 'grad_norm': 1.4756056070327759, 'learning_rate': 0.00010053763440860215, 'epoch': 1.99}
{'loss': 0.5909, 'grad_norm': 3.632136821746826, 'learning_rate': 9.78494623655914e

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5752906799316406, 'eval_accuracy': 0.793939393939394, 'eval_runtime': 1.9159, 'eval_samples_per_second': 172.247, 'eval_steps_per_second': 21.922, 'epoch': 2.15}
{'loss': 0.4867, 'grad_norm': 3.231328248977661, 'learning_rate': 8.978494623655914e-05, 'epoch': 2.2}
{'loss': 0.4909, 'grad_norm': 0.8531060814857483, 'learning_rate': 8.709677419354839e-05, 'epoch': 2.26}
{'loss': 0.389, 'grad_norm': 1.2248698472976685, 'learning_rate': 8.440860215053764e-05, 'epoch': 2.31}
{'loss': 0.4955, 'grad_norm': 3.5868959426879883, 'learning_rate': 8.172043010752689e-05, 'epoch': 2.37}
{'loss': 0.5262, 'grad_norm': 2.8980119228363037, 'learning_rate': 7.903225806451613e-05, 'epoch': 2.42}
{'loss': 0.5601, 'grad_norm': 1.0979009866714478, 'learning_rate': 7.634408602150538e-05, 'epoch': 2.47}
{'loss': 0.3772, 'grad_norm': 1.8252800703048706, 'learning_rate': 7.365591397849463e-05, 'epoch': 2.53}
{'loss': 0.5798, 'grad_norm': 1.623577356338501, 'learning_rate': 7.096774193548388e-05, '

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5808688998222351, 'eval_accuracy': 0.806060606060606, 'eval_runtime': 1.8913, 'eval_samples_per_second': 174.485, 'eval_steps_per_second': 22.207, 'epoch': 2.69}
{'loss': 0.4606, 'grad_norm': 1.7373626232147217, 'learning_rate': 6.290322580645161e-05, 'epoch': 2.74}
{'loss': 0.4491, 'grad_norm': 2.6274023056030273, 'learning_rate': 6.021505376344086e-05, 'epoch': 2.8}
{'loss': 0.5265, 'grad_norm': 2.8111517429351807, 'learning_rate': 5.752688172043011e-05, 'epoch': 2.85}
{'loss': 0.4482, 'grad_norm': 2.3688714504241943, 'learning_rate': 5.4838709677419355e-05, 'epoch': 2.9}
{'loss': 0.4656, 'grad_norm': 1.8154428005218506, 'learning_rate': 5.2150537634408605e-05, 'epoch': 2.96}
{'loss': 0.3828, 'grad_norm': 1.449847936630249, 'learning_rate': 4.9462365591397855e-05, 'epoch': 3.01}
{'loss': 0.3487, 'grad_norm': 0.42441222071647644, 'learning_rate': 4.67741935483871e-05, 'epoch': 3.06}
{'loss': 0.3545, 'grad_norm': 1.6662935018539429, 'learning_rate': 4.408602150537635e-0

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5549391508102417, 'eval_accuracy': 0.796969696969697, 'eval_runtime': 2.0841, 'eval_samples_per_second': 158.339, 'eval_steps_per_second': 20.152, 'epoch': 3.23}
{'loss': 0.3444, 'grad_norm': 6.237390995025635, 'learning_rate': 3.602150537634409e-05, 'epoch': 3.28}
{'loss': 0.3708, 'grad_norm': 5.466879367828369, 'learning_rate': 3.3333333333333335e-05, 'epoch': 3.33}
{'loss': 0.2087, 'grad_norm': 0.4771074950695038, 'learning_rate': 3.0645161290322585e-05, 'epoch': 3.39}
{'loss': 0.3231, 'grad_norm': 1.6142545938491821, 'learning_rate': 2.7956989247311828e-05, 'epoch': 3.44}
{'loss': 0.3607, 'grad_norm': 2.0825023651123047, 'learning_rate': 2.5268817204301075e-05, 'epoch': 3.49}
{'loss': 0.2245, 'grad_norm': 2.978687286376953, 'learning_rate': 2.258064516129032e-05, 'epoch': 3.55}
{'loss': 0.3053, 'grad_norm': 1.6812710762023926, 'learning_rate': 1.989247311827957e-05, 'epoch': 3.6}
{'loss': 0.2696, 'grad_norm': 3.742980480194092, 'learning_rate': 1.7204301075268818e-0

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5628989338874817, 'eval_accuracy': 0.7818181818181819, 'eval_runtime': 1.9213, 'eval_samples_per_second': 171.759, 'eval_steps_per_second': 21.86, 'epoch': 3.76}
{'loss': 0.3178, 'grad_norm': 0.31403791904449463, 'learning_rate': 9.13978494623656e-06, 'epoch': 3.82}
{'loss': 0.2307, 'grad_norm': 2.097884178161621, 'learning_rate': 6.451612903225806e-06, 'epoch': 3.87}
{'loss': 0.3203, 'grad_norm': 3.1524269580841064, 'learning_rate': 3.763440860215054e-06, 'epoch': 3.92}
{'loss': 0.3157, 'grad_norm': 1.4197919368743896, 'learning_rate': 1.0752688172043011e-06, 'epoch': 3.98}
{'train_runtime': 164.9859, 'train_samples_per_second': 71.885, 'train_steps_per_second': 4.509, 'train_loss': 0.5454547696857042, 'epoch': 4.0}


  0%|          | 0/46 [00:00<?, ?it/s]

In [3]:
trainer.evaluate(prepared_ds['test'])

  0%|          | 0/46 [00:00<?, ?it/s]

{'eval_loss': 0.4978531301021576,
 'eval_accuracy': 0.8256130790190735,
 'eval_runtime': 2.3023,
 'eval_samples_per_second': 159.403,
 'eval_steps_per_second': 19.98,
 'epoch': 4.0}

In [4]:
print(metrics)

{'eval_loss': 0.4978531301021576, 'eval_accuracy': 0.8256130790190735, 'eval_runtime': 2.2373, 'eval_samples_per_second': 164.039, 'eval_steps_per_second': 20.561, 'epoch': 4.0}
