## B-cos Explanations

In [None]:
import torch
from bcos_layer import BcosConv2d

model = torch.hub.load('B-cos/B-cos-v2', 'resnet34', pretrained=True)
model.fc = BcosConv2d(in_channels=512, out_channels=5)
model.load_state_dict(torch.load(f'./../models/imagenet_bcos_teacher_resnet34.pth'))

In [None]:
%%time
import matplotlib.pyplot as plt
from utils import load_dataset

def generate_bcos_explanations(model, images, labels):
    model.eval()
    images.requires_grad = True
    with torch.enable_grad(), model.explanation_mode():
        logits = model(images)[range(images.size()[0]), labels]
        gradients = torch.autograd.grad(outputs=logits, inputs=images, grad_outputs=torch.ones_like(logits), retain_graph=True)[0]
    return (images * gradients).sum(dim=1)

image_size, idx_to_tag, num_classes, train_loader, validation_loader = load_dataset("imagenet", 64, True)
for images, labels in train_loader:
    result = generate_bcos_explanations(model, images, labels)
    break

In [None]:
import numpy as np
i = np.random.choice(range(train_loader.batch_size))
plt.imshow((images[i][:3]*255).permute(1, 2, 0).cpu().detach().numpy())
plt.show()
plt.imshow((result[i]*255).cpu().detach().numpy())
plt.show()

## GradCAM Explanations

In [None]:
from torchvision import models
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

In [None]:
import torch
from utils import load_dataset, load_model, generate_gradcam_heatmaps

In [None]:
device = torch.device("cpu")
image_size, idx_to_tag, num_classes, train_loader, validation_loader = load_dataset("imagenet", 64, False)

In [None]:
%%time
# model = load_model(False, "resnet34", True, num_classes, "./../models/imagenet_teacher_resnet34.pth")
model.to(device)
for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    heatmaps = generate_gradcam_heatmaps(model, "resnet34", images, labels, (224, 224))
    break

In [None]:
import numpy as np
import matplotlib.pyplot as plt

i = np.random.choice(range(train_loader.batch_size))
plt.imshow((images[i][:3]*255).permute(1, 2, 0).cpu().detach().numpy())
plt.show()
plt.imshow((heatmaps[i]*255).cpu().detach().numpy())
plt.show()

In [None]:
plt.imshow((heatmaps[i]).cpu().detach().numpy())


In [None]:
heatmaps[i]

In [None]:
from torchvision.utils import save_image
save_image(heatmaps[i], 'img1.png')

In [None]:
from PIL import Image
image = Image.open('img1.png')
image

In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.PILToTensor()
])
img_tensor = transform(image)
print(img_tensor)

In [None]:
torch.save(heatmaps[i], 'teacher_heatmap.pt')

In [None]:
plt.imshow((torch.load('teacher_heatmap.pt')).cpu().detach().numpy())
