In [None]:
%pip install kaggle transformers[torch] datasets[vision]

^C


In [1]:
import os
import random
import cv2
import torch
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm, trange
from datasets import load_dataset, load_metric, DatasetDict
from PIL import ImageDraw, ImageFont, Image
from transformers import (
    AutoImageProcessor,
    AutoFeatureExtractor,
    ViTImageProcessor,
    ViTForImageClassification,
    ResNetForImageClassification,
    TrainingArguments,
    AutoConfig,
    ViTConfig,
    TrainerCallback
)

print(torch.cuda.is_available())

True


# Load dataset

In [2]:
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):
  w, h = size
  labels = ds['train'].features['labels'].names
  grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
  draw = ImageDraw.Draw(grid)
  font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

  for label_id, label in enumerate(labels):
    ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

    for i, example in enumerate(ds_slice):
      image = example['image']
      idx = examples_per_class * label_id + i
      box = (idx % examples_per_class * w, idx // examples_per_class * h)
      grid.paste(image.resize(size), box=box)
      draw.text(box, label, (255, 255, 255), font=font)

  return grid

## Tomato leaf dataset

In [3]:
tomato_dataset = load_dataset("wellCh4n/tomato-leaf-disease-image")

tomato_dataset = tomato_dataset.rename_column("text", "labels")

In [4]:
from datasets import ClassLabel

texts = tomato_dataset['train']['labels']

name_map = {'A tomato leaf with Septoria Leaf Spot': "Septoria Leaf Spot",
            'A tomato leaf with Target Spot': "Target Spot",
            'A tomato leaf with Bacterial Spot': "Bacterial Spot",
            'A tomato leaf with Tomato Yellow Leaf Curl Virus': "Yellow Leaf Curl Virus",
            'A healthy tomato leaf': "Healthy",
            'A tomato leaf with Tomato Mosaic Virus': "Mosiac Virus",
            'A tomato leaf with Leaf Mold': "Leaf Mold",
            'A tomato leaf with Late Blight': "Late Blight",
            'A tomato leaf with Early Blight': "Early Blight",
            'A tomato leaf with Spider Mites Two-spotted Spider Mite': "Spider Mites"}

labels = list(set(texts))

for idx in trange(len(labels)):
    labels[idx] = name_map[labels[idx]]

print(labels)
class_labels = ClassLabel(names=labels)

id2label = {idx: label for idx, label in enumerate(labels)}
label2id = {label: idx for idx, label in enumerate(labels)}

def encode_labels(example):
    simplified_name = name_map[example['labels']]
    example['labels'] = label2id[simplified_name]

    return example

dataset = tomato_dataset.map(encode_labels)
dataset = dataset.cast_column("labels", class_labels)

  0%|          | 0/10 [00:00<?, ?it/s]

['Spider Mites', 'Healthy', 'Septoria Leaf Spot', 'Mosiac Virus', 'Leaf Mold', 'Yellow Leaf Curl Virus', 'Target Spot', 'Bacterial Spot', 'Early Blight', 'Late Blight']


Map:   0%|          | 0/18160 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/18160 [00:00<?, ? examples/s]

In [5]:
train_test_sets = dataset['train'].train_test_split(test_size=0.2)
test_valid_sets = train_test_sets['test'].train_test_split(test_size=0.5)

dataset = DatasetDict({
    'train': train_test_sets['train'],
    'test': test_valid_sets['test'],
    'valid': test_valid_sets['train'],
})

dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 14528
    })
    test: Dataset({
        features: ['image', 'labels'],
        num_rows: 1816
    })
    valid: Dataset({
        features: ['image', 'labels'],
        num_rows: 1816
    })
})

In [6]:
show_examples(dataset, seed=random.randint(0, 1337), examples_per_class=1, size=(224, 224))

OSError: cannot open resource

# Load models

## ResNet

In [7]:
resnet_path = "microsoft/resnet-50"
resnet_processor = AutoImageProcessor.from_pretrained(resnet_path) # On HG ResNet has a feature extractor, not a preprocessor but they function in the same way

In [8]:
def transform_resnet(example_batch):
    inputs = resnet_processor([x for x in example_batch['image']], return_tensors='pt')

    inputs['labels'] = example_batch['labels']
    return inputs

In [9]:
prepared_dataset_resnet = dataset.map(transform_resnet, batched=True)

Map:   0%|          | 0/14528 [00:00<?, ? examples/s]

Map:   0%|          | 0/1816 [00:00<?, ? examples/s]

Map:   0%|          | 0/1816 [00:00<?, ? examples/s]

## ViT

In [10]:
vit_path = 'google/vit-base-patch16-224-in21k'
vit_processor = ViTImageProcessor.from_pretrained(vit_path)

In [11]:
def transform_vit(example_batch):
  inputs = vit_processor([x for x in example_batch['image']], return_tensors='pt')

  inputs['labels'] = example_batch['labels']
  return inputs

In [12]:
prepared_dataset_vit = dataset.with_transform(transform_vit)

# Training

Based on (https://huggingface.co/blog/fine-tune-vit)

## Full training

## Fine-tuning

In [13]:
def collate_fn(batch):
  # Necessary for ViT
  return {
      'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
      'labels': torch.tensor([x['labels'] for x in batch])
  }

In [14]:
metric = load_metric("accuracy", trust_remote_code=True)

def compute_metrics(p):
  return metric.compute(predictions=np.argmax(p.predictions, axis=1),
                        references=p.label_ids,)

  metric = load_metric("accuracy", trust_remote_code=True)


In [15]:
class TimerTracker(TrainerCallback):
    def __init__(self):
        self.start_time = None
        self.end_time = None

    def on_train_begin(self, args, state, control, **kwargs):
        self.start_time = time.time()

    def on_train_end(self, args, state, control, **kwargs):
        self.end_time = time.time()

        total_time = self.end_time - self.start_time

        if state.is_world_process_zero:
            state.log_history.append({'total_time': total_time})

In [16]:
from transformers import ViTConfig, ResNetConfig

##
## CREATE VIT MODELS
##

vit_config = ViTConfig(
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

vit = ViTForImageClassification(vit_config)

vit_pretrained = ViTForImageClassification.from_pretrained(
    vit_path,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)


##
## CREATE RESNET MODELS
##

resnet_config = ResNetConfig(
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

resnet = ResNetForImageClassification(resnet_config)

resnet_pretrained = ResNetForImageClassification.from_pretrained(
    resnet_path,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# The way HuggingFace initializes the weights makes it more likely we experience
# gradient vanishing (for some reason)
in_features = resnet_pretrained.classifier[-1].in_features
resnet_pretrained.classifier[-1] = torch.nn.Linear(in_features, len(labels))

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([10, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
EPOCHS = 3
BATCH_SIZE = 32
SEED = 42
EVAL_STEPS = 100
LOGGING_STEPS = 100

training_args_vit = TrainingArguments(
  output_dir="./vit-tomato-leaf",
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  evaluation_strategy="steps",
  seed = SEED,
  num_train_epochs=EPOCHS,
  eval_steps=EVAL_STEPS,
  logging_steps=LOGGING_STEPS,
  learning_rate=2e-4,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

training_args_vit_pretrained = TrainingArguments(
  output_dir="./vit-pretrained-tomato-leaf",
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  evaluation_strategy="steps",
  seed = SEED,
  num_train_epochs=EPOCHS,
  eval_steps=EVAL_STEPS,
  logging_steps=LOGGING_STEPS,
  learning_rate=2e-4,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

training_args_resnet = TrainingArguments(
  output_dir="./resnet-tomato-leaf",
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  evaluation_strategy="steps",
  seed = SEED,
  num_train_epochs=EPOCHS,
  logging_dir='./logs',
  eval_steps=EVAL_STEPS,
  logging_steps=LOGGING_STEPS,
  learning_rate=5e-5,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

training_args_resnet_pretrained = TrainingArguments(
  output_dir="./resnet-pretrained-tomato-leaf",
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  evaluation_strategy="steps",
  seed = SEED,
  num_train_epochs=EPOCHS,
  logging_dir='./logs',
  eval_steps=EVAL_STEPS,
  logging_steps=LOGGING_STEPS,
  learning_rate=5e-5,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)



In [18]:
from transformers import DefaultDataCollator

model_stack  = [
    ("resnet", resnet, training_args_resnet, prepared_dataset_resnet, resnet_processor, DefaultDataCollator()),
    ("vit", vit, training_args_vit, prepared_dataset_vit, vit_processor, collate_fn),
    ("resnet_pretrained", resnet_pretrained, training_args_resnet_pretrained, prepared_dataset_resnet, resnet_processor, DefaultDataCollator()),
    ("vit_pretrained", vit_pretrained, training_args_vit_pretrained, prepared_dataset_vit, vit_processor, collate_fn),
]

In [19]:
from transformers import Trainer

for (name, model, training_args, dataset, processor, collate) in tqdm(model_stack):
  print(f"Training: {name}")

  trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate,
    compute_metrics=compute_metrics,
    train_dataset=dataset["train"],
    eval_dataset=dataset["valid"],
    tokenizer=processor,
    callbacks=[TimerTracker()]
  )

  train_results = trainer.train()

  trainer.save_model()
  trainer.log_metrics("train", train_results.metrics)
  trainer.save_metrics("train", train_results.metrics)

  trainer.log_metrics("train", train_results.metrics)
  trainer.save_metrics("train", train_results.metrics)
  trainer.save_state()

  print(f"Metrics for: {name} after training for {trainer.state.log_history[-1]['total_time']}")
  metrics = trainer.evaluate(dataset['test'])
  trainer.log_metrics("eval", metrics)
  trainer.save_metrics("eval", metrics)



  0%|          | 0/4 [00:00<?, ?it/s]

Training: resnet


Step,Training Loss,Validation Loss,Accuracy
100,1.7474,1.384149,0.519824
200,1.2224,1.008597,0.64207
300,0.9479,0.898375,0.689427
400,0.8247,0.779903,0.713656
500,0.7478,0.740646,0.722467
600,0.6358,0.528925,0.805066
700,0.5774,0.493329,0.821035
800,0.5326,0.468809,0.835903
900,0.5513,0.402855,0.851872
1000,0.4592,0.38345,0.863987


***** train metrics *****
  epoch                    =         3.0
  total_flos               = 862562164GF
  train_loss               =      0.7218
  train_runtime            =  2:21:16.35
  train_samples_per_second =       5.142
  train_steps_per_second   =       0.161
***** train metrics *****
  epoch                    =         3.0
  total_flos               = 862562164GF
  train_loss               =      0.7218
  train_runtime            =  2:21:16.35
  train_samples_per_second =       5.142
  train_steps_per_second   =       0.161
Metrics for: resnet after training for 8476.359677553177


***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.8651
  eval_loss               =      0.377
  eval_runtime            = 0:03:24.67
  eval_samples_per_second =      8.872
  eval_steps_per_second   =      0.278
Training: vit


  context_layer = torch.nn.functional.scaled_dot_product_attention(


Step,Training Loss,Validation Loss,Accuracy
100,1.7209,1.327215,0.529736
200,1.1863,1.143225,0.620595
300,0.9891,1.162088,0.551211
400,0.8129,0.759277,0.731828
500,0.6689,0.510452,0.838106
600,0.4947,0.52012,0.820485
700,0.4548,0.516084,0.829846
800,0.4218,0.351454,0.877753
900,0.3624,0.340381,0.876652
1000,0.2714,0.277942,0.906388


***** train metrics *****
  epoch                    =          3.0
  total_flos               = 3145684526GF
  train_loss               =       0.5952
  train_runtime            =   4:34:25.95
  train_samples_per_second =        2.647
  train_steps_per_second   =        0.083
***** train metrics *****
  epoch                    =          3.0
  total_flos               = 3145684526GF
  train_loss               =       0.5952
  train_runtime            =   4:34:25.95
  train_samples_per_second =        2.647
  train_steps_per_second   =        0.083
Metrics for: vit after training for 16465.955320119858


***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.9058
  eval_loss               =     0.2668
  eval_runtime            = 0:01:10.55
  eval_samples_per_second =     25.739
  eval_steps_per_second   =      0.808
Training: resnet_pretrained


Step,Training Loss,Validation Loss,Accuracy
100,2.1675,1.972697,0.309471
200,1.8174,1.612785,0.351322
300,1.609,1.464275,0.537445
400,1.4469,1.27825,0.627753
500,1.2614,1.102868,0.701542
600,1.0739,0.925075,0.744493
700,0.932,0.789578,0.779185
800,0.7798,0.685544,0.80011
900,0.7281,0.604512,0.819934
1000,0.6436,0.550776,0.832048


***** train metrics *****
  epoch                    =         3.0
  total_flos               = 862562164GF
  train_loss               =      1.0671
  train_runtime            =  3:15:58.64
  train_samples_per_second =       3.707
  train_steps_per_second   =       0.116
***** train metrics *****
  epoch                    =         3.0
  total_flos               = 862562164GF
  train_loss               =      1.0671
  train_runtime            =  3:15:58.64
  train_samples_per_second =       3.707
  train_steps_per_second   =       0.116
Metrics for: resnet_pretrained after training for 11758.625509738922


***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.8216
  eval_loss               =     0.5517
  eval_runtime            = 0:03:30.02
  eval_samples_per_second =      8.647
  eval_steps_per_second   =      0.271
Training: vit_pretrained


Step,Training Loss,Validation Loss,Accuracy
100,0.6584,0.220406,0.959251
200,0.194,0.206465,0.94163
300,0.1351,0.10647,0.973568
400,0.0845,0.065736,0.98348
500,0.0718,0.151506,0.954846
600,0.0589,0.067276,0.984581
700,0.0277,0.036966,0.991189
800,0.0195,0.027644,0.995044
900,0.0176,0.03045,0.993943
1000,0.0082,0.029469,0.994493


***** train metrics *****
  epoch                    =          3.0
  total_flos               = 3145684526GF
  train_loss               =        0.095
  train_runtime            =   5:13:12.04
  train_samples_per_second =        2.319
  train_steps_per_second   =        0.072
***** train metrics *****
  epoch                    =          3.0
  total_flos               = 3145684526GF
  train_loss               =        0.095
  train_runtime            =   5:13:12.04
  train_samples_per_second =        2.319
  train_steps_per_second   =        0.072
Metrics for: vit_pretrained after training for 18792.063464403152


***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.9983
  eval_loss               =     0.0122
  eval_runtime            = 0:01:10.33
  eval_samples_per_second =     25.818
  eval_steps_per_second   =       0.81


In [21]:
from random import randint
import torch
from tqdm import tqdm

NUM_ROWS = 100

for (name, model, training_args, _, processor, collate) in tqdm(model_stack):
    rows = [randint(0, tomato_dataset['train'].num_rows-1) for _ in range(NUM_ROWS)]

    samples = tomato_dataset['train'].select(rows)
    total_time = 0
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for sample in tqdm(samples):
        inputs = processor(sample['image'], return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        correct_label = name_map[sample['labels']]
        with torch.no_grad():
            start = time.time()
            outputs = model(**inputs)
            prediction = outputs.logits.argmax(-1).item()
            prediction = id2label[prediction]
            end = time.time()

            total_time += (end - start)
    print(f"Total inference time for {name}: {total_time}s ({total_time / NUM_ROWS}s per sample)")

  0%|                                                                                                   | 0/4 [00:00<?, ?it/s]
  0%|                                                                                                 | 0/100 [00:00<?, ?it/s][A
  1%|▉                                                                                        | 1/100 [00:00<00:31,  3.12it/s][A
  2%|█▊                                                                                       | 2/100 [00:00<00:26,  3.76it/s][A
  3%|██▋                                                                                      | 3/100 [00:00<00:24,  4.00it/s][A
  4%|███▌                                                                                     | 4/100 [00:01<00:22,  4.18it/s][A
  5%|████▍                                                                                    | 5/100 [00:01<00:22,  4.25it/s][A
  6%|█████▎                                                                                  

Total inference time for resnet: 22.281595468521118s (0.22281595468521118s per sample)



  0%|                                                                                                 | 0/100 [00:00<?, ?it/s][A
  2%|█▊                                                                                       | 2/100 [00:00<00:08, 11.98it/s][A
  4%|███▌                                                                                     | 4/100 [00:00<00:07, 12.37it/s][A
  6%|█████▎                                                                                   | 6/100 [00:00<00:07, 12.50it/s][A
  8%|███████                                                                                  | 8/100 [00:00<00:07, 12.10it/s][A
 10%|████████▊                                                                               | 10/100 [00:00<00:07, 11.95it/s][A
 12%|██████████▌                                                                             | 12/100 [00:00<00:07, 12.10it/s][A
 14%|████████████▎                                                                       

Total inference time for vit: 7.144994020462036s (0.07144994020462037s per sample)



  0%|                                                                                                 | 0/100 [00:00<?, ?it/s][A
  2%|█▊                                                                                       | 2/100 [00:00<00:04, 19.82it/s][A
  4%|███▌                                                                                     | 4/100 [00:00<00:05, 19.01it/s][A
  6%|█████▎                                                                                   | 6/100 [00:00<00:04, 18.85it/s][A
  8%|███████                                                                                  | 8/100 [00:00<00:04, 19.14it/s][A
 10%|████████▊                                                                               | 10/100 [00:00<00:04, 19.21it/s][A
 12%|██████████▌                                                                             | 12/100 [00:00<00:04, 18.87it/s][A
 14%|████████████▎                                                                       

Total inference time for resnet_pretrained: 4.461665153503418s (0.04461665153503418s per sample)



  0%|                                                                                                 | 0/100 [00:00<?, ?it/s][A
  2%|█▊                                                                                       | 2/100 [00:00<00:05, 17.35it/s][A
  4%|███▌                                                                                     | 4/100 [00:00<00:05, 16.42it/s][A
  6%|█████▎                                                                                   | 6/100 [00:00<00:05, 16.66it/s][A
  8%|███████                                                                                  | 8/100 [00:00<00:05, 17.06it/s][A
 10%|████████▊                                                                               | 10/100 [00:00<00:05, 16.79it/s][A
 12%|██████████▌                                                                             | 12/100 [00:00<00:05, 16.98it/s][A
 14%|████████████▎                                                                       

Total inference time for vit_pretrained: 4.880549192428589s (0.04880549192428589s per sample)



