In [61]:
import os
import glob
import random

from PIL import Image
from tqdm import tqdm

In [62]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.utils import data
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from torch.optim import lr_scheduler
from torchsummary import summary

In [63]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [64]:
# Data preprocessing details
data_transforms = {
	'train': transforms.Compose([
		transforms.RandomResizedCrop(224),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	]),
	'val': transforms.Compose([
		transforms.Resize(224),
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	]),
}

In [65]:
src_path = "/home/rubel/projects_works/projects/algorithms/open_set_recognition/dataset/val"

In [66]:
weight_path = "/home/rubel/projects_works/projects/algorithms/open_set_recognition/weights/model_best.pth"

In [67]:
class_list = [os.path.basename(i) for i in glob.glob(os.path.join(src_path, "*"))]
class_list.remove("unknown_class")
class_list.sort()
print(len(class_list))
img_list = []
for c in class_list:
    img_names = glob.glob(os.path.join(src_path, c , "*"))
    for img in img_names:
        img_list.append([img, c])
img_names = glob.glob(os.path.join(src_path, "unknown_class", "*"))
for img in img_names:
    img_list.append([img, "unknown_class"])
random.shuffle(img_list)
print(len(img_list))

9
6358


In [68]:
class_list

['dettol_asl_120ml',
 'dettol_asl_235ml',
 'dettol_asl_725ml',
 'dettol_mac_aqua_curved_650ml',
 'dettol_mac_jasmine_curved_650ml',
 'dettol_mac_lavander_curved_650ml',
 'dettol_mac_lemon_curved_650ml',
 'dettol_spray_crisp_breeze_450ml',
 'dettol_spray_morning_dew_450ml']

In [69]:
def load_model(weight_path):
    # loading the trained model and generating embedding based on that
    base_model = models.resnet18(pretrained=False).to(DEVICE)
    for param in base_model.parameters():
        param.requires_grad = False
    num_ftrs = base_model.fc.in_features
    base_model.fc = nn.Sequential(nn.Linear(num_ftrs, 256), nn.Linear(256, 128))
    base_model = base_model.to(DEVICE)

    # loading the trained model with trained weights
    checkpoint = torch.load(weight_path)
    base_model.load_state_dict(checkpoint['state_dict'])
    base_model = base_model.eval()

    return base_model

In [70]:
# loading the pretrained model and changing the dense layer. Initially the convolution layers will be freezed
base_model = models.resnet18(pretrained=True).to(DEVICE)
for param in base_model.parameters():
    param.requires_grad = False
num_ftrs = base_model.fc.in_features
base_model.fc = nn.Sequential(nn.Linear(num_ftrs, 256), nn.Linear(256, 128), nn.Linear(128, len(class_list)), nn.Softmax(dim=1))
base_model = base_model.to(DEVICE)


# registering a forward hook to extract features
feature = {}
def get_activation(name):
    def hook(model, input, output):
        feature[name] = output.detach()
    return hook

base_model.fc[1].register_forward_hook(get_activation("embeddings"))

checkpoint = torch.load(weight_path)
base_model.load_state_dict(checkpoint["state_dict"])
base_model = base_model.eval()



In [73]:
total_unknown = 0
correct_unknown = 0
false_unknown = 0

total_known = 0
correct_known = 0
false_known = 0
for idx, i in tqdm(enumerate(img_list)):
    img_name = i[0]
    label = i[1]
    query_img = Image.open(img_name)
    query_img = data_transforms["val"](query_img)
    query_img = query_img.unsqueeze(0).to(DEVICE)
    query_img_output = base_model(query_img)[0]
    embeddings = feature["embeddings"]
    embeddings_sum = torch.sqrt(torch.sum(torch.square(embeddings), axis=1)).cpu()
    if label == "unknown_class":
        total_unknown += 1
    else:
        total_known += 1
    if embeddings_sum.item() < 5.5:
        if label == "unknown_class":
            correct_unknown += 1
        else:
            false_known += 1
    elif embeddings_sum.item() > 5.5:
        if label == "unknown_class":
            false_unknown += 1
        else:
            correct_known += 1
#     print(torch.max(query_img_output).item(), label, embeddings_sum.item())

6358it [01:47, 59.15it/s]


In [74]:
print(total_unknown, correct_unknown, false_unknown)
print(total_known, correct_known, false_known)

2000 1396 604
4358 4059 299


In [59]:
print(total_unknown, correct_unknown, false_unknown)
print(total_known, correct_known, false_known)

2000 1049 951
4358 3046 1312


In [47]:
print(total_unknown, correct_unknown, false_unknown)
print(total_known, correct_known, false_known)

2000 777 1223
4358 4336 22


In [None]:
	print("Staring accuracy check on test data...")
	for idx, i in tqdm(enumerate(test_img_list)):
		query_img_org = Image.open(i)
		gt_class = i.split("/")[-2]
		query_img = config.data_transforms["val"](query_img_org)
		query_img = query_img.unsqueeze(0).to(config.DEVICE)
		query_img_embedding = base_model(query_img)
		query_img_embedding = query_img_embedding.squeeze()

		similar_images = annoy_index.get_nns_by_vector(query_img_embedding, 20, include_distances=True)
		similar_image_labels = [annoy_index_to_label[i] for i in similar_images[0]]
		
		pt_class = similar_image_labels[0]
		
		if gt_class == pt_class:
			if gt_class in individual_accuracy:
				individual_accuracy[gt_class][0] += 1
				individual_accuracy[gt_class][1] += 1
				individual_accuracy[gt_class][2] = individual_accuracy[gt_class][1]/individual_accuracy[gt_class][0]
			else:
				individual_accuracy[gt_class] = [1, 1, 1]
				
			correct += 1
	#         query_img_org.save("./wrong_predictions/test/correct/" + os.path.splitext(os.path.basename(i))[0] + "_gt_" + gt_class + "_pt_" + pt_class + ".png")
		else:
			if gt_class in individual_accuracy:
				individual_accuracy[gt_class][0] += 1
				individual_accuracy[gt_class][1] += 0
				individual_accuracy[gt_class][2] = individual_accuracy[gt_class][1]/individual_accuracy[gt_class][0]
			else:
				individual_accuracy[gt_class] = [1, 0, 0]
	#         query_img_org.save("./wrong_predictions/test/wrong/" + os.path.splitext(os.path.basename(i))[0] + "_gt_" + gt_class + "_pt_" + pt_class + ".png")
		total += 1
	
	print("Overall Accuracy ", correct/total)
	print("Individual Accuracy Report:")
	print(individual_accuracy)
