In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import ViTFeatureExtractor, TFAutoModel
from PIL import Image
import tensorflow as tf
from carb_calc.ml_logic.model import prediction
import cv2
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
base_model = TFAutoModel.from_pretrained('google/vit-base-patch16-224')

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing TFViTModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFViTModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFViTModel were not initialized from the PyTorch model and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")



In [5]:
image = Image.open('../test_images/banana.jpg')

In [6]:
processed_image = feature_extractor(images=image, return_tensors="tf")

In [7]:
embeddings = base_model(processed_image)

In [8]:
X = embeddings.last_hidden_state

In [9]:
X.shape

TensorShape([1, 197, 768])

In [10]:
X = X[:,0,:]

In [11]:
X.shape

TensorShape([1, 768])

In [12]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(768,)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(131, activation='softmax')
])

In [13]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder,filename))
        if img is not None:
            images.append(img)
    return images

**Processing**

In [14]:
labels_input = [
    "Apple Braeburn", "Cantaloupe 1", "Grape Blue", "Mangostan", "Pear Monster", "Potato White",
    "Apple Crimson Snow", "Cantaloupe 2", "Grape Pink", "Maracuja", "Pear Red", "Quince",
    "Apple Golden 1", "Carambula", "Grape White", "Melon Piel de Sapo", "Pear Stone", "Rambutan",
    "Apple Golden 2", "Cauliflower", "Grape White 2", "Mulberry", "Pear Williams", "Raspberry",
    "Apple Golden 3", "Cherry 1", "Grape White 3", "Nectarine", "Pepino", "Redcurrant",
    "Apple Granny Smith", "Cherry 2", "Grape White 4", "Nectarine Flat", "Pepper Green", "Salak",
    "Apple Pink Lady", "Cherry Rainier", "Grapefruit Pink", "Nut Forest", "Pepper Orange", "Strawberry",
    "Apple Red 1", "Cherry Wax Black", "Grapefruit White", "Nut Pecan", "Pepper Red", "Strawberry Wedge",
    "Apple Red 2", "Cherry Wax Red", "Guava", "Onion Red", "Pepper Yellow", "Tamarillo",
    "Apple Red 3", "Cherry Wax Yellow", "Hazelnut", "Onion Red Peeled", "Physalis", "Tangelo",
    "Apple Red Delicious", "Chestnut", "Huckleberry", "Onion White", "Physalis with Husk", "Tomato 1",
    "Apple Red Yellow 1", "Clementine", "Kaki", "Orange", "Pineapple", "Tomato 2",
    "Apple Red Yellow 2", "Cocos", "Kiwi", "Papaya", "Pineapple Mini", "Tomato 3",
    "Apricot", "Corn", "Kohlrabi", "Passion Fruit", "Pitahaya Red", "Tomato 4",
    "Avocado", "Corn Husk", "Kumquats", "Peach", "Plum", "Tomato Cherry Red",
    "Avocado ripe", "Cucumber Ripe", "Lemon", "Peach 2", "Plum 2", "Tomato Heart",
    "Banana", "Cucumber Ripe 2", "Lemon Meyer", "Peach Flat", "Plum 3", "Tomato Maroon",
    "Banana Lady Finger", "Dates", "Limes", "Pear", "Pomegranate", "Tomato Yellow",
    "Banana Red", "Eggplant", "Lychee", "Pear 2", "Pomelo Sweetie", "Tomato not Ripened",
    "Beetroot", "Fig", "Mandarine", "Pear Abate", "Potato Red", "Walnut",
    "Blueberry", "Ginger Root", "Mango", "Pear Forelle", "Potato Red Washed", "Watermelon",
    "Cactus fruit", "Granadilla", "Mango Red", "Pear Kaiser", "Potato Sweet"
]


In [15]:
def modify_label(label):
    return label.replace(" ", "_")

In [16]:
key_pair = {}

for i in labels_input:
    modified_label = modify_label(i)
    key_pair[modified_label] = load_images_from_folder(f'/Users/kymbradshaw/Downloads/archive (2)/fruits-360_dataset/fruits-360/Training/{i}')

In [35]:
key_pair['Apple_Braeburn'][30].shape

(100, 100, 3)

In [36]:
key_pair["Papaya"][10].shape

(100, 100, 3)

In [41]:
file_path_train = '/Users/kymbradshaw/Downloads/archive (2)/fruits-360_dataset/fruits-360/Training'
file_path_test = '/Users/kymbradshaw/Downloads/archive (2)/fruits-360_dataset/fruits-360/Test'

In [42]:
train = tf.keras.utils.image_dataset_from_directory(file_path_train, labels='inferred', label_mode='categorical', color_mode='rgb', batch_size=32, image_size=(100,100))

Found 67692 files belonging to 131 classes.


In [43]:
test = tf.keras.utils.image_dataset_from_directory(file_path_test, labels='inferred', label_mode='categorical', color_mode='rgb', batch_size=32, image_size=(100,100))

Found 22688 files belonging to 131 classes.


In [44]:
type(train)

tensorflow.python.data.ops.prefetch_op._PrefetchDataset