In [14]:
import pandas as pd
import pathlib
import keras
import cv2
import glob
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from keras.applications.vgg16 import preprocess_input
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from transformers import TFAutoModelForImageClassification, AutoImageProcessor

from sklearn.metrics import accuracy_score
from tensorflow.keras import layers, models, optimizers

from tensorflow.keras.applications import ResNet50, imagenet_utils
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions

from tensorflow.keras.preprocessing import image

from tensorflow.keras.applications import DenseNet121

from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, concatenate, UpSampling2D

# Trying to use a transformer for classification

## Importing libraries and images

In [32]:
import cv2
import glob
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification, TFViTForImageClassification

# Define your URLs and limits
url_normal = '../data/MR/knee/normal_filter/'
url_airspace = '../data/MR/knee/acl_pathology_filter/'
url_bronch = '../data/MR/knee/bone_inflammation_filter/'
url_inter = '../data/MR/knee/chondral_abnormality/'
url_nodule = '../data/MR/knee/fracture_filter/'
url_parenchyma = '../data/MR/knee/hematoma_filter/'

limite = 100  # Adjust the limit as needed

In [33]:
# Load images
images_normal = [cv2.imread(file) for file in glob.glob(url_normal + "*.png")][:limite]
images_airspace = [cv2.imread(file) for file in glob.glob(url_airspace + "*.png")][:limite]
images_bronch = [cv2.imread(file) for file in glob.glob(url_bronch + "*.png")][:limite]
images_inter = [cv2.imread(file) for file in glob.glob(url_inter + "*.png")][:limite]
images_nodule = [cv2.imread(file) for file in glob.glob(url_nodule + "*.png")][:limite]
images_parenchyma = [cv2.imread(file) for file in glob.glob(url_parenchyma + "*.png")][:limite]

# Assign labels
labels_normal = [0] * len(images_normal)
labels_airspace = [1] * len(images_airspace)
labels_bronch = [2] * len(images_bronch)
labels_inter = [3] * len(images_inter)
labels_nodule = [4] * len(images_nodule)
labels_parenchyma = [5] * len(images_parenchyma)

## Importing ViT model and preparing data

In [55]:
# Concatenate data
X = np.concatenate((images_normal, images_airspace, images_bronch, images_inter, images_nodule, images_parenchyma), axis=0)
y = np.concatenate((labels_normal, labels_airspace, labels_bronch, labels_inter, labels_nodule, labels_parenchyma), axis=0)

# Splitting data
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.3, shuffle=True)
X_test, X_val, y_test, y_val = train_test_split(X_val, y_val, test_size=0.5)

# Use the ViT feature extractor
vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')  

X_train_vit = vit_feature_extractor(X_train, return_tensors="np", padding=True, truncation=True)
X_val_vit = vit_feature_extractor(X_val, return_tensors="np", padding=True, truncation=True)
X_test_vit = vit_feature_extractor(X_test, return_tensors="np", padding=True, truncation=True)

# Reshape the input to match the expected shape of the ViT model
X_train_vit_reshaped = X_train_vit['pixel_values']
X_val_vit_reshaped = X_val_vit['pixel_values']
X_test_vit_reshaped = X_test_vit['pixel_values']

# Build a simple model
num_classes = 6

input_shape = X_train_vit_reshaped.shape[1:]

# Creating the ViT model
vit_model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)

# Setting the layers of the ViT model to be non-trainable
for layer in vit_model.layers:
    layer.trainable = False

# Creating a Sequential model with the ViT model and a Dense layer with softmax activation
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=input_shape),  
    vit_model,  
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax')  # Dense layer with softmax activation for classification
])

# Compiling the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Printing the model summary
model.summary()

Some layers from the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing TFViTForImageClassification: ['vit/pooler/dense/bias:0', 'vit/pooler/dense/kernel:0']
- This IS expected if you are initializing TFViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model: "sequential_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 tf_vi_t_for_image_classifi  TFSequenceClassifierOut   85803270  
 cation_22 (TFViTForImageCl  put(loss=None, logits=(             
 assification)               None, 6),                           
                              hidden_states=None, at             
                             tentions=None)                      
                                                                 
 dense_24 (Dense)            (None, 256)               1792      
                                                                 
 dense_25 (Dense)            (None, 6)                 1542      
                                                                 
Total params: 85806604 (327.33 MB)
Trainable params: 3334 (13.02 KB)
Non-trainable params: 85803270 (327.31 MB)
_________________________________________________________________


## Training model

In [56]:
epochs = 5
batch_size = 32

history = model.fit(
    X_train_vit['pixel_values'],  
    y_train,
    validation_data=(X_val_vit['pixel_values'], y_val),
    epochs=epochs,
    batch_size=batch_size
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [52]:
history.history

array([2, 4, 1, 4, 2, 0, 1, 3, 0, 4, 0, 2, 0, 1, 0, 5, 3, 1, 2, 4, 3, 3,
       2, 3, 2, 5, 0, 1, 5, 3, 2, 2, 4, 0, 0, 2, 1, 1, 0, 0, 4, 2, 3, 4,
       3, 2, 4, 2, 4, 1, 2, 4, 0, 0, 2, 3, 1, 4, 3, 1, 5, 5, 5, 2, 4, 2,
       2, 0, 0, 0, 0, 2, 3, 2, 1, 2, 3, 3, 1, 1, 0, 4, 3, 2, 1, 3, 1, 3,
       1, 2, 3, 3, 5, 2, 1, 1, 4, 4, 4, 3, 5, 4, 0, 5, 5, 1, 0, 5, 1, 3,
       0, 2, 5, 3, 3, 2, 3, 3, 3, 3, 0, 4, 3, 2, 5, 4, 0, 1, 5, 2, 4, 4,
       4, 2, 3, 4, 0, 2, 0, 1, 1, 1, 3, 4, 0, 1, 3, 3, 5, 2, 4, 5, 2, 0,
       2, 4, 3, 3, 5, 1, 4, 1, 5, 2, 3, 4, 5, 2, 4, 3, 5, 5, 5, 2, 2, 2,
       0, 0, 5, 1, 4, 2, 1, 4, 1, 3, 2, 5, 4, 4, 1, 3, 3, 0, 0, 3, 5, 1,
       2, 1, 4, 1, 4, 3, 3, 4, 5, 4, 0, 4, 1, 0, 1, 5, 2, 4, 0, 5, 0, 1,
       4, 1, 2, 3, 1, 0, 0, 3, 3, 3, 4, 0, 2, 0, 1, 5, 0, 1, 0, 1, 2, 0,
       2, 0, 4, 5, 4, 3, 1, 3, 1, 2, 1, 4, 3, 1, 5, 3, 4, 3, 2, 3, 2, 4,
       3, 1, 3, 4, 3, 3, 4, 3, 0, 0, 5, 1, 1, 0, 4, 2, 3, 2, 1, 2, 4, 3,
       0, 2, 3, 3, 3, 0, 2, 3, 4, 3, 4, 2, 1, 3, 5,

## Conclusion

Accuracy of vision transformer model (less than 20%) is really low compared with CNN (more than 50%), data quantity is not sufficient for transformer usage in our case