TODO:

*   Create the dataloader/Dataset



References:

---

1.   https://medium.com/the-owl/extracting-output-from-intermediate-layer-in-any-pretrained-model-the-pytorch-way-b201926a1eec

2. Sohn, Kihyuk, et al. "Fixmatch: Simplifying semi-supervised learning with consistency and confidence." Advances in neural information processing systems 33 (2020): 596-608.

3. Wallin, Erik, et al. "ProSub: Probabilistic Open-Set Semi-supervised Learning with Subspace-Based Out-of-Distribution Detection." European Conference on Computer Vision. Cham: Springer Nature Switzerland, 2024.

In [1]:
"Required Imports"
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torchvision.models import resnet50, resnet18
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from collections import defaultdict
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import random
import os
from itertools import compress
from torchvision.models._utils import IntermediateLayerGetter
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb

In [2]:
"For code reproducability"
MANUAL_SEED = 42
torch.manual_seed(MANUAL_SEED)
torch.cuda.manual_seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
random.seed(MANUAL_SEED)

Dataset description:
*   Contains a total of 1,00,000 images
*   200 Classes each class contains 500 Images
*   resolution of image is 64x64

In [3]:
"Source: https://github.com/seshuad/IMagenet"
"Extracting the ImageNet folder out of some random repository I found on github"
! git clone https://github.com/seshuad/IMagenet
! mv IMagenet/tiny-imagenet-200 ./
! rm -r IMagenet

Cloning into 'IMagenet'...
remote: Enumerating objects: 120594, done.[K
remote: Total 120594 (delta 0), reused 0 (delta 0), pack-reused 120594 (from 1)[K
Receiving objects: 100% (120594/120594), 212.68 MiB | 17.76 MiB/s, done.
Resolving deltas: 100% (1115/1115), done.
Updating files: 100% (120206/120206), done.


In [4]:
dataset_path = './tiny-imagenet-200'
dataset_classes = sorted(os.listdir(os.path.join(dataset_path, 'train')))

ID_MASK = np.concatenate((np.ones(100, dtype=bool), np.zeros(100, dtype=bool)))
OOD_MASK = ~ID_MASK

OOD_CLASSES = list(compress(dataset_classes, OOD_MASK))
ID_CLASSES = list(compress(dataset_classes, ID_MASK))
ID_classes = ID_CLASSES
OOD_classes = OOD_CLASSES
ID_OOD_ratio = len(ID_classes)/len(OOD_classes)

ltoi = {l: i for i, l in enumerate(ID_CLASSES)}
itol = {i: l for i, l in enumerate(ID_CLASSES)}

In [5]:
from collections import defaultdict

train_path = os.path.join(dataset_path, 'train')

train_Labeled = []
train_Unlabeled = []

all_labeled = []
all_unlabeled = []

for label in os.listdir(train_path):
    image_dir = os.path.join(train_path, label, 'images')
    image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir)]

    if label in OOD_classes:
        all_unlabeled.extend(image_paths)
    else:
        all_labeled.extend([(img_path, ltoi[label]) for img_path in image_paths])

np.random.shuffle(all_labeled)
np.random.shuffle(all_unlabeled)

train_Labeled = all_labeled[:10_000]
remaining_labeled = all_labeled[10_000:]
train_Unlabeled = all_unlabeled + [x[0] for x in remaining_labeled]  # Only image paths for unlabeled

print(f"# Labeled = {len(train_Labeled)}, # Unlabeled = {len(train_Unlabeled)}")


# Labeled = 10000, # Unlabeled = 90000


In [6]:
print(f"Statistic of the Train Data")
print(f"Number of ID classes are: {len(ID_classes)}, Number of OOD classes are: {len(OOD_classes)}")
print(f"Number of Labeled Samples: {len(train_Labeled)}, Number of UnLabeled Samples: {len(train_Unlabeled)}")

Statistic of the Train Data
Number of ID classes are: 100, Number of OOD classes are: 100
Number of Labeled Samples: 10000, Number of UnLabeled Samples: 90000


In [7]:
val_path = os.path.join(dataset_path, 'val', 'images')
annotation_path = './tiny-imagenet-200/val/val_annotations.txt'

val_data = []
with open(annotation_path, 'r') as f:
    val_annotations = [(line.split('\t')[0], line.split('\t')[1]) for line in f.readlines()]

for img_name, label in val_annotations:
    img_path = os.path.join(val_path, img_name)
    val_data.append((img_path, ltoi.get(label, -1))) # Returning -1 for labels not ID classes


In [8]:
class TinyImageNet(Dataset):
    def __init__(self, data:list, labeled=True, is_val=False):
        super().__init__()
        self.data = data
        self.is_labeled = labeled
        self.is_val = is_val

    def __len__(self):
        return len(self.data)

    "For some reasons it takes all ram while running it transform"
    def strong_augment(self, image):
        strong_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.8, 1.2)),
            transforms.ToTensor(),
        ])
        return strong_transform(image)

    def weak_augment(self, image):
        weak_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ToTensor()
                            ])
        return weak_transform(image)

    def __getitem__(self, idx):
        if(self.is_val):
            X, y = self.data[idx]
            img = Image.open(X).convert('RGB')
            img = self.weak_augment(img)
            return img, torch.tensor(y)
        if(self.is_labeled):
            X, y = self.data[idx]
            img = Image.open(X).convert('RGB')
            alpha_img = self.weak_augment(img)
            return alpha_img, torch.tensor(y)
        else:
            X = self.data[idx]
            img = Image.open(X).convert('RGB')
            alpha_img, beta_img = self.weak_augment(img), self.strong_augment(img)
            return alpha_img, beta_img

In [9]:
unlabeled_batch_size = 9*128
batch_size = 128
test_size = 1024

In [10]:
train_Labeled_dataset = TinyImageNet(train_Labeled, labeled=True)
train_Unlabeled_dataset = TinyImageNet(train_Unlabeled, labeled=False)
val_dataset = TinyImageNet(val_data, labeled=True, is_val=True)
train_labeled_dataloader = DataLoader(train_Labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
train_unlabeled_dataloader = DataLoader(train_Unlabeled_dataset, batch_size=unlabeled_batch_size, shuffle=True, drop_last=True)
test_labeled_dataloader = DataLoader(val_dataset, batch_size=test_size, shuffle=False, drop_last=True)

In [11]:
import timm
class Classifier(nn.Module):
    def __init__(self, in_f, out_f):
        super().__init__()
        self.classifier = nn.Linear(in_f, out_f)

    def forward(self, x):
        return self.classifier(x)

feature_dim = 200
backbone = timm.create_model('resnet34', pretrained=True)
backbone.fc = nn.Linear(backbone.fc.in_features, feature_dim)
classifier = Classifier(feature_dim, 100)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

In [12]:
def get_batch_mean(Z, y, num_classes, feature_dim, batch_size, device):
    sums = torch.zeros(num_classes, feature_dim, dtype=Z_l.dtype, device=device)
    counts = torch.zeros(num_classes, dtype=torch.float, device=device)
    sums = sums.index_add(0, y_l, Z_l, alpha=1)
    ones = torch.ones(batch_size, dtype=torch.float, device=device)
    counts = counts.index_add(0, y_l, ones)
    counts[counts == 0] = 1.0
    avg_values = sums / counts.unsqueeze(1)
    return avg_values


def subspace_score(Z, batch_means):
    """
    Z: (batch_size, feature_dim)
    batch_means: (num_classes, feature_dim)
    Returns:
        cosine similarity between Z and its projection on the subspace
        spanned by batch_means — shape (batch_size, 1)
    """
    Q, _ = torch.linalg.qr(batch_means.T)
    proj_Z = Z @ Q @ Q.T
    cos_sim = F.cosine_similarity(Z, proj_Z, dim=1).clamp(0, 1)
    return cos_sim.unsqueeze(1)


In [13]:
"Taken from the code given by the authors... "
def beta_pdf(x, alpha, beta, loc=0.0, scale=1.0):
    x = (x - loc) / scale
    alpha = torch.as_tensor(alpha, dtype=torch.float32, device=x.device)
    beta = torch.as_tensor(beta, dtype=torch.float32, device=x.device)
    scale = torch.as_tensor(scale, dtype=torch.float32, device=x.device)
    def xlogy(a, b):
        return torch.where(a == 0, torch.zeros_like(b), a * torch.log(b + 1e-10))
    def xlog1py(a, y):
        return torch.where(a == 0, torch.zeros_like(y), a * torch.log1p(y + 1e-10))

    log_unnormalized = xlogy(alpha - 1.0, x) + xlog1py(beta - 1.0, -x)
    log_normalization = torch.lgamma(alpha) + torch.lgamma(beta) - torch.lgamma(alpha + beta)
    log_prob = log_unnormalized - log_normalization
    log_prob = torch.where((x >= 0) & (x <= 1), log_prob, torch.tensor(float('-inf'), device=x.device))

    return torch.exp(log_prob) / scale

def update_beta_parameters(w_id, w_ood, s_l, s_u ,alpha_id, beta_id, alpha_ood, beta_ood, l):
    # Detaching from computational graph so that it won't cause some bullshit error
    w_id = w_id.squeeze(1).detach(); w_ood = w_ood.squeeze(1).detach()
    s_l = s_l.squeeze(1).detach(); s_u = s_u.squeeze(1).detach()
    nu_id = torch.sum(s_l) + torch.dot(w_id, s_u)/(s_l.size(0) + torch.sum(w_id))
    sigma2_id = (torch.sum(s_l - nu_id)**2 + torch.dot(w_id, (s_u - nu_id)**2))/(s_l.size(0) + torch.sum(w_id))

    nu_ood = torch.dot(w_ood, s_u)/(torch.sum(w_ood))
    sigma2_ood = torch.dot(w_ood, (s_u - nu_id)**2)/((torch.sum(w_ood)))

    _alpha_id = nu_id*((nu_id*(1-nu_id))/sigma2_id - 1)
    _alpha_ood = nu_ood*((nu_ood*(1-nu_ood))/sigma2_ood - 1)
    _beta_id = (1 - nu_id)*((nu_id*(1-nu_id))/sigma2_id - 1)
    _beta_ood = (1 - nu_ood)*((nu_ood*(1-nu_ood))/sigma2_ood - 1)

    return alpha_id*l + (1 - l)*_alpha_id, beta_id*l + (1 - l)*_beta_id, alpha_ood*l + (1 - l)*_alpha_ood, beta_ood*l + (1 - l)*_beta_ood

def get_p_id(s, alpha_id, beta_id, alpha_ood, beta_ood, pi=1.0):
    """
    Compute the ID probability using Beta PDFs instead of torch.distributions.Beta,
    following the formulation in the paper (with loc/scale support implicitly assumed as 0/1).
    """
    beta_pdf_id = beta_pdf(s, alpha_id, beta_id)
    beta_pdf_ood = beta_pdf(s, alpha_ood, beta_ood)
    numerator = beta_pdf_id * pi
    denominator = numerator + beta_pdf_ood * (1 - pi)
    p_id = numerator / (denominator + 1e-8)
    return p_id


In [14]:
"this function has been tested"
class SelfSupervisionLoss(nn.Module):
    def __init__(self, in_f, out_f):
        super().__init__()
        self.proj = nn.Linear(in_f, out_f, bias=False)

    def forward(self, x_w, x_s):
        x_s_proj = self.proj(x_s)
        x_w = x_w.detach()
        dot_product = torch.sum(x_s_proj * x_w, dim=1)
        norm_product = torch.norm(x_s_proj, dim=1) * torch.norm(x_w, dim=1)
        cosine_sim = dot_product / (norm_product + 1e-8)
        return -cosine_sim.mean()

"Can't this be negative??"
"Though i did test this.."
class SubspaceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, m_id, m_ood, s_z):
        m_id, m_ood = m_id.to(dtype=float), m_ood.to(dtype=float)
        return torch.mean((m_ood - m_id) * s_z)

"This function has been tested"

class SemiSupervisedLoss(nn.Module):
    def __init__(self, threshold):
        super().__init__()
        self.threshold = threshold

    def forward(self,y_pred, y_true, mask_id):
        """
        logits of model on weakly augmented data: y_true
        logits of model on strongly augmented data: y_pred
        boolean tensors for ID/OOD classification: mask_id
        """
        y_true = y_true.detach()
        batch_size = y_pred.size(0)
        entropy = F.cross_entropy(y_pred, torch.argmax(y_true, dim=1), reduction="none") # input logits, and true labels
        y_prob_true = F.softmax(y_true, dim=1)
        sample_id = torch.logical_and(torch.max(y_prob_true, dim=1)[0] >= self.threshold, (mask_id == 1))
        return (torch.sum(entropy[sample_id]))/batch_size

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_epoch = 10
alpha_id, beta_id = torch.tensor(10, device=device),torch.tensor(2, device=device)
alpha_ood, beta_ood = torch.tensor(2, device=device), torch.tensor(10, device=device)
warm_up = 5
lr = 3e-3
backbone.to(device)
classifier.to(device)
supervised_loss_fn = nn.CrossEntropyLoss()
selfsupervised_loss_fn = SelfSupervisionLoss(feature_dim, feature_dim).to(device)
subspace_loss_fn = SubspaceLoss().to(device)
semisupervised_loss_fn = SemiSupervisedLoss(0.9).to(device)
all_params = list(backbone.parameters()) + list(classifier.parameters()) + list(selfsupervised_loss_fn.parameters())
optimizer = torch.optim.SGD(all_params, lr=0.01, momentum=0.9, nesterov=True)
class_means = torch.zeros((len(ID_CLASSES), feature_dim), requires_grad=False, device=device)
class_mean_EMA = 0.2
test_accuracy, test_classification_report, test_confusion_matrix, test_auc_roc_score = [], [], [], []
num_batches = len(train_labeled_dataloader)
batch_means = torch.zeros(len(ID_CLASSES), feature_dim, device=device)
scheduler = CosineAnnealingLR(optimizer, T_max=100)

In [16]:
wandb.init(
    project="Tiny-Imagenet-Prosub",
    name=f"run_{wandb.util.generate_id()}",
    config={
        "epochs": num_epoch,
        "learning_rate": lr,
        "warmup_epochs": warm_up,
        "scheduler": "CosineAnnealingLR",
        "optimizer": "SGD w/ Nesterov",
        "alpha_id": alpha_id.item(),
        "beta_id": beta_id.item(),
        "alpha_ood": alpha_ood.item(),
        "beta_ood": beta_ood.item()
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msujal22514[0m ([33msujal22514-indraprastha-institute-of-information-technol[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
for epoch in range(num_epoch):
    backbone.train();classifier.train();selfsupervised_loss_fn.train()
    for (X_l, y_l), (X_w, X_s) in tqdm(zip(train_labeled_dataloader, train_unlabeled_dataloader), total=num_batches, desc="Training:"):
        "========================Loading inputs to device========================"
        X_l = X_l.to(device)
        y_l = y_l.to(device)
        X_w = X_w.to(device)
        X_s = X_s.to(device)
        "========================Inputs loaded========================"

        "======================= Processing the labeled dataset ======================="
        Z_l = backbone(X_l)


        "=====================Calculating labeled subspace score============================"
        y_pred_l = classifier(Z_l)
        s_l = subspace_score(Z_l, batch_means.detach())
        "=====================Label subspace score calculated============================"
        "============================Label data processed ============================ "

        "======================= Processing the labeled dataset ======================="
        Z_s_u, Z_w_u = backbone(X_w), backbone(X_s)

        " ======================= Calculating subspace score and mask ======================="
        s_u = subspace_score(Z_w_u.detach(), batch_means.detach())
        Z_s_u_for_cls = Z_s_u.detach()
        Z_w_u_for_cls = Z_w_u.detach()

        y_pred_u = classifier(Z_s_u_for_cls)
        y_gold_u = classifier(Z_w_u_for_cls)
        prob_id_u = get_p_id(s_u, alpha_id, beta_id, alpha_ood, beta_ood) # size should be (Un_labeled_batch_size)
        prob_ood_u = 1 - prob_id_u
        uniform = torch.rand(unlabeled_batch_size, device=device)
        mask_id = (prob_id_u.squeeze(1) >= uniform)
        mask_ood = ~mask_id
        " ======================= Calculated subspace score and mask ======================="

        " ============================Calculating losses ============================"
        loss_selfsupervised = selfsupervised_loss_fn(Z_w_u, Z_s_u)
        loss_semi_supervised = semisupervised_loss_fn(y_pred_u, y_gold_u, mask_id)
        loss_subspace = subspace_loss_fn(mask_id, mask_ood, s_u)
        loss_supervised = supervised_loss_fn(y_pred_l, y_l)
        " ============================Losses Calculated ============================"

        " ============================ Optimizing the model ============================"
        if epoch < warm_up:
            loss = 40*loss_selfsupervised + loss_supervised
        else:
            loss = 40*loss_selfsupervised + loss_semi_supervised + loss_subspace + loss_supervised
        loss.backward()
        optimizer.zero_grad()
        optimizer.step()
        " ============================ Model optimized ============================"

        " ======================== Performing IIM Step ======================== "
        alpha_id, alpha_ood, alpha_ood, beta_ood = update_beta_parameters(prob_id_u,prob_ood_u, s_l, s_u ,alpha_id, beta_id, alpha_ood, beta_ood, 0.2)
        " ======================== IIM step done ======================== "
        "======================= Updating class means ======================="
        batch_means = get_batch_mean(Z_l, y_l, len(ID_CLASSES), feature_dim, batch_size, device)
        class_means = class_means*class_mean_EMA + (1 - class_mean_EMA)*batch_means
        "======================= Class means updataed ======================="
        wandb.log({
            "epoch": epoch,
            "loss_supervised": loss_supervised.item(),
            "loss_selfsupervised": loss_selfsupervised.item(),
            "loss_semi_supervised": loss_semi_supervised.item() if epoch >= warm_up else 0.0,
            "loss_subspace": loss_subspace.item() if epoch >= warm_up else 0.0,
        })

    scheduler.step()
    backbone.eval();classifier.eval();selfsupervised_loss_fn.eval()
    test_predictions,test_targets = [], []
    with torch.inference_mode():
        for X, y_true in test_labeled_dataloader:
            y_true = y_true.to(device)
            X = X.to(device)
            Z = backbone(X)
            y_logits = classifier(Z)
            prob_id = get_p_id(Z, alpha_id, beta_id, alpha_ood, beta_ood)
            y_softmax = F.softmax(y_logits, dim=1)
            y_pred = torch.argmax(y_softmax, dim=1, keepdim=False)
            uniform = torch.rand(size=(test_size, ), device=device)
            y_pred[(torch.max(y_softmax, dim=1)[0] < uniform)] = -1
            test_predictions.extend(y_true.cpu().detach().numpy())
            test_targets.extend(y_pred.cpu().detach().numpy())

    acc = accuracy_score(test_predictions, test_targets)
    print(f"epoch = {epoch}, acc = {acc:.4f}")
    test_accuracy.append(acc)
    test_classification_report.append(classification_report(test_predictions, test_targets, zero_division=0))
    test_confusion_matrix.append(confusion_matrix(test_predictions, test_targets))
    torch.save(backbone.state_dict(), f"backbone_epoch_{epoch}.pth")
    torch.save(classifier.state_dict(), f"classifier_epoch_{epoch}.pth")
    torch.save(selfsupervised_loss_fn.state_dict(), f"selfsupervised_loss_fn_epoch_{epoch}.pth")

wandb.finish()

Training:: 100%|██████████| 78/78 [05:33<00:00,  4.27s/it]


epoch = 0, acc = 0.4929


Training:: 100%|██████████| 78/78 [05:46<00:00,  4.44s/it]


epoch = 1, acc = 0.4922


Training:: 100%|██████████| 78/78 [05:58<00:00,  4.59s/it]


epoch = 2, acc = 0.4936


Training:: 100%|██████████| 78/78 [05:34<00:00,  4.29s/it]


epoch = 3, acc = 0.4933


Training:: 100%|██████████| 78/78 [05:30<00:00,  4.24s/it]


epoch = 4, acc = 0.4932


Training:: 100%|██████████| 78/78 [05:32<00:00,  4.26s/it]


epoch = 5, acc = 0.4927


Training:: 100%|██████████| 78/78 [05:37<00:00,  4.32s/it]


epoch = 6, acc = 0.4933


Training:: 100%|██████████| 78/78 [05:31<00:00,  4.25s/it]


epoch = 7, acc = 0.4915


Training:: 100%|██████████| 78/78 [05:32<00:00,  4.26s/it]


epoch = 8, acc = 0.4933


Training::  58%|█████▊    | 45/78 [03:12<02:17,  4.18s/it]

In [None]:
beta_parameters = torch.stack([alpha_id, beta_id, alpha_ood, beta_ood])
torch.save(beta_parameters, "beta_parameters.pth")
torch.save(batch_means, "batch_means.pth")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import timm


def get_batch_mean(Z, y, num_classes, feature_dim, batch_size, device):
    sums = torch.zeros(num_classes, feature_dim, dtype=Z_l.dtype, device=device)
    counts = torch.zeros(num_classes, dtype=torch.float, device=device)
    sums = sums.index_add(0, y_l, Z_l, alpha=1)
    ones = torch.ones(batch_size, dtype=torch.float, device=device)
    counts = counts.index_add(0, y_l, ones)
    counts[counts == 0] = 1.0
    avg_values = sums / counts.unsqueeze(1)
    return avg_values


def subspace_score(Z, batch_means):
    """
    Z: (batch_size, feature_dim)
    batch_means: (num_classes, feature_dim)
    Returns:
        cosine similarity between Z and its projection on the subspace
        spanned by batch_means — shape (batch_size, 1)
    """
    Q, _ = torch.linalg.qr(batch_means.T)
    proj_Z = Z @ Q @ Q.T
    cos_sim = F.cosine_similarity(Z, proj_Z, dim=1).clamp(0, 1)
    return cos_sim.unsqueeze(1)

"Taken from the code given by the authors... "
def beta_pdf(x, alpha, beta, loc=0.0, scale=1.0):
    x = (x - loc) / scale
    alpha = torch.as_tensor(alpha, dtype=torch.float32, device=x.device)
    beta = torch.as_tensor(beta, dtype=torch.float32, device=x.device)
    scale = torch.as_tensor(scale, dtype=torch.float32, device=x.device)
    def xlogy(a, b):
        return torch.where(a == 0, torch.zeros_like(b), a * torch.log(b + 1e-10))
    def xlog1py(a, y):
        return torch.where(a == 0, torch.zeros_like(y), a * torch.log1p(y + 1e-10))

    log_unnormalized = xlogy(alpha - 1.0, x) + xlog1py(beta - 1.0, -x)
    log_normalization = torch.lgamma(alpha) + torch.lgamma(beta) - torch.lgamma(alpha + beta)
    log_prob = log_unnormalized - log_normalization
    log_prob = torch.where((x >= 0) & (x <= 1), log_prob, torch.tensor(float('-inf'), device=x.device))

    return torch.exp(log_prob) / scale

def update_beta_parameters(w_id, w_ood, s_l, s_u ,alpha_id, beta_id, alpha_ood, beta_ood, l):
    # Detaching from computational graph so that it won't cause some bullshit error
    w_id = w_id.squeeze(1).detach(); w_ood = w_ood.squeeze(1).detach()
    s_l = s_l.squeeze(1).detach(); s_u = s_u.squeeze(1).detach()
    nu_id = torch.sum(s_l) + torch.dot(w_id, s_u)/(s_l.size(0) + torch.sum(w_id))
    sigma2_id = (torch.sum(s_l - nu_id)**2 + torch.dot(w_id, (s_u - nu_id)**2))/(s_l.size(0) + torch.sum(w_id))

    nu_ood = torch.dot(w_ood, s_u)/(torch.sum(w_ood))
    sigma2_ood = torch.dot(w_ood, (s_u - nu_id)**2)/((torch.sum(w_ood)))

    _alpha_id = nu_id*((nu_id*(1-nu_id))/sigma2_id - 1)
    _alpha_ood = nu_ood*((nu_ood*(1-nu_ood))/sigma2_ood - 1)
    _beta_id = (1 - nu_id)*((nu_id*(1-nu_id))/sigma2_id - 1)
    _beta_ood = (1 - nu_ood)*((nu_ood*(1-nu_ood))/sigma2_ood - 1)

    return alpha_id*l + (1 - l)*_alpha_id, beta_id*l + (1 - l)*_beta_id, alpha_ood*l + (1 - l)*_alpha_ood, beta_ood*l + (1 - l)*_beta_ood

def get_p_id(s, alpha_id, beta_id, alpha_ood, beta_ood, pi=1.0):
    """
    Compute the ID probability using Beta PDFs instead of torch.distributions.Beta,
    following the formulation in the paper (with loc/scale support implicitly assumed as 0/1).
    """
    beta_pdf_id = beta_pdf(s, alpha_id, beta_id)
    beta_pdf_ood = beta_pdf(s, alpha_ood, beta_ood)
    numerator = beta_pdf_id * pi
    denominator = numerator + beta_pdf_ood * (1 - pi)
    p_id = numerator / (denominator + 1e-8)
    return p_id


class Classifier(nn.Module):
    def __init__(self, in_f, out_f):
        super().__init__()
        self.classifier = nn.Linear(in_f, out_f)

    def forward(self, x):
        return self.classifier(x)

class InferenceModelImage:
    def __init__(self, backbone_pth, classifier_pth, batch_mean_pth, beta_parameters_pth):
        self.feature_dim = 200
        self.num_classes = 100
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.backbone = timm.create_model('resnet34', pretrained=True)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, self.feature_dim)
        self.backbone.load_state_dict(torch.load(backbone_pth, map_location=self.device))
        self.backbone = self.backbone.to(self.device)
        self.backbone.eval()

        self.classifier = Classifier(self.feature_dim, self.num_classes)
        self.classifier.load_state_dict(torch.load(classifier_pth, map_location=self.device))
        self.classifier = self.classifier.to(self.device)
        self.classifier.eval()

        self.batch_means = torch.load(batch_mean_pth, map_location=self.device).to(self.device)
        beta_parameters = torch.load(beta_parameters_pth, map_location=self.device).to(self.device)
        self.alpha_id, self.beta_id, self.alpha_ood, self.beta_ood = beta_parameters.chunk(4)

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def _preprocess_image(self, img_path):
        img = Image.open(img_path).convert('RGB')
        tensor = self.transform(img).unsqueeze(0).to(self.device)
        return tensor

    def __call__(self, image_path):
        x = self._preprocess_image(image_path)
        Z = self.backbone(x)
        s_score = subspace_score(Z, self.batch_means)
        p_id = get_p_id(
            s_score,
            self.alpha_id.to(self.device),
            self.beta_id.to(self.device),
            self.alpha_ood.to(self.device),
            self.beta_ood.to(self.device)
        )
        if p_id[0] < torch.rand(1, device=self.device):
            return -1
        y = self.classifier(Z)
        return torch.argmax(y, dim=1)


In [None]:
backbone_pth = "backbone_fsd_kaggle.pth"
classifier_pth = "classifier_epoch_fsd_kaggle.pth"
beta_parameters_pth = "beta_parameters_fsd_kaggle.pth"
batch_mean_pth = "batch_means_fsd_kaggle.pth"

test_inf = InferenceModelImage(backbone_pth, classifier_pth, batch_mean_pth, beta_parameters_pth)
print(test_inf("path"))