In [1]:
import torch
from PIL import Image
import open_clip
from datasets import load_dataset
from torch.nn.functional import cosine_similarity
from sklearn.metrics import confusion_matrix
import numpy as np

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

## Prepare Datasets

### Download datasets from GLUE benchmark

In [84]:
dataset_mrpc = load_dataset("nyu-mll/glue", "mrpc")
dataset_rte = load_dataset("nyu-mll/glue", "rte")
dataset_wnli = load_dataset("nyu-mll/glue", "wnli")

### Check all were downloaded

In [82]:

print(dataset_mrpc['train']['label'])
print(dataset_rte['train']['label'])
print(dataset_wnli['train']['label'])

[5.0, 3.799999952316284, 3.799999952316284, 2.5999999046325684, 4.25, 4.25, 0.5, 1.600000023841858, 2.200000047683716, 5.0, 4.199999809265137, 4.599999904632568, 3.867000102996826, 4.666999816894531, 1.6670000553131104, 3.75, 5.0, 0.5, 3.799999952316284, 5.0, 3.200000047683716, 2.799999952316284, 4.599999904632568, 3.0, 5.0, 4.800000190734863, 5.0, 4.199999809265137, 4.199999809265137, 4.0, 4.0, 4.908999919891357, 3.0, 2.4000000953674316, 4.199999809265137, 3.4000000953674316, 5.0, 3.75, 2.75, 5.0, 4.0, 3.5999999046325684, 1.600000023841858, 1.75, 5.0, 1.0, 1.0, 2.375, 3.799999952316284, 3.200000047683716, 3.200000047683716, 4.400000095367432, 3.75, 4.75, 3.200000047683716, 1.555999994277954, 3.937999963760376, 5.0, 5.0, 4.0, 1.600000023841858, 4.75, 3.5, 1.399999976158142, 1.399999976158142, 4.0, 5.0, 3.8329999446868896, 0.6000000238418579, 2.9170000553131104, 4.199999809265137, 2.0, 2.5999999046325684, 1.600000023841858, 2.0, 4.199999809265137, 2.0, 4.800000190734863, 4.4000000953674

## Prepare for tests
we use pretrained models RN50 RN101 ViT-B-32 for tests,  
we create 3 lists holding models we are going to test and each datasets we are going to test them on

In [88]:
models = ['RN50','RN101','ViT-B-32']
datasets = [dataset_mrpc,dataset_rte,dataset_wnli]
datasets_names = ["mrpc","rte","wnli"]

In [101]:
def run_tests(dataset,n,set,model,tokenizer):
    result = {"index":[],"recall":[],"precision":[],"accuracy":[]}
    for i in range(n):
        threshold = 0.9
        prediction = []
        actual = []
        for test_case in dataset[set]:   

            tokenized_text_1 = tokenizer([test_case['sentence1']])
            tokenized_text_2 = tokenizer([test_case['sentence2']])

            with torch.no_grad(), torch.cuda.amp.autocast():
                text_embedding_1 = model.encode_text(tokenized_text_1)
                text_embedding_2 = model.encode_text(tokenized_text_2)

            if cosine_similarity(text_embedding_1,text_embedding_2) < 0.9:
                prediction.append(-1)
            else:
                prediction.append(1)
            actual.append(test_case['label'])


        cm = confusion_matrix(actual,prediction)
        tn, fp, fn, tp =  [i for i in cm.ravel() if i != 0]
        recall = tp / (tp + fn)       
        precision = tp / (tp + fp)
        accuracy = (tp+tn)/(tn+fp+fn+tp)

        result["index"].append(i)
        result["recall"].append(recall)
        result["precision"].append(precision)
        result["accuracy"].append(accuracy)

    
    return result

In [100]:
def print_results(res:dict):
    print("recall: ",np.mean(res["recall"]))
    print("precision: ",np.mean(res["precision"]))
    print("accuracy: ",np.mean(res["accuracy"]))

## Tests

we run tests and print results for each combination of model and dataset

In [102]:
for model_name in models:
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='openai')
    tokenizer = open_clip.get_tokenizer(model_name)

    for dataset in datasets:
        i=0
        result = run_tests(dataset=dataset,n=1,set = 'train',model=model,tokenizer = tokenizer)
        print(model_name)
        print(datasets_names[i])
        print_results(result)
        i+=1



RN50
mrpc
recall:  0.6616814874696847
precision:  0.7599814298978644
accuracy:  0.6308615049073064




RN50
mrpc
recall:  0.11522965350523771
precision:  0.44135802469135804
accuracy:  0.4863453815261044




RN50
mrpc
recall:  0.3108974358974359
precision:  0.49489795918367346
accuracy:  0.5055118110236221




RN101
mrpc
recall:  0.5739692805173807
precision:  0.7763805358119191
accuracy:  0.601145038167939




RN101
mrpc
recall:  0.09024979854955681
precision:  0.42424242424242425
accuracy:  0.4855421686746988




RN101
mrpc
recall:  0.23397435897435898
precision:  0.5177304964539007
accuracy:  0.5165354330708661




ViT-B-32
mrpc
recall:  0.7611156022635408
precision:  0.734399375975039
accuracy:  0.653217011995638




ViT-B-32
mrpc
recall:  0.1587429492344883
precision:  0.43680709534368073
accuracy:  0.478714859437751




ViT-B-32
mrpc
recall:  0.391025641025641
precision:  0.4860557768924303
accuracy:  0.49763779527559054
