<a href="https://colab.research.google.com/github/yuraoh12/AI-bigdata/blob/main/231121_GradCAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://github.com/kairess/toy-datasets/raw/master/dog-inflammation.zip
!unzip -q dog-inflammation.zip

In [None]:
#dataloader
from fastai.vision.all import *

path = 'dog-inflammation'

block = DataBlock(
    blocks = (ImageBlock, CategoryBlock), # Input, Output
    get_items = get_image_files, # PIL
    get_y = parent_label, # 폴더
)

loader = block.dataloaders(path)

loader.show_batch()

In [None]:
# train # 전이 학습 - 파인 튜닝
learn = vision_learner(loader, resnet18, metrics=accuracy)
learn.fine_tune(epochs = 3)

In [None]:
# test result
learn.show_results() # 위에가 정답

In [None]:
# GradCAM
neg_path = '/content/dog-inflammation/Negative/D0_03f7e7c0-60a5-11ec-8402-0a7404972c70.jpg'
pos_path = '/content/dog-inflammation/Positive/D0_02fa7d26-60a5-11ec-8402-0a7404972c70.jpg'

test_loader = loader.test_dl([neg_path, pos_path]) # 전처리

neg_x, pos_x = next(iter(test_loader))[0]

neg_x = neg_x.unsqueeze(0)
pos_x = pos_x.unsqueeze(0)

print(neg_x.shape, pos_x.shape)

neg_img = Image.open(neg_path)
pos_img = Image.open(pos_path)

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(neg_img)
ax[0].axis('off')
ax[1].imshow(pos_img)
ax[1].axis('off')
plt.tight_layout()
plt.show()

In [None]:
class Hook():
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_func)
    def hook_func(self, m, i, o):
        self.stored = o.detach().clone()
    def __enter__(self, *args):
        return self
    def __exit__(self, *args):
        self.hook.remove()

class HookBwd():
    def __init__(self, m):
        self.hook = m.register_backward_hook(self.hook_func)
    def hook_func(self, m, gi, go):
        self.stored = go[0].detach().clone()
    def __enter__(self, *args):
        return self
    def __exit__(self, *args):
        self.hook.remove()

In [None]:
def get_gradcam(x, location=-1): # 어느 레이어에서 뽑아낼껀지
    with HookBwd(learn.model[0][location]) as hookg:
        with Hook(learn.model[0][location]) as hook:
            output = learn.model.eval()(x.cuda())
            idx = torch.argmax(output)
            act = hook.stored
        output[0, idx].backward()
        grad = hookg.stored

    w = grad[0].mean(dim=[1, 2], keepdim=True)
    heatmap = (w * act[0]).sum(0).detach().cpu()

    return heatmap

heatmap = get_gradcam(pos_x, location=-2)

print(heatmap.shape)
plt.imshow(heatmap, cmap='jet')
plt.show()

In [None]:
# Positive
img = Image.open(pos_path)

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img)
ax[0].axis('off')
ax[1].imshow(img)
ax[1].imshow(heatmap.cpu(), alpha=0.5, extent=(0,224,224,0),
              interpolation='bilinear', cmap='jet')
ax[1].axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Negative
heatmap = get_gradcam(neg_x, location=-1)
img = Image.open(neg_path)

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img)
ax[0].axis('off')
ax[1].imshow(img)
ax[1].imshow(heatmap.cpu(), alpha=0.5, extent=(0,224,224,0),
              interpolation='bilinear', cmap='jet')
ax[1].axis('off')
plt.tight_layout()
plt.show()