In [1]:
# Ensure all necessary imports are present
from transformers import ViTFeatureExtractor
import numpy as np
from PIL import Image
from datasets import load_dataset
from datasets import DatasetDict
import torch

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
ds = load_dataset('betul2')
ds = ds.rename_column('label', 'labels')

# Split the dataset
ds = ds['train'].train_test_split(test_size=0.1)
train_val = ds['train'].train_test_split(test_size=0.1)
ds['train'] = train_val['train']
ds['validation'] = train_val['test']
ds['test'] = ds['test']

# Convert to DatasetDict
dataset = DatasetDict(ds)


Using the latest cached version of the dataset since betul2 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at C:\Users\semih\.cache\huggingface\datasets\betul2\default\0.0.0\c8b151b1a7bf91b4 (last modified on Thu Aug 22 16:40:50 2024).


In [3]:
# Feature extractor initialization
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)



In [4]:
# Function to process a single example
def process_example(example):
    image = example['image']
    # Convert the PIL image to RGB format
    image = image.convert('RGB')
    # Convert the PIL image to numpy array
    image_array = np.array(image)
    # Ensure the image is in (height, width, channels) format and rearrange to (channels, height, width)
    if image_array.ndim == 3 and image_array.shape[-1] in [1, 3, 4]:  # last dimension is channels
        image_array = np.moveaxis(image_array, -1, 0)
    # Now the image should be in (channels, height, width) format
    inputs = feature_extractor(image_array, return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs


In [5]:
# Test the process_example function
processed_example = process_example(ds['train'][0])
print(processed_example)

{'pixel_values': tensor([[[[ 0.0039,  0.0039,  0.0353,  ...,  0.0039, -0.0431,  0.0353],
          [ 0.0039,  0.0039,  0.0353,  ...,  0.0039, -0.0431,  0.0353],
          [ 0.0039,  0.0039,  0.0039,  ...,  0.0039, -0.0431,  0.0431],
          ...,
          [ 0.0039,  0.0039,  0.0039,  ...,  0.0039,  0.0039,  0.0039],
          [-0.0275, -0.0275,  0.0039,  ...,  0.0039,  0.0039,  0.0039],
          [-0.0196, -0.0275,  0.0039,  ...,  0.0039,  0.0039,  0.0039]],

         [[ 0.0039,  0.0039,  0.0353,  ...,  0.0039, -0.0118,  0.0353],
          [ 0.0039,  0.0039,  0.0353,  ...,  0.0039, -0.0118,  0.0353],
          [ 0.0039,  0.0039,  0.0039,  ...,  0.0039, -0.0118,  0.0431],
          ...,
          [ 0.0039,  0.0039,  0.0039,  ...,  0.0039,  0.0039,  0.0039],
          [-0.0275, -0.0275,  0.0039,  ...,  0.0039,  0.0039,  0.0039],
          [-0.0196, -0.0275,  0.0039,  ...,  0.0039,  0.0039,  0.0039]],

         [[ 0.0039,  0.0039,  0.0039,  ...,  0.0039,  0.0039,  0.0039],
          [ 0

In [6]:
# Function to transform a batch of examples
def transform(example_batch):
    images = [x.convert('RGB') for x in example_batch['image']]
    # Ensure each image is in the correct format
    images = [np.moveaxis(np.array(img), -1, 0) if img.ndim == 3 and img.shape[-1] in [1, 3, 4] else np.array(img) for img in images]
    inputs = feature_extractor(images, return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

In [7]:
# Function to transform a batch of examples
def transform(example_batch):
    images = [x.convert('RGB') for x in example_batch['image']]
    # Ensure each image is in the correct format
    images = [np.moveaxis(np.array(img), -1, 0) if np.array(img).ndim == 3 and np.array(img).shape[-1] in [1, 3, 4] else np.array(img) for img in images]
    inputs = feature_extractor(images, return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

In [8]:
# Apply the transform function to the dataset
prepared_ds = ds.with_transform(transform)

In [9]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [10]:
# Define evaluation metric
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy", trust_remote_code=True)
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  metric = load_metric("accuracy", trust_remote_code=True)


In [11]:
# Load pretrained model
from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Move the model to the GPU
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [13]:
# Define training arguments
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base--v5",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)



In [14]:
# Initialize Trainer
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)




In [15]:
import torch

# Train the model
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluate the model
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Push to hub or create model card
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'custom brats layers',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

  0%|          | 0/744 [00:00<?, ?it/s]

  context_layer = torch.nn.functional.scaled_dot_product_attention(


{'loss': 1.1088, 'grad_norm': 1.39084792137146, 'learning_rate': 0.00019731182795698925, 'epoch': 0.05}
{'loss': 0.8508, 'grad_norm': 1.115374207496643, 'learning_rate': 0.0001946236559139785, 'epoch': 0.11}
{'loss': 0.8677, 'grad_norm': 1.134900450706482, 'learning_rate': 0.00019193548387096775, 'epoch': 0.16}
{'loss': 0.8366, 'grad_norm': 1.3971041440963745, 'learning_rate': 0.000189247311827957, 'epoch': 0.22}
{'loss': 0.7954, 'grad_norm': 1.8755857944488525, 'learning_rate': 0.00018655913978494625, 'epoch': 0.27}
{'loss': 0.7438, 'grad_norm': 2.1053378582000732, 'learning_rate': 0.00018387096774193548, 'epoch': 0.32}
{'loss': 0.8199, 'grad_norm': 1.3887600898742676, 'learning_rate': 0.00018118279569892475, 'epoch': 0.38}
{'loss': 0.8945, 'grad_norm': 5.668015956878662, 'learning_rate': 0.00017849462365591398, 'epoch': 0.43}
{'loss': 0.673, 'grad_norm': 1.4448820352554321, 'learning_rate': 0.00017580645161290325, 'epoch': 0.48}
{'loss': 0.7009, 'grad_norm': 2.706613063812256, 'learn

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.7025018930435181, 'eval_accuracy': 0.7636363636363637, 'eval_runtime': 2.2207, 'eval_samples_per_second': 148.604, 'eval_steps_per_second': 18.913, 'epoch': 0.54}
{'loss': 0.6874, 'grad_norm': 1.2416284084320068, 'learning_rate': 0.00017043010752688172, 'epoch': 0.59}
{'loss': 0.7256, 'grad_norm': 1.7809841632843018, 'learning_rate': 0.00016774193548387098, 'epoch': 0.65}
{'loss': 0.7037, 'grad_norm': 0.9938160181045532, 'learning_rate': 0.00016505376344086022, 'epoch': 0.7}
{'loss': 0.6592, 'grad_norm': 0.8089711666107178, 'learning_rate': 0.00016236559139784946, 'epoch': 0.75}
{'loss': 0.7134, 'grad_norm': 0.8656672835350037, 'learning_rate': 0.00015967741935483872, 'epoch': 0.81}
{'loss': 0.7579, 'grad_norm': 0.8881009817123413, 'learning_rate': 0.00015698924731182796, 'epoch': 0.86}
{'loss': 0.6383, 'grad_norm': 3.4207422733306885, 'learning_rate': 0.00015430107526881722, 'epoch': 0.91}
{'loss': 0.7042, 'grad_norm': 1.1370435953140259, 'learning_rate': 0.00015161290

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.6782025694847107, 'eval_accuracy': 0.7727272727272727, 'eval_runtime': 2.17, 'eval_samples_per_second': 152.073, 'eval_steps_per_second': 19.355, 'epoch': 1.08}
{'loss': 0.5988, 'grad_norm': 1.0203512907028198, 'learning_rate': 0.00014354838709677422, 'epoch': 1.13}
{'loss': 0.7319, 'grad_norm': 1.37605881690979, 'learning_rate': 0.00014086021505376346, 'epoch': 1.18}
{'loss': 0.6869, 'grad_norm': 1.5589251518249512, 'learning_rate': 0.0001381720430107527, 'epoch': 1.24}
{'loss': 0.5267, 'grad_norm': 0.7937272787094116, 'learning_rate': 0.00013548387096774193, 'epoch': 1.29}
{'loss': 0.5889, 'grad_norm': 1.6712664365768433, 'learning_rate': 0.0001327956989247312, 'epoch': 1.34}
{'loss': 0.752, 'grad_norm': 1.6246939897537231, 'learning_rate': 0.00013010752688172043, 'epoch': 1.4}
{'loss': 0.5647, 'grad_norm': 0.6406051516532898, 'learning_rate': 0.0001274193548387097, 'epoch': 1.45}
{'loss': 0.5053, 'grad_norm': 1.1794164180755615, 'learning_rate': 0.0001247311827956989

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.6262461543083191, 'eval_accuracy': 0.7818181818181819, 'eval_runtime': 1.8996, 'eval_samples_per_second': 173.718, 'eval_steps_per_second': 22.11, 'epoch': 1.61}
{'loss': 0.6414, 'grad_norm': 0.6208717823028564, 'learning_rate': 0.00011666666666666668, 'epoch': 1.67}
{'loss': 0.5549, 'grad_norm': 1.3186957836151123, 'learning_rate': 0.00011397849462365593, 'epoch': 1.72}
{'loss': 0.6742, 'grad_norm': 2.5451791286468506, 'learning_rate': 0.00011129032258064515, 'epoch': 1.77}
{'loss': 0.6711, 'grad_norm': 1.1998966932296753, 'learning_rate': 0.0001086021505376344, 'epoch': 1.83}
{'loss': 0.5431, 'grad_norm': 1.3649810552597046, 'learning_rate': 0.00010591397849462365, 'epoch': 1.88}
{'loss': 0.4781, 'grad_norm': 1.3462424278259277, 'learning_rate': 0.0001032258064516129, 'epoch': 1.94}
{'loss': 0.6391, 'grad_norm': 1.4028325080871582, 'learning_rate': 0.00010053763440860215, 'epoch': 1.99}
{'loss': 0.4725, 'grad_norm': 1.4146496057510376, 'learning_rate': 9.7849462365591

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5823186635971069, 'eval_accuracy': 0.7848484848484848, 'eval_runtime': 1.9127, 'eval_samples_per_second': 172.529, 'eval_steps_per_second': 21.958, 'epoch': 2.15}
{'loss': 0.4014, 'grad_norm': 0.9386455416679382, 'learning_rate': 8.978494623655914e-05, 'epoch': 2.2}
{'loss': 0.4946, 'grad_norm': 0.6845691204071045, 'learning_rate': 8.709677419354839e-05, 'epoch': 2.26}
{'loss': 0.5419, 'grad_norm': 1.3128761053085327, 'learning_rate': 8.440860215053764e-05, 'epoch': 2.31}
{'loss': 0.5379, 'grad_norm': 1.4756325483322144, 'learning_rate': 8.172043010752689e-05, 'epoch': 2.37}
{'loss': 0.4826, 'grad_norm': 5.35335636138916, 'learning_rate': 7.903225806451613e-05, 'epoch': 2.42}
{'loss': 0.4639, 'grad_norm': 0.6271783709526062, 'learning_rate': 7.634408602150538e-05, 'epoch': 2.47}
{'loss': 0.5326, 'grad_norm': 2.0034284591674805, 'learning_rate': 7.365591397849463e-05, 'epoch': 2.53}
{'loss': 0.5112, 'grad_norm': 1.7605661153793335, 'learning_rate': 7.096774193548388e-05,

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5916585922241211, 'eval_accuracy': 0.806060606060606, 'eval_runtime': 1.913, 'eval_samples_per_second': 172.506, 'eval_steps_per_second': 21.955, 'epoch': 2.69}
{'loss': 0.5579, 'grad_norm': 2.109804153442383, 'learning_rate': 6.290322580645161e-05, 'epoch': 2.74}
{'loss': 0.5259, 'grad_norm': 1.6437660455703735, 'learning_rate': 6.021505376344086e-05, 'epoch': 2.8}
{'loss': 0.3856, 'grad_norm': 1.876278281211853, 'learning_rate': 5.752688172043011e-05, 'epoch': 2.85}
{'loss': 0.4483, 'grad_norm': 2.5249826908111572, 'learning_rate': 5.4838709677419355e-05, 'epoch': 2.9}
{'loss': 0.3772, 'grad_norm': 1.4441992044448853, 'learning_rate': 5.2150537634408605e-05, 'epoch': 2.96}
{'loss': 0.4079, 'grad_norm': 5.1416497230529785, 'learning_rate': 4.9462365591397855e-05, 'epoch': 3.01}
{'loss': 0.2888, 'grad_norm': 1.4264943599700928, 'learning_rate': 4.67741935483871e-05, 'epoch': 3.06}
{'loss': 0.381, 'grad_norm': 1.0494979619979858, 'learning_rate': 4.408602150537635e-05, '

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5405390858650208, 'eval_accuracy': 0.793939393939394, 'eval_runtime': 1.9088, 'eval_samples_per_second': 172.887, 'eval_steps_per_second': 22.004, 'epoch': 3.23}
{'loss': 0.2442, 'grad_norm': 1.3763409852981567, 'learning_rate': 3.602150537634409e-05, 'epoch': 3.28}
{'loss': 0.3434, 'grad_norm': 4.556057929992676, 'learning_rate': 3.3333333333333335e-05, 'epoch': 3.33}
{'loss': 0.4097, 'grad_norm': 0.828928530216217, 'learning_rate': 3.0645161290322585e-05, 'epoch': 3.39}
{'loss': 0.3498, 'grad_norm': 3.557579755783081, 'learning_rate': 2.7956989247311828e-05, 'epoch': 3.44}
{'loss': 0.2932, 'grad_norm': 2.2648048400878906, 'learning_rate': 2.5268817204301075e-05, 'epoch': 3.49}
{'loss': 0.2564, 'grad_norm': 2.0926642417907715, 'learning_rate': 2.258064516129032e-05, 'epoch': 3.55}
{'loss': 0.2702, 'grad_norm': 3.5984246730804443, 'learning_rate': 1.989247311827957e-05, 'epoch': 3.6}
{'loss': 0.2812, 'grad_norm': 0.8931816816329956, 'learning_rate': 1.7204301075268818e-

  0%|          | 0/42 [00:00<?, ?it/s]

{'eval_loss': 0.5512449145317078, 'eval_accuracy': 0.8, 'eval_runtime': 1.8804, 'eval_samples_per_second': 175.491, 'eval_steps_per_second': 22.335, 'epoch': 3.76}
{'loss': 0.3536, 'grad_norm': 1.8082451820373535, 'learning_rate': 9.13978494623656e-06, 'epoch': 3.82}
{'loss': 0.2789, 'grad_norm': 0.4126892685890198, 'learning_rate': 6.451612903225806e-06, 'epoch': 3.87}
{'loss': 0.2241, 'grad_norm': 0.8909844756126404, 'learning_rate': 3.763440860215054e-06, 'epoch': 3.92}
{'loss': 0.2632, 'grad_norm': 2.3688242435455322, 'learning_rate': 1.0752688172043011e-06, 'epoch': 3.98}
{'train_runtime': 163.7833, 'train_samples_per_second': 72.413, 'train_steps_per_second': 4.543, 'train_loss': 0.5439459455590094, 'epoch': 4.0}
***** train metrics *****
  epoch                    =         4.0
  total_flos               = 855959680GF
  train_loss               =      0.5439
  train_runtime            =  0:02:43.78
  train_samples_per_second =      72.413
  train_steps_per_second   =       4.543

  0%|          | 0/46 [00:00<?, ?it/s]

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =     0.7902
  eval_loss               =     0.5956
  eval_runtime            = 0:00:02.43
  eval_samples_per_second =    150.539
  eval_steps_per_second   =     18.869
