### Object Material Type (OMT) Classifier, based on OpenAI's CLIP Model
Source: https://openai.com/research/clip

### Zero-Shot Classification

In [3]:
from common import *
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, classification_report, matthews_corrcoef, cohen_kappa_score, hamming_loss
import torch
import clip
import numpy as np

# Load dataset
df_dataset = load_from_pickle(dataset_file)
# Take 10% of stratified samples for zero-shot classification testing
_x_train, zs_x, _y_train, zs_y = train_test_split(df_dataset['File'], df_dataset['Material Class'], test_size=0.1, stratify=df_dataset['Material Class'], random_state=9876)
# 80-20 Train-Test split
x_train, x_test, y_train, y_test = train_test_split(df_dataset['File'], df_dataset['Material Class'], test_size=0.2, stratify=df_dataset['Material Class'], random_state=1234)

# Initialise material classes
material_classes = [i.lower() for i in material_class_mapping.values()]
for idx, i in enumerate(material_classes):
    if(i == 'others'):
        material_classes[idx] = "anything other than paper, plastic, glass, or metal"

# Preparations for model
device = "cuda" if torch.cuda.is_available() else "cpu"
text_prompt = torch.cat([clip.tokenize(f"photo of an object made of {c}") for c in material_classes]).to(device)

In [None]:
# Compare performance of available models
for current_model in clip.available_models(): 
    # Initialise model
    model, preprocess = clip.load(current_model, device=device, download_root=MODEL_FOLDER)

    # Get model specifications
    input_resolution = model.visual.input_resolution
    context_length = model.context_length
    vocab_size = model.vocab_size
    
    print("Model:", current_model)
    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    print("Input resolution:", input_resolution)
    print("Context length:", context_length)
    print("Vocab size:", vocab_size, end='\n\n')
    
    # Initialise predictions
    zs_y_pred = []
    for current_image in zs_x:
        # Initialise image
        image = preprocess(Image.open(current_image)).unsqueeze(0).to(device)

        # Classify image's material type
        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text_prompt)
            
            # logits_per_image, logits_per_text = model(image, text_prompt)
            # probs = logits_per_image.softmax(dim=-1).cpu().numpy()

            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            values, indices = similarity[0].topk(len(material_classes))

            # Add result to predictions
            zs_y_pred.append(int(indices[0]))

            # Print the result
            # print("\nTop predictions:\n")
            # for value, index in zip(values, indices):
            #     print(f"{material_classes[index]:>16s}: {100 * value.item():.2f}%")

    # Get model performance
    list_zs_y = list(zs_y)

    # Accuracy
    accuracy = accuracy_score(list_zs_y, zs_y_pred)
    print("Accuracy:", accuracy)

    # Precision
    precision = precision_score(list_zs_y, zs_y_pred, average='macro')
    print("Precision:", precision)

    # Recall
    recall = recall_score(list_zs_y, zs_y_pred, average='macro')
    print("Recall:", recall)

    # F1 Score
    f1 = f1_score(list_zs_y, zs_y_pred, average='macro')
    print("F1 Score:", f1)

    # Matthews Correlation Coefficient (MCC)
    mcc = matthews_corrcoef(list_zs_y, zs_y_pred)
    print("Matthews Correlation Coefficient (MCC):", mcc)

    # Cohen's Kappa
    kappa = cohen_kappa_score(list_zs_y, zs_y_pred)
    print("Cohen's Kappa:", kappa)

    # Hamming Loss
    hamming_loss_val = hamming_loss(list_zs_y, zs_y_pred)
    print("Hamming Loss:", hamming_loss_val, end='\n\n')

    # Confusion matrix
    cm = confusion_matrix(list_zs_y, zs_y_pred)
    print("Confusion Matrix:")
    print(cm, end="\n\n")

    # Classification Report
    class_report = classification_report(list_zs_y, zs_y_pred)
    print("Classification Report:")
    print(class_report, end="\n\n\n")

### Zero-Shot Classification Performance

#### Performance on 10% Stratified Dataset

In [None]:
""" 
Model: RN50
Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.459983498349835
Precision: 0.5364220929770019
Recall: 0.38003848168448484
F1 Score: 0.3680587218509033
Matthews Correlation Coefficient (MCC): 0.25340113747196713
Cohen's Kappa: 0.23193098306528914
Hamming Loss: 0.540016501650165

Confusion Matrix:
[[698 122  77   1  25]
 [227 173  51   0   6]
 [127   5 159   7   0]
 [216   1  61  42   2]
 [267  26  83   5  43]]

Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.76      0.57       923
           1       0.53      0.38      0.44       457
           2       0.37      0.53      0.44       298
           3       0.76      0.13      0.22       322
           4       0.57      0.10      0.17       424

    accuracy                           0.46      2424
   macro avg       0.54      0.38      0.37      2424
weighted avg       0.52      0.46      0.41      2424



Model: RN101
Model parameters: 119,688,033
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.43853135313531355
Precision: 0.5527112807582848
Recall: 0.41461532266208617
F1 Score: 0.3941133714670186
Matthews Correlation Coefficient (MCC): 0.24808290524134521
Cohen's Kappa: 0.2311673445189052
Hamming Loss: 0.5614686468646864

Confusion Matrix:
[[569  41 234   6  73]
 [298  84  69   2   4]
 [ 57   1 237   3   0]
 [135   2  88  93   4]
 [250   7  85   2  80]]

Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.62      0.51       923
           1       0.62      0.18      0.28       457
           2       0.33      0.80      0.47       298
           3       0.88      0.29      0.43       322
           4       0.50      0.19      0.27       424

    accuracy                           0.44      2424
   macro avg       0.55      0.41      0.39      2424
weighted avg       0.53      0.44      0.41      2424



Model: RN50x4
Model parameters: 178,300,601
Input resolution: 288
Context length: 77
Vocab size: 49408

Accuracy: 0.4236798679867987
Precision: 0.48681441780354745
Recall: 0.4645814783022167
F1 Score: 0.4189558129951056
Matthews Correlation Coefficient (MCC): 0.2785911295413147
Cohen's Kappa: 0.267045705425068
Hamming Loss: 0.5763201320132013

Confusion Matrix:
[[265 118 281   3 256]
 [ 48 339  56   5   9]
 [ 33   7 244  12   2]
 [147   4  93  71   7]
 [149  46 120   1 108]]

Classification Report:
              precision    recall  f1-score   support

           0       0.41      0.29      0.34       923
           1       0.66      0.74      0.70       457
           2       0.31      0.82      0.45       298
           3       0.77      0.22      0.34       322
           4       0.28      0.25      0.27       424

    accuracy                           0.42      2424
   macro avg       0.49      0.46      0.42      2424
weighted avg       0.47      0.42      0.41      2424



Model: RN50x16
Model parameters: 290,979,217
Input resolution: 384
Context length: 77
Vocab size: 49408

Accuracy: 0.3778877887788779
Precision: 0.6005714234807596
Recall: 0.2875830346805791
F1 Score: 0.2680951382913359
Matthews Correlation Coefficient (MCC): 0.10396886235842223
Cohen's Kappa: 0.08371582946158096
Hamming Loss: 0.6221122112211221

Confusion Matrix:
[[672   2 228   2  19]
 [338  54  51   0  14]
 [170   0 126   1   1]
 [276   0  22  24   0]
 [376   1   6   1  40]]

Classification Report:
              precision    recall  f1-score   support

           0       0.37      0.73      0.49       923
           1       0.95      0.12      0.21       457
           2       0.29      0.42      0.34       298
           3       0.86      0.07      0.14       322
           4       0.54      0.09      0.16       424

    accuracy                           0.38      2424
   macro avg       0.60      0.29      0.27      2424
weighted avg       0.56      0.38      0.31      2424



Model: RN50x64
Model parameters: 623,258,305
Input resolution: 448
Context length: 77
Vocab size: 49408

Accuracy: 0.5107260726072608
Precision: 0.6150121280504154
Recall: 0.503943698798375
F1 Score: 0.4940139667238478
Matthews Correlation Coefficient (MCC): 0.36780074396645157
Cohen's Kappa: 0.3562257593536424
Hamming Loss: 0.4892739273927393

Confusion Matrix:
[[507  92 301   4  19]
 [ 93 289  71   0   4]
 [ 60   9 214  15   0]
 [151   6  54 110   1]
 [123 126  55   2 118]]

Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.55      0.55       923
           1       0.55      0.63      0.59       457
           2       0.31      0.72      0.43       298
           3       0.84      0.34      0.49       322
           4       0.83      0.28      0.42       424

    accuracy                           0.51      2424
   macro avg       0.62      0.50      0.49      2424
weighted avg       0.61      0.51      0.51      2424



Model: ViT-B/32
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.5198019801980198
Precision: 0.5566003027870373
Recall: 0.5115291964599042
F1 Score: 0.5088150348719772
Matthews Correlation Coefficient (MCC): 0.36312681718413553
Cohen's Kappa: 0.35875201985422955
Hamming Loss: 0.4801980198019802

Confusion Matrix:
[[515 195 110   0 103]
 [ 65 308  61   3  20]
 [ 79  12 191  14   2]
 [150   2  27 140   3]
 [190  84  37   7 106]]

Classification Report:
              precision    recall  f1-score   support

           0       0.52      0.56      0.54       923
           1       0.51      0.67      0.58       457
           2       0.45      0.64      0.53       298
           3       0.85      0.43      0.58       322
           4       0.45      0.25      0.32       424

    accuracy                           0.52      2424
   macro avg       0.56      0.51      0.51      2424
weighted avg       0.54      0.52      0.51      2424



Model: ViT-B/16
Model parameters: 149,620,737
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.49834983498349833
Precision: 0.5593171334247528
Recall: 0.4970808217221248
F1 Score: 0.48833988282391905
Matthews Correlation Coefficient (MCC): 0.3454062610394499
Cohen's Kappa: 0.3408978292174629
Hamming Loss: 0.5016501650165016

Confusion Matrix:
[[483  27  93   0 320]
 [126 246  68   1  16]
 [ 48  14 223  10   3]
 [121   0  97  96   8]
 [108  97  57   2 160]]

Classification Report:
              precision    recall  f1-score   support

           0       0.55      0.52      0.53       923
           1       0.64      0.54      0.59       457
           2       0.41      0.75      0.53       298
           3       0.88      0.30      0.45       322
           4       0.32      0.38      0.34       424

    accuracy                           0.50      2424
   macro avg       0.56      0.50      0.49      2424
weighted avg       0.55      0.50      0.50      2424



Model: ViT-L/14
Model parameters: 427,616,513
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.6621287128712872
Precision: 0.6671066030427073
Recall: 0.6464374130645891
F1 Score: 0.6345731510632313
Matthews Correlation Coefficient (MCC): 0.5628138868486012
Cohen's Kappa: 0.557143634710465
Hamming Loss: 0.3378712871287129

Confusion Matrix:
[[683  81  16   2 141]
 [ 56 334  65   0   2]
 [ 27  13 246  12   0]
 [ 72   2  73 173   2]
 [ 34  82 134   5 169]]

Classification Report:
              precision    recall  f1-score   support

           0       0.78      0.74      0.76       923
           1       0.65      0.73      0.69       457
           2       0.46      0.83      0.59       298
           3       0.90      0.54      0.67       322
           4       0.54      0.40      0.46       424

    accuracy                           0.66      2424
   macro avg       0.67      0.65      0.63      2424
weighted avg       0.69      0.66      0.66      2424



Model: ViT-L/14@336px
Model parameters: 427,944,193
Input resolution: 336
Context length: 77
Vocab size: 49408

Accuracy: 0.6542904290429042
Precision: 0.663429523779947
Recall: 0.627868365041197
F1 Score: 0.6192232789472756
Matthews Correlation Coefficient (MCC): 0.5501572631477752
Cohen's Kappa: 0.5438385557779997
Hamming Loss: 0.3457095709570957

Confusion Matrix:
[[708  82   8   1 124]
 [ 65 325  65   0   2]
 [ 36  12 238  12   0]
 [ 81   2  79 160   0]
 [ 41  75 150   3 155]]

Classification Report:
              precision    recall  f1-score   support

           0       0.76      0.77      0.76       923
           1       0.66      0.71      0.68       457
           2       0.44      0.80      0.57       298
           3       0.91      0.50      0.64       322
           4       0.55      0.37      0.44       424

    accuracy                           0.65      2424
   macro avg       0.66      0.63      0.62      2424
weighted avg       0.68      0.65      0.65      2424
"""

#### Performance on 1% Stratified Dataset

In [None]:
""" 
Model: RN50
Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.448559670781893
Precision: 0.583571768469864
Recall: 0.3751343251185467
F1 Score: 0.37071455515687757
Matthews Correlation Coefficient (MCC): 0.23762702989158305
Cohen's Kappa: 0.2213959494034098
Hamming Loss: 0.551440329218107

Confusion Matrix:
[[65 15 12  0  1]
 [23 20  3  0  0]
 [14  1 14  1  0]
 [18  1  8  5  0]
 [26  4  7  0  5]]

Classification Report:
              precision    recall  f1-score   support

           0       0.45      0.70      0.54        93
           1       0.49      0.43      0.46        46
           2       0.32      0.47      0.38        30
           3       0.83      0.16      0.26        32
           4       0.83      0.12      0.21        42

    accuracy                           0.45       243
   macro avg       0.58      0.38      0.37       243
weighted avg       0.56      0.45      0.41       243



Model: RN101
Model parameters: 119,688,033
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.4732510288065844
Precision: 0.6328468406593407
Recall: 0.4481032525212048
F1 Score: 0.4353587588881706
Matthews Correlation Coefficient (MCC): 0.2964528719346169
Cohen's Kappa: 0.27898189573239984
Hamming Loss: 0.5267489711934157

Confusion Matrix:
[[59 10 22  0  2]
 [25 12  9  0  0]
 [ 4  1 25  0  0]
 [15  1  8  8  0]
 [25  0  6  0 11]]

Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.63      0.53        93
           1       0.50      0.26      0.34        46
           2       0.36      0.83      0.50        30
           3       1.00      0.25      0.40        32
           4       0.85      0.26      0.40        42

    accuracy                           0.47       243
   macro avg       0.63      0.45      0.44       243
weighted avg       0.59      0.47      0.45       243



Model: RN50x4
Model parameters: 178,300,601
Input resolution: 288
Context length: 77
Vocab size: 49408

Accuracy: 0.4444444444444444
Precision: 0.5236139332365748
Recall: 0.48317045682228005
F1 Score: 0.4459617423004151
Matthews Correlation Coefficient (MCC): 0.30142422420908876
Cohen's Kappa: 0.2899196952314985
Hamming Loss: 0.5555555555555556

Confusion Matrix:
[[29 11 31  0 22]
 [ 9 35  1  0  1]
 [ 4  1 24  1  0]
 [13  0 10  9  0]
 [13  6 12  0 11]]

Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.31      0.36        93
           1       0.66      0.76      0.71        46
           2       0.31      0.80      0.44        30
           3       0.90      0.28      0.43        32
           4       0.32      0.26      0.29        42

    accuracy                           0.44       243
   macro avg       0.52      0.48      0.45       243
weighted avg       0.50      0.44      0.43       243



Model: RN50x16
Model parameters: 290,979,217
Input resolution: 384
Context length: 77
Vocab size: 49408

Accuracy: 0.39094650205761317
Precision: 0.7298524087997772
Recall: 0.29988821545448474
F1 Score: 0.30113976304214096
Matthews Correlation Coefficient (MCC): 0.11915305898735758
Cohen's Kappa: 0.09360350824134278
Hamming Loss: 0.6090534979423868

Confusion Matrix:
[[68  0 25  0  0]
 [36  8  2  0  0]
 [19  0 11  0  0]
 [27  0  0  5  0]
 [39  0  0  0  3]]

Classification Report:
              precision    recall  f1-score   support

           0       0.36      0.73      0.48        93
           1       1.00      0.17      0.30        46
           2       0.29      0.37      0.32        30
           3       1.00      0.16      0.27        32
           4       1.00      0.07      0.13        42

    accuracy                           0.39       243
   macro avg       0.73      0.30      0.30       243
weighted avg       0.67      0.39      0.34       243



Model: RN50x64
Model parameters: 623,258,305
Input resolution: 448
Context length: 77
Vocab size: 49408

Accuracy: 0.5267489711934157
Precision: 0.6232332944832945
Recall: 0.5095527783343352
F1 Score: 0.49698702494362906
Matthews Correlation Coefficient (MCC): 0.3827235763528328
Cohen's Kappa: 0.37162709120345394
Hamming Loss: 0.4732510288065844

Confusion Matrix:
[[56  9 25  1  2]
 [10 29  7  0  0]
 [ 7  0 22  1  0]
 [15  2  4 11  0]
 [11 15  6  0 10]]

Classification Report:
              precision    recall  f1-score   support

           0       0.57      0.60      0.58        93
           1       0.53      0.63      0.57        46
           2       0.34      0.73      0.47        30
           3       0.85      0.34      0.49        32
           4       0.83      0.24      0.37        42

    accuracy                           0.53       243
   macro avg       0.62      0.51      0.50       243
weighted avg       0.61      0.53      0.52       243



Model: ViT-B/32
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.5596707818930041
Precision: 0.6148514851485148
Recall: 0.5433259032925932
F1 Score: 0.5445632206500818
Matthews Correlation Coefficient (MCC): 0.4163739149401716
Cohen's Kappa: 0.41027443864821955
Hamming Loss: 0.4403292181069959

Confusion Matrix:
[[58 18  8  0  9]
 [ 6 32  7  0  1]
 [ 8  2 20  0  0]
 [14  1  2 15  0]
 [15 11  5  0 11]]

Classification Report:
              precision    recall  f1-score   support

           0       0.57      0.62      0.60        93
           1       0.50      0.70      0.58        46
           2       0.48      0.67      0.56        30
           3       1.00      0.47      0.64        32
           4       0.52      0.26      0.35        42

    accuracy                           0.56       243
   macro avg       0.61      0.54      0.54       243
weighted avg       0.60      0.56      0.55       243



Model: ViT-B/16
Model parameters: 149,620,737
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.51440329218107
Precision: 0.5705212894514913
Recall: 0.5004149135109864
F1 Score: 0.5024798569075044
Matthews Correlation Coefficient (MCC): 0.3559923760937538
Cohen's Kappa: 0.3533444589779442
Hamming Loss: 0.48559670781893005

Confusion Matrix:
[[52  5  4  0 32]
 [12 29  4  0  1]
 [ 8  2 20  0  0]
 [14  0  8 10  0]
 [11  9  7  1 14]]

Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.56      0.55        93
           1       0.64      0.63      0.64        46
           2       0.47      0.67      0.55        30
           3       0.91      0.31      0.47        32
           4       0.30      0.33      0.31        42

    accuracy                           0.51       243
   macro avg       0.57      0.50      0.50       243
weighted avg       0.56      0.51      0.51       243



Model: ViT-L/14
Model parameters: 427,616,513
Input resolution: 224
Context length: 77
Vocab size: 49408

Accuracy: 0.6831275720164609
Precision: 0.7025869963369964
Recall: 0.6598076537767983
F1 Score: 0.6540962367049323
Matthews Correlation Coefficient (MCC): 0.5866950250269293
Cohen's Kappa: 0.5815404571275216
Hamming Loss: 0.3168724279835391

Confusion Matrix:
[[72  9  0  0 12]
 [ 6 34  5  0  1]
 [ 3  2 25  0  0]
 [ 8  0  8 16  0]
 [ 2 11 10  0 19]]

Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.77      0.78        93
           1       0.61      0.74      0.67        46
           2       0.52      0.83      0.64        30
           3       1.00      0.50      0.67        32
           4       0.59      0.45      0.51        42

    accuracy                           0.68       243
   macro avg       0.70      0.66      0.65       243
weighted avg       0.72      0.68      0.68       243



Model: ViT-L/14@336px
Model parameters: 427,944,193
Input resolution: 336
Context length: 77
Vocab size: 49408

Accuracy: 0.6255144032921811
Precision: 0.6483090561920349
Recall: 0.5988711347091431
F1 Score: 0.5949292342066084
Matthews Correlation Coefficient (MCC): 0.5095524392810765
Cohen's Kappa: 0.5044483786388185
Hamming Loss: 0.37448559670781895

Confusion Matrix:
[[68 11  1  0 13]
 [10 29  6  0  1]
 [ 4  2 23  1  0]
 [10  0  8 14  0]
 [ 2  8 14  0 18]]

Classification Report:
              precision    recall  f1-score   support

           0       0.72      0.73      0.73        93
           1       0.58      0.63      0.60        46
           2       0.44      0.77      0.56        30
           3       0.93      0.44      0.60        32
           4       0.56      0.43      0.49        42

    accuracy                           0.63       243
   macro avg       0.65      0.60      0.59       243
weighted avg       0.66      0.63      0.62       243

"""