# Inference with a Fine-Tuned PyTorch Image Classifier

In this notebook, we will perform inference using a pretrained image classification model (based on PyTorch) that has been fine-tuned on a custom dataset. The workflow will demonstrate how to load the model, prepare input images, and generate predictions on unseen data.

Along the way, we will also evaluate the model’s performance on a held-out test set. To better understand the strengths and weaknesses of the classifier, we will generate and plot a confusion matrix, which visually summarizes prediction results across all classes.

By the end of this notebook, you will be able to:

*   Load a fine-tuned PyTorch image classifier for inference.
*   Preprocess and batch images for evaluation.
*   Run predictions and compute evaluation metrics.
*   Plot and interpret a confusion matrix to assess model accuracy and misclassifications.

In [None]:
# Connect to Google drive in case the model is stored on it otherwise skip this
# step.
from google.colab import drive
drive.mount('/content/gdrive')

try:
  !ln -s /content/gdrive/My\ Drive/ /mydrive
  print('Successful')
except Exception as e:
  print(e)
  print('Not successful')

In [None]:
# Get the inference utils.
url = (
    "https://raw.githubusercontent.com/tensorflow/models/refs/heads/master/"
    "official/projects/waste_identification_ml/fine_tuning/"
    "Pytorch_Image_Classifier/inference_utils.py"
)
!wget {url} > /dev/null 2>&1

In [None]:
import torch
import inference_utils
import warnings
import tqdm
import glob
import os
from sklearn.metrics import confusion_matrix, classification_report

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

## Define Model Path, Image Path, and Class Labels

In [None]:
# Change the paths and labels according to your case.
MODEL_PATH = "/mydrive/LLM/pet_grade_bottles/best_vit_model_epoch_10.pt"
CLASS_NAMES = ['dairy','others']
device = "cuda" if torch.cuda.is_available() else "cpu"

input_dir = '/mydrive/LLM/pet_grade_bottles/test_dataset/mixed_bottles_objects_images/'

In [None]:
# Load Model.
model = inference_utils.load_vit_classifier(
    model_path=MODEL_PATH,
    num_classes=len(CLASS_NAMES),
    device=device
)


# Get the same transform used during training
transform = inference_utils.get_default_transform(image_size=(224, 224))

## Inferencing

In [None]:
files = glob.glob(os.path.join(input_dir, '*'))
len(files)

In [None]:
y_pred = [] # Predicted class.
y_test = ['dairy'] * len(files) # Actual class.

for path in tqdm.tqdm(files):
  image_tensor = inference_utils.process_image(image_path=path, transform=transform)
  logits = inference_utils.predict(model=model, image_tensor=image_tensor, device=device)
  pred_class, pred_prob = inference_utils.get_prediction_details(
      logits=logits,
      class_names=CLASS_NAMES
  )
  print(f"  - Class: {pred_class}")
  print(f"  - Probability: {pred_prob:.4f}")
  inference_utils.plot_prediction(
    image_path=path,
    pred_class=pred_class,
    pred_prob=pred_prob
  )

  y_pred.append(pred_class)

## Visualize metric

In [None]:
# Show classification report.
print(classification_report(y_test, y_pred, target_names=CLASS_NAMES))

In [None]:
# Show confusion matrix.
matrix = confusion_matrix(y_test, y_pred)
inference_utils.show_confusion_matrix(matrix, CLASS_NAMES)