In [None]:
import argparse
import os
import torch
import torchvision
import torchvision.datasets as dset
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.models import resnet18
from transformers import set_seed
from tqdm import tqdm
import pickle
import pandas as pd
from itertools import product
from sklearn.metrics import accuracy_score, balanced_accuracy_score

from torch.utils.data import Subset

from torchcp.classification import Metrics
from torchcp.classification.predictor import SplitPredictor
from torchcp.classification.score import THR, APS, SAPS, RAPS, Margin, TOPK
from torch import nn
from torch.utils.data import DataLoader

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True)
num_ftrs = model.fc.in_features

model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load("finetuned_models/clf_cifar10h_dbg.pth",map_location=torch.device('cpu')))
model = model.to(device)
model.eval()
    
num_classes = 10


dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
cal_data, test_data = torch.utils.data.random_split(dataset, [10000, 40000])

cal_data = Subset(cal_data, range(0,100))
test_data = Subset(test_data, range(500,600))

cal_data_loader = DataLoader(cal_data,batch_size=64)
test_data_loader = DataLoader(test_data,batch_size=64)

# Extract logits and labels
cal_logits = torch.stack([sample[0] for sample in cal_data])
cal_labels = torch.stack([torch.tensor(sample[1]) for sample in cal_data])
test_logits = torch.stack([sample[0] for sample in test_data])
test_labels = torch.stack([torch.tensor(sample[1]) for sample in test_data])

#######################################
# A standard process of conformal prediction
#######################################
scoring_methods = [THR(), APS(), RAPS(penalty=0)]
alphas = [0.05,0.1,0.2]
for score_function, alpha in product(scoring_methods, alphas):
    predictor = SplitPredictor(score_function,model)
    predictor.calibrate(cal_data_loader, alpha)

    predictions_sets_list = []
    predictions_list = []
    labels_list = []
    logits_list = []
    feature_list = []

    # Evaluate in inference mode
    predictor._model.eval()
    with torch.no_grad():
        for batch in test_data_loader:
            # Move batch to device and get predictions
            inputs = batch[0]
            labels = batch[1]

            # Get predictions as bool tensor (N x C)
            batch_predictions = predictor.predict(inputs)

            logits = model(inputs)

            predicted_label = logits.argmax(axis=1)
            # Accumulate predictions and labels
            predictions_sets_list.append(batch_predictions)
            predictions_list.append(predicted_label)
            labels_list.append(labels)
            logits_list.append(logits)
            feature_list.append(inputs)

        # Concatenate all batches
        val_prediction_sets = torch.cat(predictions_sets_list, dim=0)  # (N_val x C)
        val_predictions = torch.cat(predictions_list, dim=0)
        val_labels = torch.cat(labels_list, dim=0)  # (N_val,)
        val_logits = torch.cat(logits_list, dim=0)
        val_features = torch.cat(feature_list, dim=0)

        y_pred = val_predictions.detach().cpu().numpy()
        y_true = val_labels.detach().cpu().numpy()
        # Compute evaluation metrics
        metric = Metrics()

        metrics = {
            "coverage_rate": metric("coverage_rate")(
                prediction_sets=val_prediction_sets, labels=val_labels
            ),
            "average_size": metric("average_size")(
                prediction_sets=val_prediction_sets, labels=val_labels
            ),
            "cov_gap": metric("CovGap")(
                prediction_sets=val_prediction_sets,
                labels=val_labels,
                alpha=alpha,
                num_classes=num_classes,
            ),
            "vio_classes": metric("VioClasses")(
                prediction_sets=val_prediction_sets,
                labels=val_labels,
                alpha=alpha,
                num_classes=num_classes,
            ),

            "sscv": metric("SSCV")(
                prediction_sets=val_prediction_sets,
                labels=val_labels,
                alpha=alpha,
            ),
            # "wsc": metric("WSC")(
            #     val_features,
            #     prediction_sets=val_prediction_sets,
            #     labels=val_labels,
            # ),
            "acc": accuracy_score(y_true, y_pred),
            "bacc": balanced_accuracy_score(y_true, y_pred),
        }




Files already downloaded and verified


ValueError: features must be 2D tensor, got shape torch.Size([100, 3, 224, 224])