In [1]:
import pandas as pd
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import torch

In [5]:
processed_data_path = "../data/processed"
validation_csv = os.path.join(processed_data_path, "validation_data.csv")
test_csv = os.path.join(processed_data_path, "test_data.csv")
train_csv = os.path.join(processed_data_path, "train_data.csv")

valid_df = pd.read_csv(validation_csv)
train_df = pd.read_csv(train_csv)
test_df = pd.read_csv(test_csv)

In [6]:
def load_image(image_path):
    return Image.open(image_path)

train_df['image_path'] = train_df['image']
train_df['image'] = train_df['image_path'].apply(load_image)

valid_df['image_path'] = valid_df['image']
valid_df['image'] = valid_df['image_path'].apply(load_image)

In [7]:
train_df["label"].unique()
mapping = {
    'Banana_Bad': 1,
    'Lemon_Mixed': 2,
    'Apple_Good' : 3, 
    'Guava_Mixed': 4,
    'Guava_Bad' : 5, 
    'Lime_Bad':6, 
    'Pomegranate_Good':7,
    'Guava_Good':8,
    'Lime_Good':9, 
    'Banana_Good':10, 
    'Apple_Bad':11, 
    'Pomegranate_Bad':12,
    'Orange_Good':13, 
    'Banana_Mixed':14, 
    'Orange_Bad':15, 
    'Pomegranate_Mixed':16,
    'Orange_Mixed':17, 
    'Apple_Mixed':18
}
train_df['label'] = train_df['label'].map(mapping)
valid_df['label'] = valid_df['label'].map(mapping)

In [8]:
print(valid_df)


                                                  image  label  \
0     <PIL.JpegImagePlugin.JpegImageFile image mode=...      1   
1     <PIL.JpegImagePlugin.JpegImageFile image mode=...     13   
2     <PIL.JpegImagePlugin.JpegImageFile image mode=...      8   
3     <PIL.JpegImagePlugin.JpegImageFile image mode=...     10   
4     <PIL.JpegImagePlugin.JpegImageFile image mode=...     17   
...                                                 ...    ...   
1948  <PIL.JpegImagePlugin.JpegImageFile image mode=...     15   
1949  <PIL.JpegImagePlugin.JpegImageFile image mode=...     14   
1950  <PIL.JpegImagePlugin.JpegImageFile image mode=...      7   
1951  <PIL.JpegImagePlugin.JpegImageFile image mode=...      9   
1952  <PIL.JpegImagePlugin.JpegImageFile image mode=...      7   

                                             image_path  
0     ../data/external/fruit_images/bad_quality_frui...  
1     ../data/external/fruit_images/good_quality_fru...  
2     ../data/external/fruit_imag

In [9]:
def check_and_transform_image(image):
    if image.mode == "L":
        # If it's a black and white image, convert it to RGB
        print(f" RGB image {image}")
        image = image.convert("RGB")
    return image

# Apply the function to each image path in the DataFrame
train_df["image"] = train_df["image"].apply(check_and_transform_image)
valid_df["image"] = valid_df["image"].apply(check_and_transform_image)

 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28490>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E284F0>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E289D0>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28A90>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28B50>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28C70>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28D30>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28D90>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E28E50>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E55070>
 RGB image <PIL.JpegImagePlugin.JpegImageFile image mode=L size=224x224 at 0x1D7E3E55190>
 RGB image

In [10]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)
print(processor)

  from .autonotebook import tqdm as notebook_tqdm


ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}



In [11]:
class CustomImageDataset(Dataset):
    def __init__(self, dataframe, feature_extractor):
        self.dataframe = dataframe
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image = self.dataframe.iloc[idx, 0]  # Assuming the "image" column contains actual images
        label = self.dataframe.iloc[idx, 1]  # Assuming the "label" column is at index 1 
        
        pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
        print(f"image_path : {self.dataframe.iloc[idx, 2]}, shape: {pixel_values.shape}")
        return {"pixel_values": pixel_values, "labels": torch.tensor(label)}      

In [12]:
train_dataset = CustomImageDataset(dataframe=train_df, feature_extractor=processor)
val_dataset = CustomImageDataset(dataframe=valid_df, feature_extractor=processor)

In [13]:
train_dataset[0]

image_path : ../data/external/fruit_images/bad_quality_fruits\Banana_Bad\IMG_20190910_175634.jpg, shape: torch.Size([1, 3, 224, 224])


{'pixel_values': tensor([[[[0.5686, 0.5765, 0.5922,  ..., 0.8275, 0.8275, 0.8667],
           [0.6078, 0.5922, 0.5843,  ..., 0.8431, 0.8353, 0.8588],
           [0.5608, 0.5373, 0.5294,  ..., 0.8745, 0.8510, 0.8588],
           ...,
           [0.4431, 0.4275, 0.3961,  ..., 0.9608, 0.9686, 0.9686],
           [0.4039, 0.3804, 0.3569,  ..., 0.9843, 0.9765, 0.9686],
           [0.4510, 0.4196, 0.3725,  ..., 1.0000, 0.9765, 0.9529]],
 
          [[0.5843, 0.5922, 0.6078,  ..., 0.8431, 0.8431, 0.8824],
           [0.6235, 0.6078, 0.6000,  ..., 0.8588, 0.8510, 0.8745],
           [0.5765, 0.5529, 0.5451,  ..., 0.8902, 0.8667, 0.8745],
           ...,
           [0.4588, 0.4431, 0.4118,  ..., 0.9686, 0.9765, 0.9765],
           [0.4196, 0.3961, 0.3725,  ..., 0.9922, 0.9843, 0.9765],
           [0.4667, 0.4353, 0.3882,  ..., 1.0000, 0.9843, 0.9608]],
 
          [[0.5608, 0.5686, 0.5843,  ..., 0.8196, 0.8196, 0.8588],
           [0.6000, 0.5843, 0.5765,  ..., 0.8353, 0.8275, 0.8510],
        

In [14]:
val_dataset[2]

image_path : ../data/external/fruit_images/good_quality_fruits\Guava_Good\IMG_8625.JPG, shape: torch.Size([1, 3, 224, 224])


{'pixel_values': tensor([[[[ 0.3725,  0.3725,  0.4353,  ...,  0.7412,  0.8196,  0.6314],
           [ 0.3647,  0.4902,  0.4667,  ...,  0.7333,  0.7490,  0.6863],
           [ 0.4196,  0.4980,  0.5137,  ...,  0.7333,  0.5137,  0.5843],
           ...,
           [-0.3098, -0.3020, -0.2706,  ..., -0.1451, -0.0353,  0.0118],
           [-0.2549, -0.3333, -0.3725,  ..., -0.1216, -0.0745,  0.0039],
           [-0.3569, -0.2863, -0.1686,  ..., -0.1059,  0.1216,  0.0980]],
 
          [[ 0.3647,  0.3647,  0.4275,  ...,  0.7098,  0.7882,  0.6000],
           [ 0.3569,  0.4824,  0.4588,  ...,  0.7020,  0.7176,  0.6549],
           [ 0.4118,  0.4902,  0.5059,  ...,  0.7020,  0.4824,  0.5529],
           ...,
           [-0.3255, -0.3176, -0.2863,  ..., -0.1922, -0.0824, -0.0353],
           [-0.2706, -0.3490, -0.3882,  ..., -0.1686, -0.1216, -0.0431],
           [-0.3725, -0.3020, -0.1843,  ..., -0.1529,  0.0745,  0.0431]],
 
          [[ 0.4275,  0.4275,  0.4902,  ...,  0.7176,  0.7961,  0.6078

In [17]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset)
valid_dataloader = DataLoader(val_dataset)

In [18]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [19]:
from transformers import AutoModelForImageClassification

labels = train_df['label'].unique().tolist()

model = AutoModelForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from torch import softmax
from tqdm import tqdm
from transformers import AutoFeatureExtractor, get_scheduler

train_dataloader = DataLoader(train_dataset)
valid_dataloader = DataLoader(val_dataset)


optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
epochs = 3

num_training_steps = epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

device = 'cpu'
model.to(device)

for epoch in tqdm(range(epochs), desc="Training"):

    running_loss = 0.0
    accuracy = 0.0
    model.train()

    for batch in tqdm(train_dataloader, desc="Batch", leave=False):

        batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)

        optimizer.zero_grad()

        batch = {k: v.to(device) for k, v in batch.items()}

        y_pred = model(**batch)

        class_pred = torch.argmax(softmax(y_pred.logits, dim=1), dim=1)

        is_correct = (
            class_pred.detach().cpu().numpy() == np.array(batch["labels"].cpu())
        ).sum()

        accuracy += is_correct

        loss = y_pred.loss

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        lr_scheduler.step()

    running_loss /= len(train_dataloader)
    accuracy /= len(train_dataloader)

    print(f"Training Loss: {running_loss}, Training Accuracy: {accuracy}")

    model.eval()
    running_loss = 0.0
    accuracy = 0.0
    with torch.no_grad():
        for batch in tqdm(valid_dataloader, desc="Validation", leave=False):
            batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
            batch = {k: v.to(device) for k, v in batch.items()}

            y_pred = model(**batch)

            class_pred = torch.argmax(softmax(y_pred.logits, dim=1), dim=1)

            is_correct = (
                class_pred.detach().cpu().numpy() == np.array(batch["labels"].cpu())
            ).sum()

            accuracy += is_correct

            loss = y_pred.loss
            running_loss += loss.item()

        running_loss /= len(valid_dataloader)
        accuracy /= len(valid_dataloader)

        print(f"Validation Loss: {running_loss}, Validation Accuracy: {accuracy}")
