In [8]:
from transformers import ViTFeatureExtractor, ViTForImageClassification, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict, load_metric
import numpy as np
import torch
from PIL import Image

DATASET_PATH = "mnist"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and preprocess dataset
ds = load_dataset(DATASET_PATH).rename_column('label', 'labels')


In [9]:
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'labels'],
        num_rows: 10000
    })
})

In [10]:
ds.keys()

dict_keys(['train', 'test'])

In [11]:
# Check for existing splits and handle accordingly
if 'train' in ds.keys() and 'validation' in ds.keys() and 'test' in ds.keys():
    # Dataset already has train, validation, and test splits
    dataset = DatasetDict({
        'train': ds['train'],
        'validation': ds['validation'],
        'test': ds['test']
    })
elif 'train' in ds.keys() and 'validation' in ds.keys():
    # Dataset has only train and validation splits, so create a test split from validation
    ds = ds['validation'].train_test_split(test_size=0.5)
    dataset = DatasetDict({
        'train': ds['train'],
        'validation': ds['test'],
        'test': ds['test']
    })
elif 'train' in ds.keys() and 'test' in ds.keys():
    # Dataset has only train and test splits, so create a validation split from train
    train_val = ds['train'].train_test_split(test_size=0.1)
    dataset = DatasetDict({
        'train': train_val['train'],
        'validation': train_val['test'],
        'test': ds['test']
    })
elif 'train' in ds.keys():
    # Dataset only has a train split, so create both validation and test splits
    ds = ds['train'].train_test_split(test_size=0.2)
    train_val = ds['train'].train_test_split(test_size=0.1)
    dataset = DatasetDict({
        'train': train_val['train'],
        'validation': train_val['test'],
        'test': ds['test']
    })
else:
    raise ValueError("The dataset does not have a 'train' split.")


In [13]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 54000
    })
    validation: Dataset({
        features: ['image', 'labels'],
        num_rows: 6000
    })
    test: Dataset({
        features: ['image', 'labels'],
        num_rows: 10000
    })
})

In [None]:

# Initialize feature extractor and model
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

# Transform function
def transform(example_batch):
    images = [np.moveaxis(np.array(x.convert('RGB')), -1, 0) for x in example_batch['image']]
    inputs = feature_extractor(images, return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

# Prepare dataset
prepared_ds = ds.with_transform(transform)

# Data collator
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

# Metric and compute function
metric = load_metric("accuracy", trust_remote_code=True)
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

# Training arguments
training_args = TrainingArguments(
    output_dir="./vit-base--v5",
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

# Trainer initialization
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

# Train and evaluate
train_results = trainer.train()
trainer.save_model()
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(prepared_ds['test'])
trainer.save_metrics("eval", metrics)

# Model card creation
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'custom brats layers',
    "tags": ['image-classification'],
}
if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

In [3]:
trainer.evaluate(prepared_ds['test'])

  0%|          | 0/46 [00:00<?, ?it/s]

{'eval_loss': 0.4978531301021576,
 'eval_accuracy': 0.8256130790190735,
 'eval_runtime': 2.3023,
 'eval_samples_per_second': 159.403,
 'eval_steps_per_second': 19.98,
 'epoch': 4.0}

In [4]:
print(metrics)

{'eval_loss': 0.4978531301021576, 'eval_accuracy': 0.8256130790190735, 'eval_runtime': 2.2373, 'eval_samples_per_second': 164.039, 'eval_steps_per_second': 20.561, 'epoch': 4.0}
