In [1]:
# 搭一个spark的环境
import numpy as np
import pandas as pd
import cv2 
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
import os
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 933120000
from pyspark.sql import SparkSession
spark=SparkSession.builder.master("local").appName("大数据大作业-预处理").getOrCreate()

sc=spark.sparkContext
sc

In [2]:
class Picture(object):
    """用类的形式保存一张图片的图片本身，灰度图像,名称等，主要是完成了小作业中的用cv2提取特征的过程"""
    def __init__(self,name,image):
        self.name=name
        self.image=image
        self.gray=cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
    def show(self):
        """展示图片"""
        plt.imshow(cv2.cvtColor(self.image,cv2.COLOR_BGR2RGB))
        plt.pause(2)
        plt.close()
    def show_gray(self):
        """展示灰度图"""
        plt.imshow(self.gray,cmap="gray")
        plt.pause(2)
        plt.close()
    def corner_detection(self):
        """角点检测函数，返回本身的角点数量,并创建新的属性self.coners,即角点个数"""
        dst = cv2.cornerHarris(self.gray,2,3,0.04)
        #result is dilated for marking the corners, not important
        dst = cv2.dilate(dst,None)
        self.corners=np.array(np.sum(dst>0.01*dst.max()))
        return self.corners
    def face_detection(self):
        """人脸检测，返回检测到的人脸的坐标array,并创建新的属性self.faces:[人脸个数,人脸平均中心，人脸平均大小]"""
        face_cascade=cv2.CascadeClassifier("D:/Anaconda/lib/site-packages/cv2/data/haarcascade_frontalface_default.xml")
        faces=face_cascade.detectMultiScale(self.gray,1.3,5)
        if type(faces)==np.ndarray:
            temp=np.mean(faces,axis=0)
            self.faces=np.array([faces.shape[0],np.mean([temp[0],temp[2]]),np.mean([temp[1],temp[3]]),temp[2]*temp[3]])
            return self.faces
        return np.array([0,0,0,0])
    def keypoint_detection(self):
        """检测keypoint，并返回得到的属性,并把获得的重要特征进行区分"""
        # Initiate FAST detector
        star = cv2.xfeatures2d.StarDetector_create()
        # Initiate BRIEF extractor
        brief = cv2.xfeatures2d.BriefDescriptorExtractor_create()
        # find the keypoints with STAR
        kp = star.detect(self.image,None)
        # compute the descriptors with BRIEF
        kp, des = brief.compute(self.image, kp)
        if des is None:
            self.keypoint=np.array([0])
            return self.keypoint
        self.keypoint=np.array([des.shape[0]])
        return self.keypoint
    def ocr_detection(self):
        """尝试找到图片的文字区域,返回图片的文字区域个数。"""
        # 2. 形态学变换的预处理，得到可以查找矩形的图片
        def preprocess(gray):
            # 1. Sobel算子，x方向求梯度
            sobel = cv2.Sobel(gray, cv2.CV_8U, 1, 0, ksize = 3)
            # 2. 二值化
            ret, binary = cv2.threshold(sobel, 0, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
            # 3. 膨胀和腐蚀操作的核函数
            element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (30, 9))
            element2 = cv2.getStructuringElement(cv2.MORPH_RECT, (24, 6))
            # 4. 膨胀一次，让轮廓突出
            dilation = cv2.dilate(binary, element2, iterations = 1)
            # 5. 腐蚀一次，去掉细节，如表格线等。注意这里去掉的是竖直的线
            erosion = cv2.erode(dilation, element1, iterations = 1)
            # 6. 再次膨胀，让轮廓明显一些
            dilation2 = cv2.dilate(erosion, element2, iterations = 3)
            # 7. 存储中间图片 
            cv2.imwrite("binary.png", binary)
            cv2.imwrite("dilation.png", dilation)
            cv2.imwrite("erosion.png", erosion)
            cv2.imwrite("dilation2.png", dilation2)
            return dilation2
        def findTextRegion(img):
            region = []
            # 1. 查找轮廓
            contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            # 2. 筛选那些面积小的
            for i in range(len(contours)):
                cnt = contours[i]
                # 计算该轮廓的面积
                area = cv2.contourArea(cnt) 
                # 面积小的都筛选掉
                if(area < 1000):
                    continue
                # 轮廓近似，作用很小
                epsilon = 0.001 * cv2.arcLength(cnt, True)
                approx = cv2.approxPolyDP(cnt, epsilon, True)
                # 找到最小的矩形，该矩形可能有方向
                rect = cv2.minAreaRect(cnt)
                # box是四个点的坐标
                box = cv2.boxPoints(rect)
                box = np.int0(box)
                # 计算高和宽
                height = abs(box[0][1] - box[2][1])
                width = abs(box[0][0] - box[2][0])
                # 筛选那些太细的矩形，留下扁的
                if(height > width * 1.2):
                    continue
                region.append(box)
            return region
        dilation = preprocess(self.gray)
        # 3. 查找和筛选文字区域
        region = findTextRegion(dilation)
        if region is  None:
            self.regions=np.array([0])
        else:
            self.regions=np.array([len(region)])
        return self.regions
    def contour_detection(self):
        """返回contours个数"""
        canny_edges = cv2.Canny(self.image, 30, 200)
        contours, hierarchy = cv2.findContours(canny_edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        self.contours=np.array([len(contours)])
        return self.contours
    def HSV_detection(self):
        """返回对HSV的三个维度的一阶矩、二阶矩、三阶矩描述"""
        hsv=cv2.cvtColor(self.image,cv2.COLOR_BGR2HSV)
        h,s,v=cv2.split(hsv)
        h_1moment=np.mean(h)
        s_1moment=np.mean(s)
        v_1moment=np.mean(v)
        h_2moment=np.std(h)
        s_2moment=np.std(s)
        v_2moment=np.std(v)
        h_3moment=(np.mean(abs(h-h.mean())**3))**(1/3)
        s_3moment=(np.mean(abs(s-s.mean())**3))**(1/3)
        v_3moment=(np.mean(abs(v-v.mean())**3))**(1/3)
        self.hsv=np.array([h_1moment,h_2moment,h_3moment,s_1moment,s_2moment,s_3moment,v_1moment,v_2moment,v_3moment])
        return self.hsv
    def detections(self):
        """对不同的detections进行聚合操作，获得图片的特征，返回一个ndarray列表。"""
        temp=np.array([])   
        temp=np.append(temp,self.corner_detection())
        temp=np.append(temp,self.face_detection())
        temp=np.append(temp,self.keypoint_detection())
        temp=np.append(temp,self.contour_detection())
        temp=np.append(temp,self.HSV_detection())
        temp=np.append(temp,self.ocr_detection())
        return temp

In [3]:
def Process(image_name,train_or_test="train"):
    """对一张图片进行处理，把他的人工特征提取出来,返回Picture类,并顺便把这些图片保存到mydata方便之后训练。
    其中读取过程优先使用PIL，从而保证可以读取动图等异常图片的前几部分。
    """
    def Read_cv2(image_name,train_or_test="train",is_exists_before=False):
        """读取图片，返回opencv格式的图片"""
        if is_exists_before:
            return cv2.imread(write_full_path)
        read_full_path=train_or_test+"/"+image_name
        if  os.path.isdir(read_full_path):
            return None
        try:
            img_data = Image.open(read_full_path)
            if type(img_data)==PIL.GifImagePlugin.GifImageFile or type(img_data)==PIL.PngImagePlugin.PngImageFile:
                img_data=img_data.convert("RGB")
            img_data=np.asanyarray(img_data)
            img_data=cv2.cvtColor(img_data,cv2.COLOR_RGB2BGR)
            return img_data
        except BaseException as e:
            print(str(e),"尝试直接使用cv2读取文件...")
            img_data=cv2.imread(read_full_path)
            return img_data
            
    write_full_path="mydata/"+train_or_test+"/"+image_name# 图片写入的路径和名字train_or_test 是 "train" 或 "test"
    is_exists_before=os.path.exists(write_full_path)# 如果已经存在这个图片，就不多此一举了
    if is_exists_before:
        img_data=Read_cv2(image_name,train_or_test,is_exists_before=True)# 只读取
    else:# 如果路径下没有相应图片，说明该图片没有处理过，应该一顿处理，然后放到相应文件夹下面
        img_data=Read_cv2(image_name,train_or_test,is_exists_before=False)
        img_data=cv2.resize(img_data,(300,300),interpolation=cv2.INTER_NEAREST)
        cv2.imwrite(write_full_path,img_data)   
    # 下面是说：返回这个图片类
    picture=Picture(name=image_name[:-4],image=img_data)
    return picture
    

In [4]:
columns=["product_id","coner_numbers","face_numbers","face_x","faces_y","face_size","keypoint",
"contour_numbers","h1","h2","h3","s1","s2","s3","v1","v2","v3","text_regions"]
df_features_by_human=pd.DataFrame([],columns=columns)# 最后生成的dataframe，用来表明所有图片的用手工提取的属性。
image_names=os.listdir("train")
# 循环处理图片，包括把图片手工特征提取出来以及把图片裁剪一下放到mydata文件夹下面
for i,image_name in enumerate(image_names):
    picture=Process(image_name,train_or_test="train")
    features=picture.detections()
    data=np.append([picture.name],features,axis=0)
    data=pd.DataFrame(data.reshape(1,len(data)),columns=columns)
    df_features_by_human=df_features_by_human.append(data)
    if i%800==0:
        print("第",i,"张图片正在处理...")

df_features_by_human.head()

第 0 张图片正在处理...
第 800 张图片正在处理...


  "Palette images with Transparency expressed in bytes should be "


第 1600 张图片正在处理...
第 2400 张图片正在处理...
第 3200 张图片正在处理...
第 4000 张图片正在处理...
第 4800 张图片正在处理...
第 5600 张图片正在处理...
第 6400 张图片正在处理...
第 7200 张图片正在处理...


Unnamed: 0,product_id,coner_numbers,face_numbers,face_x,faces_y,face_size,keypoint,contour_numbers,h1,h2,h3,s1,s2,s3,v1,v2,v3,text_regions
0,1,47168.0,0.0,0.0,0.0,0.0,248.0,529.0,37.42097777777778,41.83535771913986,53.564119841293504,227.6955888888889,60.90090685776077,88.4149539885625,159.27947777777777,106.88535931836584,112.1505354942421,1.0
0,10,5800.0,0.0,0.0,0.0,0.0,69.0,65.0,29.012355555555555,32.32981256786961,48.50907983309518,40.68761111111111,48.03626248140147,69.6923422603402,222.0541333333333,25.91577045875949,37.44877274753828,1.0
0,100,7118.0,0.0,0.0,0.0,0.0,61.0,52.0,24.50458888888889,23.621974728062213,40.849771170394845,40.113455555555554,47.36650322577035,71.08319336977125,223.1409555555556,23.32328065303522,34.6235543344856,2.0
0,1000,1102.0,0.0,0.0,0.0,0.0,56.0,9.0,17.91017777777778,43.23500020908414,58.62402304267309,49.04227777777778,85.87807404383493,98.1411203586388,201.7015111111111,76.75561039171946,100.70790184201356,2.0
0,1001,981.0,0.0,0.0,0.0,0.0,56.0,9.0,5.817122222222222,10.957340013107173,13.397384503878452,36.318133333333336,73.03100431448429,93.9597658909266,192.06744444444445,79.73796061072808,98.4170580126663,2.0


In [5]:
df_features_by_human.to_csv("train_人工提取特征.csv")

In [6]:
# 对test进行类似操作
columns=["product_id","coner_numbers","face_numbers","face_x","faces_y","face_size","keypoint",
"contour_numbers","h1","h2","h3","s1","s2","s3","v1","v2","v3","text_regions"]
df_features_by_human=pd.DataFrame([],columns=columns)
image_names=os.listdir("test")
for i,image_name in enumerate(image_names):
    picture=Process(image_name,train_or_test="test")
    features=picture.detections()
    data=np.append([picture.name],features,axis=0)
    data=pd.DataFrame(data.reshape(1,len(data)),columns=columns)
    df_features_by_human=df_features_by_human.append(data)
    if i%800==0:
        print("第",i,"张图片正在处理...")

df_features_by_human.head()

第 0 张图片正在处理...
第 800 张图片正在处理...
第 1600 张图片正在处理...


Unnamed: 0,product_id,coner_numbers,face_numbers,face_x,faces_y,face_size,keypoint,contour_numbers,h1,h2,h3,s1,s2,s3,v1,v2,v3,text_regions
0,10000,16003.0,0.0,0.0,0.0,0.0,194.0,577.0,40.82513333333333,56.26588466561638,61.52987582843196,35.55336666666667,67.13927610902081,87.47258109034925,49.25275555555555,75.07420287190108,90.3119731404843,1.0
0,8001,2801.0,0.0,0.0,0.0,0.0,65.0,71.0,11.091033333333334,34.70390934813156,54.40594152053097,19.957844444444444,42.54721247715062,54.401014352361365,42.95611111111111,90.21224766804664,112.12642647082195,3.0
0,8002,3512.0,0.0,0.0,0.0,0.0,79.0,78.0,17.035255555555555,37.99694039351096,47.94449725812648,7.104588888888889,22.25126126382475,37.236187944676416,38.04937777777778,80.5661989067345,99.96325739655798,3.0
0,8003,2880.0,1.0,93.5,87.0,7921.0,50.0,77.0,13.701255555555557,30.05876220156524,38.8548018792236,8.163866666666667,26.719803994291077,46.73273434043582,40.32233333333333,86.60116647341162,107.67151910702783,2.0
0,8004,3370.0,0.0,0.0,0.0,0.0,75.0,105.0,7.348277777777778,19.967390146340428,32.20081817664204,10.86438888888889,30.14526125767049,46.8720324100535,43.809955555555554,92.4571659069949,114.08110508021488,3.0


In [7]:
df_features_by_human.to_csv("test_人工提取特征.csv")

In [8]:
df_data=spark.read.csv("data.csv",header=True)
df_data.show()

+----------+----------------+-----+------------------+--------+--------+-------------+--------------------+
|product_id|product_category|brand|             price|quantity|favorite|negative_info|           image_url|
+----------+----------------+-----+------------------+--------+--------+-------------+--------------------+
|         1|               2|    2|127778.66305461951|       1|    1056|            0|https://lh3.googl...|
|         2|               1|    1|           3250.34|       1|       2|            1|https://lh3.googl...|
|         3|               1|    1| 977.9385000000001|       1|       2|            1|https://lh3.googl...|
|         4|               1|    1| 977.9385000000001|       1|       2|            1|https://lh3.googl...|
|         5|               1|    1|1776.1040000000003|       1|       0|            0|https://lh3.googl...|
|         6|               1|    1|1776.1040000000003|       1|       0|            0|https://lh3.googl...|
|         7|               1

In [9]:
df_data_trainwithval=df_data.sample(fraction=0.75,seed=0)
df_data_valwithval=df_data.subtract(df_data_trainwithval)
df_data_trainwithval.show(10)
df_data_valwithval.show(10)

+----------+----------------+-----+------------------+--------+--------+-------------+--------------------+
|product_id|product_category|brand|             price|quantity|favorite|negative_info|           image_url|
+----------+----------------+-----+------------------+--------+--------+-------------+--------------------+
|         2|               1|    1|           3250.34|       1|       2|            1|https://lh3.googl...|
|         3|               1|    1| 977.9385000000001|       1|       2|            1|https://lh3.googl...|
|         5|               1|    1|1776.1040000000003|       1|       0|            0|https://lh3.googl...|
|         7|               1|    1| 6825.714000000001|       1|       0|            0|https://lh3.googl...|
|         8|               1|    1| 6825.714000000001|       1|       0|            0|https://lh3.googl...|
|         9|               1|    1|           1625.17|       1|       2|            0|https://lh3.googl...|
|        10|               1

## 再创造一个类别出来，用来给pytorch预训练

In [10]:
data_trainwithval=df_data_trainwithval.toPandas()
data_valwithval=df_data_valwithval.toPandas()

In [11]:
# 观察price数据
import matplotlib.pyplot as plt 
from pyspark.sql.types import FloatType
from pyspark.sql.functions import log
df_data=df_data.withColumn("price",df_data["price"].cast(FloatType()))
hists=df_data.withColumn("price",log(df_data["price"]+1)).select("price").rdd.flatMap(
    lambda row:row
).histogram(20)
# 这里尝试对price进行log 然后分组，尽量均匀的分20组供pytorch训练
bound=np.exp(hists[0])*np.exp(-1)

In [12]:
df_test=spark.read.csv("test.csv",header=True)
df_test=df_test.withColumn("price",df_test["price"].cast(FloatType()))
df_data=df_data.withColumn("price",df_data["price"].cast(FloatType()))


In [13]:
write_path="mydata/"
def category(price):
    """利用之前的bound上下界给所有图片加一个类别标签"""
    for i,b in enumerate(bound):
        if price<b:
            return i
    return len(bound)
def move_data(id,train_test,train_val,price):
    """按照之前分好的类，对图片进行移动"""
    cate=category(float(price))
    img=cv2.imread(write_path+train_test+"/"+str(id)+".jpg")
    if img is None :
        return 0
    if train_test=="train" and train_val=="train":
            cv2.imwrite(write_path+"train_with_val/"+str(cate)+"/"+str(id)+".jpg",img)
            return 0
    elif train_test=="train" and train_val=="val":
            cv2.imwrite(write_path+"val_with_val/"+str(cate)+"/"+str(id)+".jpg",img)
            return 0
    else:
        cv2.imwrite(write_path+"test_with_val/"+str(cate)+"/"+str(id)+".jpg",img)
        return 0

# 正式对图片分类+移动到mydata/train_with_val的操作
for index in data_trainwithval.index:
    move_data(id=data_trainwithval.loc[index,"product_id"],train_test="train",train_val="train",price=data_trainwithval.loc[index,"price"])
    if index%800==0:
        print("train",index)
for index in data_valwithval.index:
    move_data(id=data_valwithval.loc[index,"product_id"],train_test="train",train_val="val",price=data_valwithval.loc[index,"price"])
    if index%800==0:
        print("val",index)
test_data=pd.read_csv("test.csv")
for index in test_data.index:
    move_data(id=test_data.loc[index,"product_id"],train_test="test",train_val="train",price=test_data.loc[index,"price"])
    if index%800==0:
        print("test",index)

train 0
train 800
train 1600
train 2400
train 3200
train 4000
train 4800
train 5600
val 0
val 800
val 1600
test 0
test 800
test 1600


In [14]:
test_data

Unnamed: 0,product_id,product_category,brand,price,quantity,favorite,negative_info,image_url
0,8001,3,4,75.413360,1,0,0,https://lh3.googleusercontent.com/BAvQgUsNDz_C...
1,8002,3,4,3.139110,1,2,0,https://lh3.googleusercontent.com/KRKwIkBz6nT_...
2,8003,3,4,252.789600,1,38,0,https://lh3.googleusercontent.com/o_bWO6YIzbQ2...
3,8004,3,4,0.784100,1,0,2,https://lh3.googleusercontent.com/ANaCEfjGaKOJ...
4,8005,3,4,1.443310,1,0,0,https://lh3.googleusercontent.com/NojrW0o82kbt...
...,...,...,...,...,...,...,...,...
1995,9996,1,3,90.585810,1,0,0,https://lh3.googleusercontent.com/d1b23HJPTiqJ...
1996,9997,1,3,0.177616,1,0,0,https://lh3.googleusercontent.com/vF4b9nK-WMWo...
1997,9998,1,3,41.532400,1,0,0,https://lh3.googleusercontent.com/nvOOLQpY7MXK...
1998,9999,1,3,41.532400,1,0,0,https://lh3.googleusercontent.com/yV44pocJq9Vl...
