In [17]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel, AutoProcessor, AutoModel


In [18]:
torch.set_default_device("cuda")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [19]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])  # CLIP normalization
])


In [20]:
cifar_data = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)

In [21]:
class_names = cifar_data.classes  # CIFAR-10 class names
text_inputs = processor(text=class_names, return_tensors="pt", padding=True)


In [22]:
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

def denormalize(img: torch.Tensor, mean: torch.Tensor, std: torch.Tensor):
    """Denormalizes the image given the mean and standard deviation."""
    return img * torch.tensor(std, device="cpu").view(3, 1, 1) + torch.tensor(mean, device="cpu").view(3, 1, 1)

In [23]:
import numpy as np
from itertools import product

seeds = list(range(10))
results = []

for seed in seeds:

    fractions = [0.9, 0.1]
    total_len = len(cifar_data)
    lengths = [int(f * total_len) for f in fractions]
    lengths[-1] = total_len - sum(lengths[:-1])
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Split the dataset
    test_data, calib_data = torch.utils.data.random_split(cifar_data, lengths, generator=generator)
    calib_loader = DataLoader(calib_data, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

    # Forward pass through CLIP

    all_labels = []
    all_predictions = []
    scores = []

    # Compute nonconformity scores

    for images, labels in calib_loader:

        pil_images = [transforms.ToPILImage()(denormalize(img, processor.image_processor.image_mean, processor.image_processor.image_std)) for img in images]
        
        # Process images using CLIP's processor (automatically normalizes them)
        inputs = processor(images=pil_images, return_tensors="pt").to("cuda")
        input_image_processed = inputs['pixel_values'].squeeze(0)

        outputs = model(**inputs, **text_inputs)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity scores
        probs = logits_per_image.softmax(dim=1)  # Convert to probabilities
        predictions = probs.argmax(dim=1)
        all_labels.extend(labels.tolist())
        all_predictions.extend(predictions.tolist())
        scores += logits_per_image.take_along_dim(torch.tensor(labels).unsqueeze(-1),dim=1).squeeze().tolist()
        


    print(scores)


    alphas = [0.02, 0.05, 0.1, 0.2]
    for alpha in alphas:
        # Compute the quantile for the nonconformity scores
        n = len(scores)
        threshold = np.quantile(scores, np.ceil((n+1)*(alpha))/n, method="inverted_cdf")
        prediction_sets = []
        all_labels = []
        all_predictions = []

        for images, labels in test_loader:

            pil_images = [transforms.ToPILImage()(denormalize(img, processor.image_processor.image_mean, processor.image_processor.image_std)) for img in images]
            
            # Process images using CLIP's processor (automatically normalizes them)
            inputs = processor(images=pil_images, return_tensors="pt").to("cuda")
            input_image_processed = inputs['pixel_values'].squeeze(0)

            outputs = model(**inputs, **text_inputs)
            logits_per_image = outputs.logits_per_image  # Image-to-text similarity scores
            probs = logits_per_image.softmax(dim=1)  # Convert to probabilities
            predictions = probs.argmax(dim=1)
            all_labels.extend(labels.tolist())
            all_predictions.extend(predictions.tolist())
            indices = (logits_per_image > threshold).nonzero(as_tuple=True)
            row_indices = [indices[1][indices[0] == i] for i in range(logits_per_image.size(0))]
            prediction_sets.extend(row_indices)

        pred_sets = [x.tolist() for x in prediction_sets]
        coverage = np.mean([all_labels[i] in pred_sets[i] for i in range(len(all_labels))])
        avg_set_size = np.mean([len(s) for s in pred_sets])
        median_set_size = np.median([len(s) for s in pred_sets])
        acc_score = accuracy_score(all_labels, all_predictions)
        results.append([seed, alpha, coverage, avg_set_size, median_set_size, acc_score])
        # print(f"accuracy =\t\t {acc_score}")
        # print(f"coverage =\t\t {coverage}")
        # print(f"mean set size =\t\t {avg_set_size}")
        # print(f"median set size =\t {median_set_size}")

        print(seed,len(calib_data), len(test_data))

  return func(*args, **kwargs)


[26.071617126464844, 25.708654403686523, 24.741636276245117, 26.970535278320312, 31.971580505371094, 29.927047729492188, 25.04533576965332, 23.204730987548828, 27.71392822265625, 21.056894302368164, 25.62864112854004, 26.237699508666992, 23.74863624572754, 27.233070373535156, 26.079936981201172, 24.95155143737793, 23.75617790222168, 26.399112701416016, 21.7309513092041, 27.720064163208008, 24.37055778503418, 28.99827003479004, 25.35816192626953, 26.025936126708984, 24.532203674316406, 24.556041717529297, 21.9766902923584, 28.77896499633789, 26.57601547241211, 27.338102340698242, 28.3701114654541, 24.658321380615234, 27.535707473754883, 24.452802658081055, 25.89330291748047, 29.384262084960938, 26.907533645629883, 25.056367874145508, 27.237817764282227, 22.52166748046875, 27.22722625732422, 24.39377212524414, 25.84654998779297, 27.79705810546875, 27.818716049194336, 32.714759826660156, 22.509963989257812, 25.529510498046875, 25.640722274780273, 27.915945053100586, 25.10906219482422, 24.

  return func(*args, **kwargs)


[22.762100219726562, 27.456279754638672, 28.905805587768555, 25.87360954284668, 27.470102310180664, 31.801795959472656, 30.860631942749023, 25.906539916992188, 28.657691955566406, 27.26395034790039, 25.2750244140625, 26.343568801879883, 27.244754791259766, 26.335399627685547, 25.106525421142578, 31.05863380432129, 27.519031524658203, 27.680015563964844, 24.692790985107422, 25.237712860107422, 26.321020126342773, 22.017099380493164, 28.158676147460938, 24.500951766967773, 25.366010665893555, 25.57065200805664, 23.40249252319336, 28.756141662597656, 30.398881912231445, 29.642223358154297, 26.02989959716797, 30.761932373046875, 29.652034759521484, 28.304401397705078, 26.788789749145508, 28.207489013671875, 24.79159927368164, 27.864131927490234, 32.65085220336914, 26.10931968688965, 21.41775131225586, 25.755638122558594, 25.390827178955078, 30.046554565429688, 24.262292861938477, 26.53919792175293, 24.812517166137695, 25.356233596801758, 27.792762756347656, 26.29544448852539, 24.4424762725

  return func(*args, **kwargs)


[25.17888641357422, 28.32703971862793, 25.626962661743164, 26.065439224243164, 27.216472625732422, 26.579294204711914, 29.20517921447754, 28.963167190551758, 25.696298599243164, 27.63831329345703, 26.952045440673828, 22.248926162719727, 23.384492874145508, 29.075557708740234, 26.092100143432617, 27.957365036010742, 24.675296783447266, 24.780685424804688, 31.567188262939453, 26.407611846923828, 29.881229400634766, 27.126249313354492, 28.350547790527344, 27.9173641204834, 26.912561416625977, 26.89937973022461, 23.078950881958008, 27.963388442993164, 25.75542640686035, 29.628881454467773, 28.173080444335938, 26.727306365966797, 28.109663009643555, 25.887969970703125, 29.386192321777344, 22.646854400634766, 30.064491271972656, 25.49402618408203, 22.381908416748047, 27.76716423034668, 23.628524780273438, 25.552188873291016, 22.998170852661133, 28.644216537475586, 30.98226547241211, 23.453731536865234, 25.28302764892578, 31.362306594848633, 30.850324630737305, 27.086788177490234, 27.71181106

  return func(*args, **kwargs)


[25.92511558532715, 28.355121612548828, 25.49911117553711, 31.44304084777832, 24.703142166137695, 28.072494506835938, 25.287639617919922, 29.38153076171875, 26.9257755279541, 28.90053939819336, 29.023462295532227, 25.249000549316406, 28.683805465698242, 23.619470596313477, 27.71107292175293, 30.588775634765625, 25.74884796142578, 28.608444213867188, 28.116924285888672, 24.311330795288086, 24.182273864746094, 23.3862361907959, 24.74578094482422, 30.099376678466797, 27.948219299316406, 22.912187576293945, 24.889739990234375, 28.507713317871094, 24.99156951904297, 24.265613555908203, 30.197650909423828, 31.159570693969727, 24.350276947021484, 23.80701446533203, 26.625246047973633, 26.45738410949707, 25.254730224609375, 29.78798484802246, 23.146663665771484, 27.340599060058594, 23.34373664855957, 31.158578872680664, 24.707855224609375, 28.860183715820312, 27.347660064697266, 29.609811782836914, 27.038061141967773, 25.920963287353516, 23.372140884399414, 24.138124465942383, 24.801025390625,

  return func(*args, **kwargs)


[26.903200149536133, 29.235570907592773, 27.34648895263672, 27.244754791259766, 27.254486083984375, 30.220678329467773, 27.658084869384766, 25.49806022644043, 28.047714233398438, 27.28965950012207, 27.471145629882812, 23.872819900512695, 28.64423179626465, 28.74825668334961, 28.21817398071289, 26.66858673095703, 26.50684356689453, 30.90094566345215, 27.568265914916992, 27.14006805419922, 24.898162841796875, 23.974803924560547, 27.25086212158203, 26.73405647277832, 28.018375396728516, 30.391416549682617, 23.822729110717773, 24.237546920776367, 30.368391036987305, 28.459808349609375, 26.31705093383789, 31.04182243347168, 26.74043846130371, 28.81772804260254, 27.68631935119629, 24.836894989013672, 29.953598022460938, 33.657325744628906, 26.39014434814453, 26.485820770263672, 24.09297752380371, 22.2784366607666, 31.167551040649414, 25.118122100830078, 25.23868179321289, 26.209163665771484, 29.261791229248047, 22.593624114990234, 27.219623565673828, 24.389219284057617, 26.334815979003906, 2

  return func(*args, **kwargs)


[27.814239501953125, 31.868209838867188, 22.738664627075195, 25.0175838470459, 25.30113983154297, 28.508726119995117, 22.710708618164062, 25.961679458618164, 29.323625564575195, 25.282896041870117, 27.244308471679688, 25.48639678955078, 25.234771728515625, 31.94669532775879, 28.92306137084961, 31.247440338134766, 24.53476905822754, 29.976404190063477, 29.59800910949707, 27.280691146850586, 26.483001708984375, 26.98031997680664, 30.0057373046875, 26.52389144897461, 26.471046447753906, 25.46332550048828, 28.53177833557129, 26.903270721435547, 26.1904296875, 28.294506072998047, 26.25679588317871, 27.470571517944336, 23.43621063232422, 24.307655334472656, 27.680015563964844, 24.2386474609375, 25.87360954284668, 25.671541213989258, 26.898378372192383, 30.02449607849121, 26.679506301879883, 28.06100845336914, 26.984724044799805, 28.461149215698242, 31.541828155517578, 29.871612548828125, 25.299612045288086, 23.993755340576172, 24.14185333251953, 31.151649475097656, 26.194990158081055, 25.682

  return func(*args, **kwargs)


[30.28114128112793, 28.566497802734375, 27.92881202697754, 24.406396865844727, 27.12752914428711, 26.950660705566406, 27.684343338012695, 32.19256591796875, 23.465065002441406, 28.643781661987305, 24.609397888183594, 23.641706466674805, 25.526836395263672, 29.334041595458984, 24.759735107421875, 26.058643341064453, 30.910316467285156, 29.583251953125, 25.50620460510254, 25.154457092285156, 29.41361427307129, 30.813684463500977, 27.863649368286133, 26.44109344482422, 24.839017868041992, 24.303905487060547, 32.31694793701172, 27.017303466796875, 27.625761032104492, 28.309253692626953, 29.259485244750977, 23.48627471923828, 32.82435607910156, 24.295743942260742, 25.824695587158203, 23.823339462280273, 28.07745361328125, 27.885665893554688, 26.10776138305664, 29.247018814086914, 24.230609893798828, 25.530349731445312, 25.484567642211914, 26.715267181396484, 26.32097625732422, 23.386781692504883, 25.19097900390625, 30.72998809814453, 24.07415199279785, 28.610342025756836, 28.298542022705078

  return func(*args, **kwargs)


[26.754131317138672, 25.28302764892578, 30.715856552124023, 28.663665771484375, 25.567960739135742, 26.454463958740234, 23.560131072998047, 26.89361000061035, 23.708782196044922, 27.053504943847656, 28.130638122558594, 25.46332550048828, 25.067752838134766, 25.848329544067383, 27.68151092529297, 27.152177810668945, 31.407794952392578, 31.297164916992188, 23.11898422241211, 26.495330810546875, 27.299510955810547, 25.254730224609375, 25.34688377380371, 28.343061447143555, 28.204662322998047, 21.155080795288086, 24.846891403198242, 29.64426612854004, 28.090848922729492, 26.927345275878906, 26.745182037353516, 28.578975677490234, 27.84810447692871, 26.8494930267334, 30.98226547241211, 26.002389907836914, 26.580110549926758, 26.815505981445312, 29.089111328125, 28.41048812866211, 27.814239501953125, 22.212413787841797, 24.01544189453125, 22.797731399536133, 30.15604019165039, 24.288705825805664, 24.30720329284668, 24.882970809936523, 27.35338592529297, 26.03546714782715, 27.7921199798584, 2

  return func(*args, **kwargs)


[32.31989288330078, 25.750776290893555, 21.41775131225586, 25.898433685302734, 27.313175201416016, 23.825965881347656, 27.217641830444336, 26.343568801879883, 27.352603912353516, 22.445913314819336, 26.899429321289062, 28.85500144958496, 25.311391830444336, 27.76251792907715, 27.327590942382812, 25.69289779663086, 28.803010940551758, 27.93332290649414, 28.53507423400879, 25.491634368896484, 24.066862106323242, 27.911996841430664, 25.073894500732422, 26.860624313354492, 31.704092025756836, 26.76797866821289, 29.86167335510254, 24.171070098876953, 28.621841430664062, 30.669736862182617, 22.157852172851562, 25.648550033569336, 25.46376609802246, 27.37433624267578, 25.770591735839844, 28.53013038635254, 29.45604705810547, 32.25091552734375, 25.333694458007812, 25.691560745239258, 30.037185668945312, 22.392671585083008, 27.826875686645508, 23.020811080932617, 26.04001808166504, 26.10934066772461, 24.837186813354492, 28.131141662597656, 23.45197868347168, 26.267013549804688, 27.2669048309326

  return func(*args, **kwargs)


[26.17561149597168, 25.594388961791992, 22.874258041381836, 26.869184494018555, 27.22165870666504, 27.436279296875, 26.799806594848633, 26.43955421447754, 29.126110076904297, 29.66977882385254, 25.527339935302734, 27.76251792907715, 25.138723373413086, 26.163307189941406, 26.198795318603516, 27.196273803710938, 30.12479591369629, 31.79623794555664, 29.101457595825195, 24.20018768310547, 23.838720321655273, 25.51698112487793, 27.862401962280273, 26.503541946411133, 25.556337356567383, 25.45867919921875, 23.801067352294922, 27.945201873779297, 25.21749496459961, 28.51174545288086, 27.697736740112305, 22.368450164794922, 21.197328567504883, 23.775611877441406, 23.445247650146484, 23.385644912719727, 25.013917922973633, 26.631677627563477, 27.806360244750977, 28.27010726928711, 26.017560958862305, 27.374385833740234, 24.60886001586914, 31.459348678588867, 21.514482498168945, 24.6646671295166, 25.414339065551758, 24.606483459472656, 24.13692283630371, 26.93402862548828, 29.681962966918945, 

In [24]:
import pandas as pd
result_df = pd.DataFrame(data=results, columns=["seed", "alpha", "coverage", "avg_set_size", "median_set_size", "acc_score"])

In [25]:
result_df.to_csv("cifar100_clip_results.csv")