In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from torch.autograd import Variable
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import models, datasets, transforms

In [None]:
class U_Net_Encoder(nn.Module):
    def __init__(self):
        super(U_Net_Encoder, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0)
            
        self.pool1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
        )
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2, groups=2)
           
        self.pool2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(5, alpha=0.0001, beta=0.75)
        )
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
           
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1, groups=2)
           
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1, groups=2)
           
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        self.fc6 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256*6*6, 4096),
        )
        self.fc7 = nn.Linear(4096, 4096)
        
        self.fc8 = nn.Linear(4096, 1000)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):       
        conv1 = self.conv1(x)        # (96, 55, 55)
        x = self.relu(conv1)

        x = self.pool1(x)            # (96, 27, 27)
    
        conv2 = self.conv2(x)        # (256, 27, 27)
        x = self.relu(conv2)
        
        x = self.pool2(x)            # (256, 13, 13)

        conv3 = self.conv3(x)        # (384, 13, 13)
        x = self.relu(conv3)
     
        conv4 = self.conv4(x)        # (384, 13, 13)
        x = self.relu(conv4)

        conv5 = self.conv5(x)        # (256, 13, 13)
        x = self.relu(conv5)

        x = self.pool3(x)            # (256, 6, 6)

        fc6 = self.fc6(x)            # (4096)
        x = self.relu(fc6)
        x = self.dropout(x)
        
        fc7 = self.fc7(x)            # (4096)
        x = self.relu(fc7)
        x = self.dropout(x)

        fc8 = self.fc8(x)            # (1000)

        return conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8


In [None]:
class U_Net_Decoder(nn.Module):
    def __init__(self):
        super(U_Net_Decoder, self).__init__()

        self.rfc8 = nn.Linear(1000, 4096)
        
        self.rfc7 = nn.Linear(8192, 4096)
        
        self.rfc6 = nn.Linear(8192, 256*6*6)
        
        self.rpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2)
        self.rconv5 = nn.ConvTranspose2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1, groups=2)
        
        self.rconv4 = nn.ConvTranspose2d(in_channels=768, out_channels=384, kernel_size=3, stride=1, padding=1, groups=2)
        
        self.rconv3 = nn.ConvTranspose2d(in_channels=768, out_channels=256, kernel_size=3, stride=1, padding=1)
        
        self.rpool2 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, stride=2)
        self.rconv2 = nn.ConvTranspose2d(in_channels=512, out_channels=96, kernel_size=5, stride=1, padding=2, groups=2)
        
        self.rpool1 = nn.ConvTranspose2d(in_channels=96, out_channels=96, kernel_size=3, stride=2)

        self.rconv1 = nn.ConvTranspose2d(in_channels=192, out_channels=3, kernel_size=11, stride=4, padding=0)

        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()
        self.loss_func = nn.MSELoss()
        self.sigmoid = nn.Sigmoid()

    def forward(self, conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8):

        x = self.rfc8(fc8)                # (4096)
        x = torch.cat((x, fc7), dim=1)    # (8192)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.rfc7(x)                  # (4096)
        x = torch.cat((x, fc6), dim=1)    # (8192)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.rfc6(x)                  # (256*6*6)
        x = self.relu(x)
        x = self.dropout(x)
        x = x.view(-1, 256, 6, 6)         # (256, 6, 6)

        x = self.rpool3(x)                # (256, 13, 13)
        x = torch.cat((x, conv5), dim=1)  # (512, 13, 13)

        x = self.rconv5(x)                # (384, 13, 13)
        x = torch.cat((x, conv4), dim=1)  # (768, 13, 13)
        x = self.relu(x)

        x = self.rconv4(x)                # (384, 13, 13)
        x = torch.cat((x, conv3), dim=1)  # (768, 13, 13)
        x = self.relu(x)

        x = self.rconv3(x)                # (256, 13, 13)
        x = self.relu(x)

        x = self.rpool2(x)                # (256, 27, 27)
        x = torch.cat((x, conv2), dim=1)  # (512, 27, 27)

        x = self.rconv2(x)                # (96, 27, 27)
        x = self.relu(x)

        x = self.rpool1(x)                # (96, 55, 55)
        x = torch.cat((x, conv1), dim=1)  # (192, 55, 55)

        x = self.rconv1(x)                # (3, 227, 227)
        x = self.sigmoid(x)      

        return x

    def loss(self, x, x_recon):

        loss = self.loss_func(x, x_recon)

        return loss


In [None]:
model_dir = '/home/shunosuga/data/model/u_net/epoch_10'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = U_Net_Encoder().to(device)
decoder = U_Net_Decoder().to(device)


In [None]:
encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder_10')))
decoder.load_state_dict(torch.load(os.path.join(model_dir, 'decoder_10')))

In [None]:
import copy
def img_preprocess(img, img_mean=np.array([0.485, 0.456, 0.406], dtype=np.float),
                   img_std=np.array([0.229, 0.224, 0.225], dtype=np.float), norm=255):
    '''convert to Pytorch's input image layout'''
    img = img / norm
    image = np.float32(np.transpose(img, (2, 0, 1)) - np.reshape(img_mean, (3, 1, 1))) / np.reshape(img_std, (3, 1, 1))
    return image

def normalized_img(img):
    '''Normalize the image.
    Map the minimum pixel to 0; map the maximum pixel to 255.
    Convert the pixels to be int
    '''
    img = img - img.min()
    if img.max() > 0:
        img = img * (255.0 / img.max())
    img = np.uint8(img)
    return img

def get_cnn_features(model, input, extract_feat_list):
    net = copy.deepcopy(model)
    outputs = []

    def hook(module, input, output):
        outputs.append(output.clone())

    # run the code in exec_code
    for exec_str in extract_feat_list:
        exec("net."+exec_str+".register_forward_hook(hook)")
    outputs = []
    _ = net(input)

    return outputs


In [None]:
random_input = np.random.randint(0, 256, [227,227,3])
plt.imshow(random_input)
random_norm = (random_input / 255).transpose(2, 0, 1)
img_input = torch.Tensor(random_norm[np.newaxis]).to(device)

conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8 = encoder(img_input)
print(conv1.size())

In [None]:
#対象にするチャンネル
channel =  54
#対象にする位置
w, h = 28, 28

rand_list = []
value_list = []
iter_num = 10000

for i in range(iter_num):
    if i % 50 == 0:
        print(f"{i}回目")
    # ランダム画像を作成
    random_input = np.random.randint(0, 256, [227,227,3]) / 255
    # 前処理
    random_norm = (random_input).transpose(2, 0, 1)
    # Tensor型に変換
    img_input = torch.Tensor(random_norm[np.newaxis]).to(device)
    #活動値の取得
    conv1, conv2, conv3, conv4, conv5, fc6, fc7, fc8 = encoder(img_input)
    #指定の位置の活動値を取得
    act_value = conv1[0, channel, w, h].to('cpu').detach().numpy()
    #ランダム画像を保存
    rand_list.append(random_input)
    #活動値を保存
    value_list.append(act_value)
               


In [None]:
rc_img = np.zeros([227,227,3])
for i in range(len(rand_list)):
    rc_img +=  value_list[i] * rand_list[i] 

In [None]:
rc_img_out = np.array(rc_img_out, dtype='uint8')

In [None]:
plt.imshow(rc_img_out)
plt.show()
# 拡大
norm_ = rc_img_out[114 -10+4: 114 + 10+3, 114 - 10+4: 114 + 10+3]
plt.imshow(norm_)
plt.show()