# torchvision.models.resnet18を用いたファインチューニングによる学習とGradCAMによる注視領域の可視化

### import

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import time
from copy import deepcopy

In [None]:
from data_loader import train_loader
from data_loader import train_size
from data_loader import test_loader
from data_loader import test_size
from models.resnet import resnet

### device config

In [None]:
CUDA_LAUNCH_BLOCKING=1

In [None]:
use_gpu = torch.cuda.is_available()

In [None]:
device = torch.device("cuda:0" if use_gpu else "cpu")

In [None]:
print(use_gpu)

### show samples

In [None]:
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 訓練データをランダムに取得
dataiter = iter(train_loader)
images, labels = dataiter.next()

# 画像の表示
imshow(torchvision.utils.make_grid(images))
# ラベルの表示
print(' '.join('%5s' % labels[labels[j]] for j in range(10)))

### define model, parameter, optimizer, loss function

In [None]:
resnet = resnet.to(device)
learning_ratio = 1e-4
epochs = 2
optimizer = torch.optim.Adam(resnet.parameters(), lr=learning_ratio, weight_decay=1e-4)
loss_func = torch.nn.CrossEntropyLoss()

### Training the model (fine tuning)

↓ prepare result container

In [None]:
elapsed_times = []
loss_transition_holder = []
accracy_transition_holder = []

↓ run iteration

In [None]:
start_time = time.time()

for epoch in range(1):

    running_loss = 0.0
    total = 0
    correct = 0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        labels -= 1  # なぜかinputsがImageFolderによって1,2でラベル付けされたらしい。0-indexedにする。
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = resnet(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # statistics
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()

        if i % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            print("elapsed time: %.1f s" % (time.time()-start_time))
            
            accuracy = correct / total * 100
            
            # record
            elapsed_times.append(time.time()-start_time)
            loss_transition_holder.append(running_loss)
            accracy_transition_holder.append(accuracy)
            
            # init
            running_loss = 0.0
            total = 0
            correct = 0


elapsed_time = time.time()-start_time
print('Training was done. Elapsed time: ', elapsed_time)

### save the trained model

In [None]:
state_holder_path = "./state_holder3.pth"
torch.save(resnet.state_dict(), state_holder_path)

### Visualize stats

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10,5))

axes[0].plot(range(len(loss_transition_holder)), loss_transition_holder, label = "loss")
axes[0].set_title("transition of loss")
axes[0].legend()

axes[1].plot(range(len(accracy_transition_holder)), accracy_transition_holder, label =  "accracy")
axes[1].set_title("transition of accuracy")
axes[1].legend()

#０個目と１個目のグラフが重ならないように調整
fig.tight_layout()

### test the model

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        labels -= 1  # なぜかinputsがImageFolderによって1,2でラベル付けされたらしい。0-indexedにする。
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = resnet(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on the test: %d %%' % (
    100 * correct / total))

### シーン解析なので注目している物体の詳細よりおおまかな領域が欲しい。のでGradCAMを試す。

In [None]:
dataiter = iter(test_loader)
images, labels = dataiter.next()

def toHeatmap(x):
    x = (x*255).reshape(-1)
    cm = plt.get_cmap('jet')
    x = np.array([cm(int(np.round(xi)))[:3] for xi in x])
    return x.reshape(224,224,3)

with torch.no_grad():
    for img_tensor in images:
        resnet.features
        resnet.classifier

        feature = resnet.features(img_tensor.view(-1,3,224,224)) #特徴マップを計算
        feature = feature.clone().detach().requires_grad_(True) #勾配を計算するようにコピー
        y_pred = resnet.classifier(feature.view(-1,512*7*7)) #予測を行う
        y_pred[0][torch.argmax(y_pred)].backward() # 予測でもっとも高い値をとったクラスの勾配を計算
        # 以下は上記の式に倣って計算しています
        alpha = torch.mean(feature.grad.view(512,7*7),1)
        feature = feature.view(512,7,7)
        L = F.relu(torch.sum(feature*alpha.view(-1,1,1),0)).cpu().detach().numpy()
        # (0,1)になるように正規化
        L_min = np.min(L)
        L_max = np.max(L - L_min)
        L = (L - L_min)/L_max
        # 得られた注目度をヒートマップに変換
        L = toHeatmap(cv2.resize(L,(224,224)))

        mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

        plt.figure(figsize=(10,10))
        plt.imshow((img_tensor*std + mean).permute(1,2,0).cpu().detach().numpy())

        img1 = (img_tensor*std + mean).permute(1,2,0).cpu().detach().numpy()
        img2 = L

        alpha = 0.3
        blended = img1*alpha + img2*(1-alpha)
        # 結果を表示する。
        plt.figure(figsize=(10,10))
        plt.imshow(blended)
        plt.axis('off')
        plt.show()