# question

## pdf

In [None]:
import os
from glob import glob 
import pandas as pd
from pdf2image import convert_from_path
from datetime import datetime

class QuePdf:
    def __init__(self):
        self.key = 'pdf_path'
        self.setup_log()
    
    def setup_log(self):
        self.log_path = '/home/ryh/embedding-match/ocr/datasets/quePdf.pickle'
        self.log_columns = [self.key, 'created_time']
        try:
            self.log = pd.read_pickle(self.log_path)
        except:
            self.log = pd.DataFrame(columns=self.log_columns)
        self.path_set = set(self.log[self.key].values)
    
    def update_log(self, pdf_path):
        row = pd.DataFrame(columns=self.log_columns)
        
        row[self.key] = [pdf_path]
        row['created_time'] = [datetime.now()]
        
        self.log = self.log.append(row)
        self.path_set.add(pdf_path)
        
    def save_log(self):
        self.log = self.log.sort_values(by=self.key, ascending=True).reset_index(drop=True)
        self.log.to_pickle(self.log_path)
        
    def get_pdf_path_list(self, year='*', subject='*', press='*'):
        d = '/home/ryh/embedding-match/ocr/datasets/taiwan/13to15/press/%s/%s/%s/question/pdf/*.pdf'%(year, subject, press)
        pdf_path_list = sorted(glob(d))
        return pdf_path_list
    
    def pdf2papers(self, pdf_path):
        if pdf_path in self.path_set:
            #print('alread exist', '---', pdf_path)
            return
        
        images = convert_from_path(pdf_path)
        
        old_dirname = os.path.dirname(pdf_path)
        old_basename = os.path.basename(pdf_path)
        
        new_dirname = old_dirname.replace('pdf', 'paper')
        os.makedirs(new_dirname, exist_ok=True)
        for i, image in enumerate(images):     
            print(i+1, len(images), end='\r')
            new_basename = old_basename.replace('.pdf', '') + '---%02d'%(i+1) + '.png' 
            img_path = os.path.join(new_dirname, new_basename)
            image.save(img_path, 'PNG')
            
        self.update_log(pdf_path)
        
    
        
    

### steps

In [None]:
quePdf = QuePdf()

In [None]:
pdf_path_list = quePdf.get_pdf_path_list(year='108', subject='數學', press='康軒')
len(pdf_path_list)


In [None]:
pdf_path = pdf_path_list[0]
# pdf_path = '/home/ryh/embedding-match/ocr/datasets/taiwan/13to15/press/108/數學/康軒/question/pdf/108上[康軒]國中試卷-(三)數學-B卷-(中上)-(題).pdf'
pdf_path = '/home/ryh/embedding-match/ocr/datasets/taiwan/13to15/press/108/數學/南一/answer/pdf/108下-南一國中試卷數學(4)D卷答.pdf'


In [None]:
quePdf.pdf2papers(pdf_path)

In [None]:
quePdf.save_log()

In [None]:
quePdf.log


## paper

### class

In [None]:
from google.cloud import vision
import io
from glob import glob 
import os
import re
from datetime import datetime
import pandas as pd
pd.set_option('display.max_rows', 500)

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

class QuePaper:
    def __init__(self):
        self.key = 'paper_path'
        self.vision_client = vision.ImageAnnotatorClient() 
        self.setup_log()
    
    def get_paper_path_list(self, year='*', subject='*', press='*', pdf_path=''):
        path = os.path.basename(pdf_path)
        path = path.replace('.pdf', '').replace('[', '*').replace(']', '*')
        d = '/home/ryh/embedding-match/ocr/datasets/taiwan/13to15/press/%s/%s/%s/question/paper/%s*.png'%(year, subject, press, path)
        paper_path_list = sorted(glob(d))
        return paper_path_list
    
    def setup_log(self):
        self.log_path = '/home/ryh/embedding-match/ocr/datasets/quePaper.pickle'
        try:
            self.log = pd.read_pickle(self.log_path)
        except:
            cols = [self.key, 'created_time', 'paper_size', 'vertical_list', 'q_list', 'q_list_check', 'page', 'page_check', 'box_list', 'name_list', 'is_split']
            self.log = pd.DataFrame(columns=cols)
            self.log.to_pickle(self.log_path)
        self.path_set = set(self.log[self.key].values)
    
    def update_log(self):
        row = pd.DataFrame()
        row[self.key] = [self.paper_path]
        row['created_time'] = [datetime.now()]
        row['paper_size'] = [(self.paper_width, self.paper_height)]
        row['vertical_list'] = [self.vertical_list]
        row['q_list'] = [self.sort_q_list]
        row['q_list_check'] = ['x']
        row['page'] = [self.page]
        row['page_check'] = ['x']
        row['box_list'] = ['x']
        row['name_list'] = ['x']
        row['is_split'] = ['x']
        
        self.log = self.log.append(row)
        self.path_set.add(self.paper_path)
        
    def save_log(self):
        self.log = self.log.sort_values(by=self.key, ascending=True).reset_index(drop=True) # cost time
        self.log.to_pickle(self.log_path) # cost time
        
    def paper2segments(self, paper_path):
        if paper_path in self.path_set: return
        
        self.setup_paper(paper_path)
        
        self.ann_list = self.detect_text(self.paper_byte)
        self.q_list = self.get_q_list(self.ann_list)
        self.filter_q_list = self.get_filter_q_list(self.q_list)
        self.vertical_list = self.get_vertical_list(self.filter_q_list)
        self.align_q_list = self.get_align_q_list(self.filter_q_list)
        self.sort_q_list = self.get_sort_q_list(self.align_q_list)

        self.update_log()
    
    def setup_paper(self, paper_path):
        self.paper_path = paper_path
        if self.paper_path in self.path_set: return
        with open(self.paper_path, 'rb') as f:
            self.paper_byte = f.read()
        self.paper_arr = mpimg.imread(self.paper_path)
        self.paper_height, self.paper_width = self.paper_arr.shape[0], self.paper_arr.shape[1]
        
    def detect_text(self, paper_byte):
        image = vision.Image(content=paper_byte)
        response = self.vision_client.text_detection(image=image)
        text_annotation_list = response.text_annotations
        return text_annotation_list[1:]
    
    def get_q_list(self, ann_list):
        q_list = []
        for ann in ann_list:
            vertices = ann.bounding_poly.vertices
            xy = vertices[0] # 左上
            x, y = xy.x, xy.y
            x_float, y_float = x / self.paper_width, y / self.paper_height
            text = ann.description
            q = {'x_float': x_float, 'y_float': y_float, 'text': text}
            q_list.append(q)
        return q_list
    
    def get_filter_q_list(self, q_list):
        filter_q_list = []
        for i, q in enumerate(q_list):
            words = ''.join([x['text'] for x in q_list[i:i+5]])
            x_float = q['x_float']
            y_float = q['y_float']
            text = q['text']
            c1 = self.is_valid_words(words)
            c2 = self.is_valid_x_position(x_float)
            c3 = self.is_pre_valid(i, q_list)
            if c1 and c2 and c3:
                num = self.get_num(text)
                q['num'] = num
                filter_q_list.append(q)   
                
            self.setup_page(y_float, words) #!
        return filter_q_list
    
    def is_valid_words(self, words):
        pat = '^([1-9]|[1-2][0-9])(\.)(.)(.)'
        x = re.search(pat, words)
        return (x != None)
    
    def is_valid_x_position(self, x_float):
        is_valid = (0.15 > x_float > 0) or (0.6 > x_float > 0.49)
        return is_valid
    
    def is_pre_valid(self, i, q_list):
        if i==0: return True
        pre = q_list[i-1]
        cur = q_list[i]
        pre_x_float, pre_y_float = pre['x_float'], pre['y_float']
        cur_x_float, cur_y_float = cur['x_float'], cur['y_float']
        x_diff, y_diff = abs(pre_x_float - cur_x_float), abs(pre_y_float - cur_y_float)
        pre_text = pre['text']
        
        # cur跟pre都在同一邊
        # 假如前一個字不是 ")" 就有問題
        if (cur_x_float < 0.3 and pre_x_float < 0.3) or (cur_x_float > 0.5 and pre_x_float > 0.5):
            if x_diff < 0.05 and y_diff < 0.03 and pre_text != ')':
                return False
        return True
    
    def get_num(self, text):
        try:
            if '.' in text:
                num = int(text[:text.index('.')])
            else:
                num = int(text)
        except:
            num = 999
        return num
    
    def get_vertical_list(self, filter_q_list):
        L, R = [], []
        for q in filter_q_list:
            x_float = q['x_float']
            if x_float < 0.45:
                L.append(x_float)
            else:
                R.append(x_float)
        x_L = min(L)
        if len(R) != 0:
            x_R = min(R)
            vertical_list = [round(x_L, 3), round(x_R, 3)]
        else:
            vertical_list = [round(x_L, 3)]
        return vertical_list
    
    def get_align_q_list(self, filter_q_list):
        align_q_list = []
        for q in filter_q_list:
            x_float = q['x_float']
            if x_float < 0.45:
                q['x_float'] = self.vertical_list[0]
            else:
                q['x_float'] = self.vertical_list[1] # 假如x_float < 0.45，代表verical_list一定有兩個
            align_q_list.append(q)
        return align_q_list
        
    def get_sort_q_list(self, align_q_list):
        sort_q_list = sorted(align_q_list, key=(lambda q: (q['x_float'], q['y_float'])), reverse=False)
        return sort_q_list
    
    def setup_page(self, y_float, words):
        if y_float < 0.2 or y_float > 0.8:
            pat = '^(\(|\{)?([1-9]|1[0-9])(-|–)([1-2])'
            x = re.search(pat, words)
            if x != None:
                self.page = x.group(0).replace('–', '-').replace('(', '').replace('{', '') 
    
    def get_q_list_check(self, q_list, remove_q_index_list=[], add_q_list=[]):
        q_list = [q for index, q in enumerate(q_list) if index not in remove_q_index_list]
        q_list += add_q_list
        q_list = self.get_sort_q_list(q_list)
        return q_list
    
    def check_q_list(self):
        df = self.log[self.log.q_list_check=='x'].copy()
        if len(df) == 0: return
        df['q_list_check'] = df.apply(lambda row: self.update_q_list_check(row), axis=1)
        self.log.loc[df.index, 'q_list_check'] = df['q_list_check']
        self.save_log()
    
    def is_valid_q_list(self, q_list):
        for i in range(1, len(q_list)):
            cur = q_list[i]
            pre = q_list[i-1]
            cur_num = cur['num']
            pre_num = pre['num']
            if cur_num - pre_num != 1:
                if cur_num != 1:
                    return False
        return True
    
    def update_q_list_check(self, row):
        q_list = row['q_list']
        if self.is_valid_q_list(q_list):
            return q_list
        else:
            return 'to_check'
    
    def update_q_list_check_(self, index=0, q_list_check=None):
        self.log.at[index, 'q_list_check'] = q_list_check
        self.save_log()
        
    def check_page(self):
        df = self.log[self.log.page_check=='x'].copy()
        if len(df) == 0: return
        df['page_check'] = df.apply(lambda row: self.update_page_check(row), axis=1)
        self.log.loc[df.index, 'page_check'] = df['page_check']
        self.save_log()
        
    def is_valid_page(self, pre_page='3-2', cur_page='4-1'):
        pre_1, pre_2 = pre_page.split('-')
        cur_1, cur_2 = cur_page.split('-')
        pre_1, pre_2 = int(pre_1), int(pre_2)
        cur_1, cur_2 = int(cur_1), int(cur_2)
        
        # 3-2 4-1 or 3-1 3-2
        if (cur_1 - pre_1 == 1) and (cur_2 - pre_2 == -1): return True
        elif (cur_1 - pre_1 == 0) and (cur_2 - pre_2 == 1): return True
        else: return False
    
    def update_page_check(self, row):
        i = row.name
        cur_page = self.log.loc[i, 'page']        
        if i==0 and cur_page=='1-1': return cur_page
        if i==0 and cur_page!='1-1': return 'to_check'
        pre_page = self.log.loc[i-1, 'page']
        
        if self.is_valid_page(pre_page, cur_page) == True:
            return cur_page
        else:
            if cur_page == '1-1':
                return cur_page
            else:
                return 'to_check'
    
    def update_page_check_(self, index=0, page_check=None):
        self.log.at[index, 'page_check'] = page_check
        self.save_log()
        
    def update_box_list(self):
        c1 = (self.log.box_list=='x')
        c2 = (self.log.q_list_check!='x')
        c3 = (self.log.q_list_check!='to_check')
        df = self.log[c1 & c2 & c3].copy()
        if len(df) == 0: return
        df['box_list'] = df.apply(lambda row: self.get_box_list(row), axis=1)
        self.log.loc[df.index, 'box_list'] = df['box_list']
        self.save_log()
    
    def get_box_list(self, row):
        up_space, down_space, bottom = 0.013, 0.003, 0.95
        q_list_check = row.q_list_check
        box_list, L, R = [], [], []
        for q in q_list_check:
            if q['x_float'] < 0.45:
                L.append(q)
            else:
                R.append(q)        
        
        if len(R) != 0:
            x_L, x_R = row.vertical_list
            for i in range(len(L)):
                cur = L[i]
                y2 = bottom if i == len(L)-1 else L[i+1]['y_float']
                q = {'x1': x_L, 'x2': x_R, 'y1': cur['y_float']-up_space, 'y2': y2-down_space}
                box_list.append(q)
            for i in range(len(R)):
                cur = R[i]
                y2 = bottom if i == len(R)-1 else R[i+1]['y_float']
                q = {'x1': x_R, 'x2': 1, 'y1': cur['y_float']-up_space, 'y2': y2-down_space}
                box_list.append(q)
        else:
            x_L = row.vertical_list
            for i in range(len(L)):
                cur = L[i]
                y2 = bottom if i == len(L)-1 else L[i+1]['y_float']
                q = {'x1': cur['x_float'], 'x2': 1, 'y1': cur['y_float']-up_space, 'y2': y2-down_space}
                box_list.append(q)
        return box_list
    
    def update_name_list(self):
        c1 = (self.log.name_list=='x')
        c2 = (self.log.q_list_check!='x')
        c3 = (self.log.q_list_check!='to_check')
        df = self.log[c1 & c2 & c3].copy()
        if len(df) == 0: return
        df['name_list'] = df.apply(lambda row: self.get_name_list(row), axis=1)
        self.log.loc[df.index, 'name_list'] = df['name_list']
        self.save_log()
        
    def get_name_list(self, row):
        i = row.name
        page_check = row.page_check
        if page_check[-2:] == '-1':
            i1, i2 = i, i+1
        elif page_check[-2:] == '-2':
            i1, i2 = i-1, i
        p1, p2 = self.log.loc[i1], self.log.loc[i2]
        if self.is_valid_pair(p1, p2) == False: return 'to_check'
        q_list = p1.q_list_check + p2.q_list_check
        num_list = [q['num'] for q in q_list]
        name_list = self.num_list_to_name_list(num_list)
        if i1 == i:
            name_list_ = name_list[:len(p1.q_list_check)]
        else:
            name_list_ = name_list[len(p1.q_list_check):]
        
        page_check_ = page_check.split('-')[0]
        name_list_ = [(page_check_ + '-' + name) for name in name_list_]
        return name_list_
    
    def num_list_to_name_list(self, num_list):
        a = 1
        name_list = ['1-1']
        for i in range(1, len(num_list)):
            cur = num_list[i]
            pre = num_list[i-1]            
            if pre >= cur:
                a += 1
            name = '%s-%s'%(a, num_list[i])
            name_list.append(name)
        return name_list
    
    def is_valid_pair(self, p1, p2):
        c1 = (p1.paper_path.split('---')[0] == p2.paper_path.split('---')[0]) # pdf一樣
        c2 = (p1.page_check[:-1] + p1.page_check[-1:].replace('1', '2')) == p2.page_check # 12-1 vs 12-2
        return c1 and c2
    
    def plot_q(self, paper_path, q_list):
        img = mpimg.imread(paper_path)
        img_height, img_width = img.shape[0], img.shape[1]
        
        d = 100
        fig, ax = plt.subplots(figsize=(img_height/d, img_width/d), dpi=d)
        
        for i in range(len(q_list)):
            q = q_list[i]
            x = q['x_float'] * img_width
            y = (q['y_float'] - 0.01) * img_height
            text = '%s --- (%.3f, %.3f) --- %s' %(q['num'], q['x_float'], q['y_float'], i)
            ax.text(x, y, text, size=d*0.15, color='red', bbox={'facecolor': 'white', 'alpha': 0.7}) # , bbox_list={'edgecolor': 'red', 'fill': False, 'linewidth': 2} facecolor='red'
        
        ax.imshow(img, interpolation='none')
        plt.tight_layout()
        plt.show()
        
    def plot_border(self, paper_path, box_list):
        img = mpimg.imread(paper_path)
        img_height, img_width = img.shape[0], img.shape[1]
        
        d = 100
        fig, ax = plt.subplots(figsize=(img_height/d, img_width/d), dpi=d)
        
        for i in range(len(box_list)):
            box = box_list[i]
            x1, x2, y1, y2 = box['x1'], box['x2'], box['y1'], box['y2']
            x1_, x2_, y1_, y2_ = x1*img_width, x2*img_width, y1*img_height, y2*img_height
            rec_width, rec_height = abs(x1_ - x2_), abs(y1_ - y2_)
            xy = (x1_, y1_)
            rect = patches.Rectangle(xy, rec_width, rec_height, linewidth=2, edgecolor='g', facecolor='none')
            ax.add_patch(rect)
            
        ax.imshow(img, interpolation='none')
        plt.tight_layout()
        plt.show()
    

### steps

In [None]:
quePaper = QuePaper()


In [None]:
paper_path_list = quePaper.get_paper_path_list(year='*', subject='*', press='*', pdf_path=pdf_path)
len(paper_path_list)


In [None]:
for i in range(len(paper_path_list)):
# for i in [1]:
    print(i+1, len(paper_path_list), end='\r')
    paper_path = paper_path_list[i]
    quePaper.paper2segments(paper_path)
quePaper.save_log()


In [None]:
quePaper.log.head(2)

### check q_list

In [None]:
quePaper.check_q_list()

#### original

In [None]:
log = quePaper.log
df_to_check = log[log.q_list_check=='to_check']
print(len(df_to_check))

if len(df_to_check) != 0:
    i = df_to_check.index[0]
    paper_path = quePaper.log.loc[i, 'paper_path']
    q_list = quePaper.log.loc[i, 'q_list']
    quePaper.plot_q(paper_path, q_list)

#### add or remove

In [None]:
paper_path = quePaper.log.loc[i, 'paper_path']
q_list = quePaper.log.loc[i, 'q_list']

add_q_list = [{'num': 2, 'text': 'x', 'x_float': 0.064, 'y_float': 0.190}]
remove_q_index_list = []

# add_q_list = []
# remove_q_index_list = [9]

# add_q_list = [
#     {'num': 2, 'text': 'x', 'x_float': 0.493, 'y_float': 0.038},
#     {'num': 3, 'text': 'x', 'x_float': 0.493, 'y_float': 0.278},
# ]
# remove_q_index_list = [2]

q_list_check = quePaper.get_q_list_check(q_list, remove_q_index_list=remove_q_index_list, add_q_list=add_q_list)

print(quePaper.is_valid_q_list(q_list_check))
quePaper.plot_q(paper_path, q_list_check)


In [None]:
quePaper.update_q_list_check_(index=i, q_list_check=q_list_check)

In [None]:
# quePaper.log

### check page

In [None]:
quePaper.check_page()

#### revise

In [None]:
log = quePaper.log
df_check_page = log[log.page_check=='to_check']
print('len(df_check_page)', len(df_check_page))

if len(df_check_page) != 0:
    i = df_check_page.index[0]
    print('i', i)
    d = 4
    start = i - d if (i-d) >= 0 else 0
    end = i + d
    display(log[start:end])
    
    paper_path = quePaper.log.loc[i, 'paper_path']
    q_list = quePaper.log.loc[i, 'q_list']
    quePaper.plot_q(paper_path, q_list)

In [None]:
# page_check = '11-1'
# quePaper.update_page_check_(index=i, page_check=page_check)

In [None]:
# quePaper.log

In [None]:
40/3066, 10/3066

### update box_list

In [None]:
quePaper.update_box_list()

In [None]:
for i in [22, 23, 24, 25]:
    paper_path = quePaper.log.loc[i, 'paper_path']
    box_list = quePaper.log.loc[i, 'box_list']
    quePaper.plot_border(paper_path, box_list)

In [None]:
quePaper.log.head()

### update name

In [None]:
quePaper.num_list_to_name_list([1,2,3,4,1,1,2,3])

In [None]:
# quePaper.log = quePaper.log.rename(columns={'name': 'name_list'})
# quePaper.log['name_list'] = 'x'

In [None]:
quePaper.update_name_list()

In [None]:
i = 30
p1 = quePaper.log.loc[i]
p2 = quePaper.log.loc[i+1]
p1.name_list, p2.name_list

In [None]:
quePaper.log.head()

## segment

In [None]:
class QueSeg:
    def __init__(self):
        pass

In [None]:
[43, 58, 73, 94, 98, 115]