### Object Material Type (OMT) Classifier, based on OpenAI's CLIP Model
Source: https://openai.com/research/clip

In [1]:
from common import *
from model_functions import *
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import clip
import numpy as np

# Load dataset
df_dataset = load_from_pickle(dataset_file)
# Take 10% of stratified samples for zero-shot classification testing
_x_train, zs_x, _y_train, zs_y = train_test_split(df_dataset['File'], df_dataset['Material Class'], test_size=0.01, stratify=df_dataset['Material Class'], random_state=9876)
# 80-20 Train-Test split
x_train, x_test, y_train, y_test = train_test_split(df_dataset['File'], df_dataset['Material Class'], test_size=0.2, stratify=df_dataset['Material Class'], random_state=1234)

# Initialise material classes
material_classes = [i.lower() for i in material_class_mapping.values()]
for idx, i in enumerate(material_classes):
    if(i == 'others'):
        material_classes[idx] = "anything other than paper, plastic, glass, or metal"

# Preparations for model
device = "cuda" if torch.cuda.is_available() else "cpu"
text_prompt = torch.cat([clip.tokenize(f"a photo of an object made of {c}") for c in material_classes]).to(device)
model, preprocess = clip.load("ViT-L/14", device=device, download_root=MODEL_FOLDER)

In [2]:
""" Initial Performance """
# Get model specifications
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size, end='\n\n')

# Initialise predictions
zs_y_pred = []
for current_image in zs_x:
    # Initialise image
    image = preprocess(Image.open(current_image)).unsqueeze(0).to(device)

    # Classify image's material type
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text_prompt)
        
        # logits_per_image, logits_per_text = model(image, text_prompt)
        # probs = logits_per_image.softmax(dim=-1).cpu().numpy()

        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        values, indices = similarity[0].topk(len(material_classes))

        # Add result to predictions
        zs_y_pred.append(int(indices[0]))

        # Print the result
        # print("\nTop predictions:\n")
        # for value, index in zip(values, indices):
        #     print(f"{material_classes[index]:>16s}: {100 * value.item():.2f}%")

# Get model performance
results = multi_class_metrics(list(zs_y), zs_y_pred)
accuracy = results['accuracy']
precision = results['precision']
recall = results['recall']
f1 = results['f1']
mcc = results['mcc']
kappa = results['kappa']
hamming_loss_val = results['hamming_loss_val']
cm = results['cm']
class_report = results['class_report']

# Print results
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
print("Matthews Correlation Coefficient (MCC):", mcc)
print("Cohen's Kappa:", kappa)
print("Hamming Loss:", hamming_loss_val, end='\n\n')
print("Confusion Matrix:\n", cm, end="\n\n")
print("Classification Report:\n", class_report, end="\n\n\n")

Model parameters: 427,616,513
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.720164609053498
Precision: 0.734947393591583
Recall: 0.7138424998330328
F1 Score: 0.703695785990944
Matthews Correlation Coefficient (MCC): 0.640289920044027
Cohen's Kappa: 0.6349013455887227
Hamming Loss: 0.27983539094650206

Confusion Matrix:
 [[70  9  0  0 14]
 [ 4 35  6  0  1]
 [ 2  2 26  0  0]
 [ 5  0  8 19  0]
 [ 1  5 11  0 25]]

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.75      0.80        93
           1       0.69      0.76      0.72        46
           2       0.51      0.87      0.64        30
           3       1.00      0.59      0.75        32
           4       0.62      0.60      0.61        42

    accuracy                           0.72       243
   macro avg       0.73      0.71      0.70       243
weighted avg       0.76      0.72      0.73       243



