In [1]:

import torch
import torch.nn as nn
import numpy as np
from torchvision.models import resnet18
from dataset import get_CIFAR10
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ResNet_DUQ(nn.Module):
    def __init__(
        self,
        feature_extractor,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    ):
        super().__init__()

        self.gamma = gamma

        self.W = nn.Parameter(
            torch.zeros(centroid_size, num_classes, model_output_size)
        )
        nn.init.kaiming_normal_(self.W, nonlinearity="relu")

        self.feature_extractor = feature_extractor

        self.register_buffer("N", torch.zeros(num_classes) + 13)
        self.register_buffer(
            "m", torch.normal(torch.zeros(centroid_size, num_classes), 0.05)
        )
        self.m = self.m * self.N

        self.sigma = length_scale
    def rbf(self, z):
        z = torch.einsum("ij,mnj->imn", z, self.W)

        embeddings = self.m / self.N.unsqueeze(0)
        print(embeddings.shape)

        diff = z - embeddings.unsqueeze(0)
        diff = (diff ** 2).mean(1).div(2 * self.sigma ** 2).mul(-1).exp()

        return diff



    def update_embeddings(self, x, y):
        self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0)

        z = self.feature_extractor(x)

        z = torch.einsum("ij,mnj->imn", z, self.W)
        embedding_sum = torch.einsum("ijk,ik->jk", z, y)

        self.m = self.gamma * self.m + (1 - self.gamma) * embedding_sum

    def forward(self, x):
        z = self.feature_extractor(x)
        y_pred = self.rbf(z)

        return y_pred

In [3]:
    
model_output_size = 512
epochs = 5
milestones = [25, 50, 75]
feature_extractor = resnet18()

# Adapted resnet from:
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
feature_extractor.conv1 = torch.nn.Conv2d(
    3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
feature_extractor.maxpool = torch.nn.Identity()
feature_extractor.fc = torch.nn.Identity()

centroid_size = model_output_size
num_classes = 10
length_scale = 0.1
gamma =0.999


model = ResNet_DUQ(
feature_extractor,
num_classes,
centroid_size,
model_output_size,
length_scale,
gamma)
model = model.cuda()

In [4]:

def predict_uncertainty( inputs, resnet_duq_model):
    with torch.no_grad():
        resnet_duq_model.eval()
        inputs = inputs/255
        uncertainty_scores = resnet_duq_model(inputs)
    return uncertainty_scores

def query( n_pool, idxs_lb, X, model ):
    idxs_unlabeled = np.arange(n_pool)[~idxs_lb]
    idxs_unlabeled = torch.tensor(idxs_unlabeled)

    # Assuming your data needs some transformation before passing it to ResNet_DUQ
    # transformed_inputs = torch.tensor([transform(Image.fromarray(img)) for img in self.X[idxs_unlabeled]])

    uncertainty_scores = predict_uncertainty(X[idxs_unlabeled], model)
    # U = uncertainty_scores.sum(dim=0).mean(dim=1)  # Modify this based on your ResNet_DUQ model
    # U = uncertainty_scores.sum(dim=1).mean(dim=0)
    # _, uncertain_indices = U.sort(descending=True)
    # result = idxs_unlabeled[uncertain_indices[:n]]

    return uncertainty_scores

In [5]:
X_tr, Y_tr, X_te, Y_te = get_CIFAR10('data')
X_tr = X_tr.cuda()
n_pool = len(Y_tr)
idxs_lb = np.zeros(n_pool, dtype=bool)
idxs_tmp = np.arange(n_pool)
np.random.shuffle(idxs_tmp)
idxs_lb[idxs_tmp[:48000]] = True

Files already downloaded and verified
Files already downloaded and verified


In [6]:
idxs_unlabeled = np.arange(n_pool)[~idxs_lb]

In [7]:
uner_scores = query(n_pool=n_pool, idxs_lb= idxs_lb, X = X_tr, model=model)

torch.Size([512, 10])


In [8]:
class DataHandler3(Dataset):
    def __init__(self, X, Y, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __getitem__(self, index):
        x,  y = self.X[index], self.Y[index]
        if self.transform is not None:
            # x = Image.fromarray(x)
            x = Image.fromarray(x.astype(np.uint8)) 
            x = self.transform(x)
        return x,  y, index

    def __len__(self):
        return len(self.X)
    

def get_handler(name):
    if name == 'CIFAR10':
        return DataHandler3
    
handler = get_handler('CIFAR10')

In [9]:
def predict_prob( X, Y, model, exp=True):
	transform =  transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
	# model = nn.DataParallel(model)
	loader_te = DataLoader(handler(X, Y, transform=transform), shuffle=False, batch_size =  100, num_workers=1)
	# model = model.eval().cuda()
	model = model.eval()
	probs = torch.zeros([len(Y), len(np.unique(Y))])
	with torch.no_grad():
		for x, y, idxs in loader_te:
			print(idxs)
			# x, y = x.cuda(), y.cuda()
			out, e1 = model(x)
			if exp: out = F.softmax(out, dim=1)
			probs[idxs] = out.cpu().data
	
	return probs

def query_entropy( n_pool, idxs_lb, X, Y, model ):
		idxs_unlabeled = np.arange(n_pool)[~idxs_lb]
		probs = predict_prob(X[idxs_unlabeled], Y.numpy()[idxs_unlabeled],model )
		log_probs = torch.log(probs)
		# U = (probs*log_probs).sum(1)
		# result = idxs_unlabeled[U.sort()[1][:n]]
		# save_queried_idx(result, self.savefile, self.alg)
		return log_probs

In [10]:
net = models.resnet18(weights=ResNet18_Weights.DEFAULT)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 10)

In [11]:
log_probs = query_entropy(n_pool=n_pool, idxs_lb=idxs_lb, X=X_tr, Y=Y_tr, model=net)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/home/sgchr/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/home/sgchr/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/home/sgchr/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_168021/3651546271.py", line 8, in __getitem__
    x,  y = self.X[index], self.Y[index]
RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.


In [None]:
model.W.shape

torch.Size([512, 10, 512])

In [None]:
ten = torch.randn(100,32,32,3)

In [None]:
uncer_scores = model_DUQ.rbf(ten)

AssertionError: Expected 3 dimensions for z, got 4

In [None]:
from torchvision import datasets, transforms
def get_CIFAR10():
    input_size = 32
    num_classes = 10
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    train_dataset = datasets.CIFAR10(
        "data/CIFAR10", train=True, transform=train_transform, download=True
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    test_dataset = datasets.CIFAR10(
         "data/CIFAR10", train=False, transform=test_transform, download=True
    )

    return input_size, num_classes, train_dataset, test_dataset

input_size, num_classes, train_dataset, test_dataset = get_CIFAR10()

Files already downloaded and verified
Files already downloaded and verified


In [None]:
kwargs = {"num_workers": 4, "pin_memory": True}
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, shuffle=True, drop_last=True, **kwargs
    )

test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=64, shuffle=False,  **kwargs
    )

In [None]:
from ignite.engine import Events, Engine
from ignite.metrics import Accuracy, Average, Loss
import torch.nn.functional as F

In [None]:
optimizer = torch.optim.SGD(
    model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4
)

scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=milestones, gamma=0.2)

In [None]:
def calc_gradients_input(x, y_pred):
    gradients = torch.autograd.grad(
        outputs=y_pred,
        inputs=x,
        grad_outputs=torch.ones_like(y_pred),
        create_graph=True,
    )[0]

    gradients = gradients.flatten(start_dim=1)

    return gradients

def calc_gradient_penalty(x, y_pred):
    gradients = calc_gradients_input(x, y_pred)

    # L2 norm
    grad_norm = gradients.norm(2, dim=1)

    # Two sided penalty
    gradient_penalty = ((grad_norm - 1) ** 2).mean()

    return gradient_penalty

def step(engine, batch):
    model.train()

    optimizer.zero_grad()

    x, y = batch
    x, y = x.cuda(), y.cuda()
    print('step function','x:',x.shape, 'y:', y.shape)

    x.requires_grad_(True)

    y_pred = model(x)

    y = F.one_hot(y, num_classes).float()
    l_gradient_penalty = 0.75
    loss = F.binary_cross_entropy(y_pred, y, reduction="mean")
    if l_gradient_penalty > 0:
        gp = calc_gradient_penalty(x, y_pred)
        loss += l_gradient_penalty * gp

    loss.backward()
    optimizer.step()

    x.requires_grad_(False)

    with torch.no_grad():
        model.eval()
        model.update_embeddings(x, y)

    return loss.item()

def eval_step(engine, batch):
    model.eval()

    x, y = batch
    x, y = x.cuda(), y.cuda()

    print('eval_step function','x:',x.shape, 'y:', y.shape)

    x.requires_grad_(True)

    y_pred = model(x)

    return {"x": x, "y": y, "y_pred": y_pred}

trainer = Engine(step)
evaluator = Engine(eval_step)

In [None]:
trainer.run(train_loader, max_epochs=epochs)
evaluator.run(test_loader)
acc = evaluator.state.metrics["accuracy"]

print(f"Test - Accuracy {acc:.4f}")

step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])
step function x: torch.Si

Engine run is terminating due to exception: 


step function x: torch.Size([64, 3, 32, 32]) y: torch.Size([64])


KeyboardInterrupt: 