In [58]:
import transformers as tf
import os
import datasets as ds
import evaluate
import torch
from torchvision.transforms import RandomResizedCrop, Compose, Normalize

In [85]:
name = "blanchon/EuroSAT_RGB"
train_ds = ds.load_dataset(name, split="train", cache_dir = os.environ['PSCRATCH'])
val_ds = ds.load_dataset(name, split="validation", cache_dir = os.environ['PSCRATCH'])
test_ds = ds.load_dataset(name, split="test", cache_dir = os.environ['PSCRATCH'])

labels = train_ds.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

'Herbaceous Vegetation'

In [84]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = tf.AutoImageProcessor.from_pretrained(checkpoint, use_fast=True, cache_dir=os.environ['PSCRATCH'])

image_processor

if 'height' in image_processor.size:
    size = (image_processor.size['height'], image_processor.size['width'])
    crop_size=size
    max_size=None
else:
    size = image_processor.size['shortest_edge']
    crop_size=(size, size)
    max_size=image_processor.size.get('longest_edge')


In [96]:
model = tf.ViTModel.from_pretrained(checkpoint)

inputs = image_processor( train_ds[0]['image'], return_tensors="pt" )

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

In [98]:
last_hidden_states.shape

torch.Size([1, 197, 768])

In [45]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

train_transforms = Compose([RandomResizedCrop(crop_size), RandomHorizontalFlip(), ToTensor(), normalize])
val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    del example_batch["image"]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    del example_batch["image"]
    return example_batch

In [67]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [75]:
train_ds[0]['pixel_values'].shape

torch.Size([3, 224, 224])

In [82]:
image_processor(train_ds[0]['pixel_values'], return_tensors="pt")['pixel_values']

tensor([[[[-1.0014, -1.0014, -1.0014,  ..., -1.0023, -1.0023, -1.0023],
          [-1.0014, -1.0014, -1.0014,  ..., -1.0023, -1.0023, -1.0023],
          [-1.0014, -1.0014, -1.0014,  ..., -1.0023, -1.0023, -1.0023],
          ...,
          [-1.0018, -1.0018, -1.0018,  ..., -1.0018, -1.0018, -1.0018],
          [-1.0018, -1.0018, -1.0018,  ..., -1.0018, -1.0018, -1.0018],
          [-1.0018, -1.0018, -1.0018,  ..., -1.0018, -1.0018, -1.0018]],

         [[-1.0024, -1.0024, -1.0024,  ..., -1.0028, -1.0028, -1.0028],
          [-1.0024, -1.0024, -1.0024,  ..., -1.0028, -1.0028, -1.0028],
          [-1.0024, -1.0024, -1.0024,  ..., -1.0028, -1.0028, -1.0028],
          ...,
          [-1.0024, -1.0024, -1.0024,  ..., -1.0025, -1.0025, -1.0025],
          [-1.0024, -1.0024, -1.0024,  ..., -1.0025, -1.0025, -1.0025],
          [-1.0024, -1.0024, -1.0024,  ..., -1.0025, -1.0025, -1.0025]],

         [[-1.0021, -1.0021, -1.0021,  ..., -1.0023, -1.0023, -1.0023],
          [-1.0021, -1.0021, -

In [83]:
train_ds[0]['pixel_values']

tensor([[[ 0.0588,  0.0588,  0.0588,  ..., -0.3098, -0.3098, -0.3098],
         [ 0.0588,  0.0588,  0.0588,  ..., -0.3098, -0.3098, -0.3098],
         [ 0.0588,  0.0588,  0.0588,  ..., -0.3098, -0.3098, -0.3098],
         ...,
         [-0.2941, -0.2941, -0.2941,  ..., -0.3098, -0.3098, -0.3098],
         [-0.2941, -0.2941, -0.2941,  ..., -0.3098, -0.3098, -0.3098],
         [-0.2941, -0.2941, -0.2941,  ..., -0.3098, -0.3098, -0.3098]],

        [[-0.1529, -0.1529, -0.1529,  ..., -0.3412, -0.3412, -0.3412],
         [-0.1529, -0.1529, -0.1529,  ..., -0.3412, -0.3412, -0.3412],
         [-0.1529, -0.1529, -0.1529,  ..., -0.3412, -0.3412, -0.3412],
         ...,
         [-0.3490, -0.3490, -0.3490,  ..., -0.3725, -0.3725, -0.3725],
         [-0.3490, -0.3490, -0.3490,  ..., -0.3725, -0.3725, -0.3725],
         [-0.3490, -0.3490, -0.3490,  ..., -0.3725, -0.3725, -0.3725]],

        [[-0.1451, -0.1451, -0.1451,  ..., -0.3176, -0.3176, -0.3176],
         [-0.1451, -0.1451, -0.1451,  ..., -0

In [52]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint, 
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [55]:
model_name = checkpoint.split("/")[-1]

batch_size = 32

args = TrainingArguments(
    f"{model_name}-finetuned-eurosat",
    remove_unused_columns=False,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

In [59]:
import numpy as np

metric = evaluate.load("accuracy")

# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [69]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [70]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    processing_class=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [71]:
train_results = trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4604,0.38684,0.965741
2,0.2565,0.208515,0.980741


In [72]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** train metrics *****
  epoch                    =       2.9783
  total_flos               = 3484041929GF
  train_loss               =       0.5828
  train_runtime            =   0:17:13.29
  train_samples_per_second =       47.034
  train_steps_per_second   =        0.366


In [73]:
train_results

TrainOutput(global_step=378, training_loss=0.5828112521499553, metrics={'train_runtime': 1033.2911, 'train_samples_per_second': 47.034, 'train_steps_per_second': 0.366, 'total_flos': 3.740961535884067e+18, 'train_loss': 0.5828112521499553, 'epoch': 2.978303747534517})

In [74]:
metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                   =     2.9783
  eval_accuracy           =     0.9807
  eval_loss               =     0.2085
  eval_runtime            = 0:00:49.13
  eval_samples_per_second =    109.904
  eval_steps_per_second   =       3.44
