<a href="https://colab.research.google.com/github/viniciusrpb/cloud_image_segmentation/blob/main/cloud_classification_ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Green Microalga Classification using Vision Transformers

In [1]:
#from google.colab import drive
#drive.mount('/content/drive')

In [2]:
#!cp -r "/content/drive/My Drive/img_satelite/classificacao/CCSN/train" "training"
#!cp -r "/content/drive/My Drive/img_satelite/classificacao/CCSN/val" "validation"
#!cp -r "/content/drive/My Drive/img_satelite/classificacao/CCSN/test" "testing"

In [3]:
#!pip install pytorch pytorch torchvision
#!pip install timm==0.3.2
#!pip install datasets transformers
#!pip install transformers pytorch-lightning --quiet
#!sudo apt -qq install git-lfs

In [4]:
from datasets import load_dataset
import tensorflow as tf
import torchvision
from torchvision.transforms import ToTensor
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, UnidentifiedImageError
from pathlib import Path
import torch
import glob
import pytorch_lightning as pl
from huggingface_hub import HfApi, Repository
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy
from transformers import ViTFeatureExtractor,ViTForImageClassification,DeiTForImageClassification,BeitForImageClassification,DeiTFeatureExtractor,  BeitFeatureExtractor
from pytorch_lightning.callbacks import ModelCheckpoint

In [5]:
path_train = 'training'
path_validation = 'validation'
path_test = 'testing'

Define the image generator objects

In [6]:
train_ds = torchvision.datasets.ImageFolder(path_train, transform=ToTensor())
valid_ds = torchvision.datasets.ImageFolder(path_validation, transform=ToTensor())
test_ds = torchvision.datasets.ImageFolder(path_test, transform=ToTensor())

In [7]:
train_ds.classes

['Ac', 'As', 'Cb', 'Cc', 'Ci', 'Cs', 'Ct', 'Cu', 'Ns', 'Sc', 'St']

In [8]:
def fn_collator(batch):
    encodings = feature_extractor([x[0] for x in batch], return_tensors='pt')
    encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
    return encodings 

Pega os códigos das classes do dataset

In [9]:
dic_label2id = {}
dic_id2label = {}
for i, class_name in enumerate(train_ds.classes):
  dic_label2id[class_name] = str(i)
  dic_id2label[str(i)] = class_name

Allocate objects for loading the data using the DataGenerator

In [10]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)


In [11]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy","f1-score")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [12]:
from transformers import ViTForImageClassification

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(train_ds.classes),
    id2label=dic_id2label,
    label2id=dic_label2id)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-clouds",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=20,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=1e-5,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [14]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=fn_collator,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=feature_extractor,
)

Using amp half precision backend


In [15]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

***** Running training *****
  Num examples = 1774
  Num Epochs = 20
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 2220


Step,Training Loss,Validation Loss,Accuracy
100,2.2754,2.238539,0.356
200,1.9629,2.021473,0.436
300,1.9279,1.862088,0.496
400,1.6787,1.753656,0.496


***** Running Evaluation *****
  Num examples = 250
  Batch size = 8
Saving model checkpoint to ./vit-base-clouds/checkpoint-100
Configuration saved in ./vit-base-clouds/checkpoint-100/config.json
Model weights saved in ./vit-base-clouds/checkpoint-100/pytorch_model.bin
Feature extractor saved in ./vit-base-clouds/checkpoint-100/preprocessor_config.json
Deleting older checkpoint [vit-base-clouds/checkpoint-200] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 250
  Batch size = 8
Saving model checkpoint to ./vit-base-clouds/checkpoint-200
Configuration saved in ./vit-base-clouds/checkpoint-200/config.json
Model weights saved in ./vit-base-clouds/checkpoint-200/pytorch_model.bin
Feature extractor saved in ./vit-base-clouds/checkpoint-200/preprocessor_config.json
Deleting older checkpoint [vit-base-clouds/checkpoint-800] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 250
  Batch size = 8
Saving model checkpoint to ./vit-base-cloud

Step,Training Loss,Validation Loss,Accuracy
100,2.2754,2.238539,0.356
200,1.9629,2.021473,0.436
300,1.9279,1.862088,0.496
400,1.6787,1.753656,0.496
500,1.517,1.684789,0.496
600,1.4436,1.627982,0.508
700,1.2844,1.582013,0.5
800,1.2259,1.559297,0.536
900,1.0839,1.524139,0.54
1000,1.0853,1.513997,0.536


***** Running Evaluation *****
  Num examples = 250
  Batch size = 8
Saving model checkpoint to ./vit-base-clouds/checkpoint-500
Configuration saved in ./vit-base-clouds/checkpoint-500/config.json
Model weights saved in ./vit-base-clouds/checkpoint-500/pytorch_model.bin
Feature extractor saved in ./vit-base-clouds/checkpoint-500/preprocessor_config.json
Deleting older checkpoint [vit-base-clouds/checkpoint-300] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 250
  Batch size = 8
Saving model checkpoint to ./vit-base-clouds/checkpoint-600
Configuration saved in ./vit-base-clouds/checkpoint-600/config.json
Model weights saved in ./vit-base-clouds/checkpoint-600/pytorch_model.bin
Feature extractor saved in ./vit-base-clouds/checkpoint-600/preprocessor_config.json
Deleting older checkpoint [vit-base-clouds/checkpoint-400] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 250
  Batch size = 8
Saving model checkpoint to ./vit-base-cloud

***** train metrics *****
  epoch                    =         20.0
  total_flos               = 2560799541GF
  train_loss               =       1.1512
  train_runtime            =   1:33:05.07
  train_samples_per_second =        6.353
  train_steps_per_second   =        0.397


In [16]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** Running Evaluation *****
  Num examples = 519
  Batch size = 8


***** eval metrics *****
  epoch                   =       20.0
  eval_accuracy           =     0.5684
  eval_loss               =     1.3881
  eval_runtime            = 0:00:30.17
  eval_samples_per_second =     17.202
  eval_steps_per_second   =      2.154
