In [2]:
# 获取目标类别图片，识别目标类别
# 获取九宫格验证图片，划分为9张图片
# 利用模型对9张图片进行分类，得到对应目标类别图片的坐标位置

In [4]:
from selenium import webdriver
from selenium.webdriver import Chrome, ChromeOptions
from selenium.webdriver.common.by import By
from selenium.webdriver import ActionChains
import requests
import time
import matplotlib.pyplot as plt
from io import BytesIO
import cv2
from PIL import Image
from PIL import ImageChops
import random
import string
import os
import numpy as np
import base64

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.transforms import transforms
from PIL import Image

%matplotlib inline

In [5]:
def crop_image(images):
    height, width = images.size
    item_height, item_width = height // 3, width // 3
    box_list = []
    for i in range(3):
        for j in range(3):
            box = (item_height*i, item_width*j, item_height*(i+1), item_width*(j+1))
            box_list.append(box)
    img_list = [images.crop(box) for box in box_list]
    return img_list

# def calc_similarity(img1_path, img2_path):
#     img1 = cv2.imdecode(np.fromfile(img1_path, dtype=np.uint8), -1)
#     H1 = cv2.calcHist([img1], [1], None, [256], [0, 256])  # 计算图直方图
#     H1 = cv2.normalize(H1, H1, 0, 1, cv2.NORM_MINMAX, -1)  # 对图片进行归一化处理
#     img2 = cv2.imdecode(np.fromfile(img2_path, dtype=np.uint8), -1)
#     H2 = cv2.calcHist([img2], [1], None, [256], [0, 256])  # 计算图直方图
#     H2 = cv2.normalize(H2, H2, 0, 1, cv2.NORM_MINMAX, -1)  # 对图片进行归一化处理
#     similarity1 = cv2.compareHist(H1, H2, 0)  # 相似度比较
#     # print('similarity:', similarity1)
#     if similarity1 == 1:  # 0.98是阈值，可根据需求调整
#         return True
#     else:
#         return False

class ImgClassifyModel(nn.Module):
    
    def __init__(self, class_num, pretrained=None):
        super().__init__()
        self.model = models.efficientnet_b5(pretrained=False)
        # self.model = models.efficientnet_b7(pretrained=False)
        if pretrained:
            self.model.load_state_dict(torch.load(pretrained))
        # self.model.classifier.add_module('3', nn.Linear(1000, class_num))
        self.model.classifier[1] = nn.Linear(2048, class_num)
        # self.model.classifier[1] = nn.Linear(2560, class_num)
    
    def forward(self, x):
        x = self.model(x)
        
        return x

def recog_image(img_path, best_model, trf, device):
    img = Image.open(img_path).convert('RGB')
    img = trf(img).unsqueeze(0)
    img = img.to(device)
    
    best_model.eval()
    pred_y = best_model(img)
    pred_y = pred_y.detach().argmax(dim=-1).cpu().numpy()
    
    return pred_y[0] + 1

In [6]:
with open('./flags.txt', 'r', encoding='utf8') as f:
    flags = f.read()
flags = flags.split('\n')
label_to_flag = dict(enumerate(flags))
del label_to_flag[0]

img_size = (112, 112)
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
trf = transforms.Compose([
    transforms.Resize(img_size), 
    transforms.ToTensor(),
    transforms.Normalize(mean=norm_mean, std=norm_std)
])
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
class_num = 90
model_path = './model/model_0724.pt'
best_model = ImgClassifyModel(class_num=class_num, pretrained=None)
best_model.load_state_dict(torch.load(model_path))
best_model = best_model.to(device)

In [7]:
options = ChromeOptions()
options.add_argument('user-agent="Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.100 Safari/537.36"')
options.add_argument('--ignore-certificate-errors')
options.add_argument('--disable-gpu')
options.add_argument('--ssl-protocol=any')
options.add_argument('--no-sandbox')
options.add_argument('--disable-dev-shm-usage')
options.add_argument('user-agent=ywy')
options.add_argument('--ignore-urlfetcher-cert-requests')
options.add_argument('--ignore-ssl-errors')
options.add_experimental_option('excludeSwitches', ['enable-automation'])

In [8]:
driver = webdriver.Chrome(options=options)
driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
    "source": """
    Object.defineProperty(navigator, 'webdriver', {
      get: () => undefined
    })
  """
})
driver.get('https://www.gsxt.gov.cn/index.html')
time.sleep(5)


input_box = driver.find_element(By.ID, 'keyword')
input_box.send_keys('德信行')
time.sleep(1)
driver.find_element(By.ID, 'btn_query').click()
time.sleep(3)

# for _ in range(1000):
box = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/div[1]/div[1]/div[1]/div[1]')
text = box.get_attribute('innerText')
if text == '请选择 3 个符合的图片':

    # s = ''.join(random.sample(string.ascii_letters + string.digits, 4))
    src = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/div[1]/div[1]/div[1]/div[2]/img').get_attribute('src')
    content = requests.get(src).content
    with open(f'./test/background.png', 'wb') as f:
        f.write(content)
    time.sleep(1)
    
    label = ''
    label_imgs = os.listdir('./labels/')
    tmp_img = Image.open('./test/background.png')
    for l in label_imgs:
        tmp = Image.open(f'./labels/{l}')
        diff = ImageChops.difference(tmp, tmp_img)
        if diff.getbbox() is None:
            label = l[:-4]
            break
    else:
        if label == '':
            raise ValueError('超出训练数据集类别范围！')

    url = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/div[1]/div[2]/div/div/div/div[1]/div[1]/div').get_attribute('style')[23:-31]
    content = requests.get(url).content
    images = Image.open(BytesIO(content))
    width, height = images.size
    w, h = width // 3, height // 3
    img_list = crop_image(images)
    for i, img in enumerate(img_list):
        # s = ''.join(random.sample(string.ascii_letters + string.digits, 4))
        img.save(f'./test/images/{str(i)}.png', compress_level=0)
    location = {0: [-w, -h], 1: [-w, 0], 2: [-w, h], 3: [0, -h], 4: [0, 0], 5: [0, h], 6: [w, -h], 7: [w, 0], 8: [w, h]}
    
    idx = []
    for i in range(9):
        img_path = f'./test/images/{str(i)}.png'
        pred = recog_image(img_path=img_path, best_model=best_model, trf=trf, device=device)
        res = label_to_flag[pred]
        if res == label:
            idx.append(i)
    
    actions = ActionChains(driver)
    
    code_img = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/div[1]/div[2]/div/div/div')
    for key in idx:
        value = location[key]
        actions.move_to_element_with_offset(code_img, value[0], value[1]).click().perform()
        time.sleep(1)
    
    # check = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/div[1]/div[2]/div/div/div[2]/div')
    # actions.click(check).perform()

    # refresh = driver.find_element(By.XPATH, '/html/body/div[7]/div[1]/div[1]/div[3]/div[1]/button[2]')
    # actions.click(refresh)
    # actions.move_to_element_with_offset(refresh, 50, 50)
    # actions.perform()
    time.sleep(5)

else:
    # driver.refresh()
    # driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
    #     "source": """
    #     Object.defineProperty(navigator, 'webdriver', {
    #       get: () => undefined
    #     })
    #   """
    # })
    # time.sleep(3)
    # break
    driver.close()

driver.close()