# Visualize rules

In [None]:
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import os
import sys

sys.path.append('../')
from rule_mining import utils as ut
from rule_mining import config as cfg
from rule_mining.detector import VsDetector

im_keep = 10
im_range = 2048
mode = 'test'

trans_info_file = os.path.join(cfg.rule_data_path, '{}_trans_info.npz'.format(mode))
annofile = os.path.join(cfg.dataset_path, 'annotations/sis/{}.story-in-sequence.json'.format(mode))

annotations = ut.load_json(annofile)['annotations']
trans_info = ut.load_npz_dict(trans_info_file)
id2words = ut.load_json(cfg.txtdata_file)['id2words']

# create a detector from a rule_file
rule_file = os.path.join(cfg.rule_data_path, 'rules03_0.6.npz')
detector = VsDetector(rule_file, id2words)
print rule_file
print "rules_num = %d, categories_kinds = %d" % (detector.rules_num, detector.elem_kinds)

In [None]:
from __future__ import division
import torch
import torch.nn as nn
import torchvision
from scipy.misc import imread
from torchvision import transforms
from torch.autograd import Variable
from scipy.ndimage import zoom

class MyResnet(nn.Module):
    def __init__(self, resnet):
        super(MyResnet, self).__init__()
        self.resnet = resnet

    def forward(self, img):
        x = img.unsqueeze(0)

        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        feature_map = x
        feature = x.mean(3).mean(2).squeeze()

        return feature_map, feature
    
preprocess = transforms.Compose([
    # trn.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

model = torchvision.models.resnet152(pretrained=True)
resnet = MyResnet(model)
resnet.cuda()
resnet.eval()
print("resnet loaded.")

#### get rules of specifical concept

In [None]:
feature_dir = os.path.join(cfg.feature_dir, '{}/'.format(mode))
img_id_list = ['6828584731']
for img_id in img_id_list:
    img_file = os.path.join(cfg.dataset_path, 'images',mode,"{}.jpg".format(img_id))
    if not os.path.exists(img_file):
        img_file = os.path.join(cfg.dataset_path,"images",mode,"{}.png".format(img_id))
    im_trans = trans_info['id2trans'][img_id][:10]
    semantic_words = detector.detect(im_trans)
    semantic_words = [word.encode('utf-8') for word in semantic_words]
    try:
        img = Image.open(img_file)
    except:
        print("image %s not found" % img_id)
        continue
    plt.figure(figsize=(8,5))
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.imshow(img) 
    plt.axis('off')
    
    plt.show()
    print(semantic_words)
    for rule, elem_set in detector.rules.items():
        if detector.pattern_to_set[rule].issubset(im_trans):
            ve_words = [detector.id2words[str(int(idx)-detector.imrange)] for idx in elem_set]
            if 'pumpkin' in ve_words:
                print detector.pattern_to_set[rule],ve_words

#### load image and get feature map for specifical filters 

In [None]:
#img_id = '6828584731' #pumpkin
filters = [1484, 853, 646, 1287, 458] 

image = imread(img_file)
image_v = image.astype('float32') / 255.0
image_v = torch.from_numpy(image_v.transpose([2, 0, 1]))
image_v = Variable(preprocess(image_v), volatile=True).cuda()
feature_map, feature= resnet(image_v)
img_feature_map = feature_map.data.cpu().float().numpy()
#img_feature = feature.data.cpu().float().numpy()
image = image.astype('float32') / 255.0

#### output feature map and picture with mask

In [None]:
print image.shape, img_feature_map.shape
for filter_id in filters:
    #print img_feature_map[0,filter_id]
    print("filter_id%d" %filter_id)
    f = img_feature_map[0,filter_id]
    
    plt.figure(figsize=(16,5))
    plt.subplot(1,2,1)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.imshow(f)
    plt.axis('off')
    
    print image.shape,f.shape
    zoom_factor = (image.shape[0]/f.shape[0], image.shape[1]/f.shape[1])
    f = zoom(f, zoom_factor, order=1)
    f = f[:, :, np.newaxis]
    image_mask = np.concatenate((f, f, f), axis=2)
    image_mask = np.minimum(image_mask, 1)
    image_with_mask = image * image_mask
    
    plt.subplot(1,2,2)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.imshow(image_with_mask)
    plt.axis('off')
    
    plt.show()