In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import openai
from dotenv import load_dotenv

from selenium.webdriver.common.by import By
from bs4 import BeautifulSoup as bs

from method.ours import (
    create_driver,
    embed_properties,
    get_processable_nodes,
    create_relations_graph,
    create_2d_span_ordered_dict,
    add_for_links,
    add_parent_child_links,
    add_left_right_links,
    add_top_bottom_links,
    create_node2vec_model,
    add_weight_to_graph,
    cutoff_low_score_edges,
)

In [3]:
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

In [4]:
# Global Variables
HEADLESS = False
TEXT_EMBEDDING_METHOD = 'ADA' # ['ADA', 'WORD2VEC', 'SPACY']
GRAPH_EMBEDDING_METHOD = 'NODE2VEC' # ['NODE2VEC', 'GCN']

In [5]:
driver = create_driver(HEADLESS)
driver.get('https://ant.design/components/form')

In [7]:
# form = driver.find_elements(By.TAG_NAME, 'form')[48]
form = driver.find_element(By.ID, 'register')
form = embed_properties(driver, form)

form_doc = bs(form.get_attribute('outerHTML'), 'html.parser')

In [8]:
'''from bs4 import NavigableString


def merge_single_child_parents(root):
    if isinstance(root, NavigableString):
        return
    if is_input_element(root):
        return
    
    children = list(root.children)
    
    if len(children) > 1:
        for child in children:
            merge_single_child_parents(child)
    
    if len(children) == 1 and not isinstance(children[0], NavigableString):
        only_child = children[0]
        childs_children = list(only_child.children)
        
        root.attrs = only_child.attrs
        root.name = only_child.name
        
        for i in range(len(childs_children)):
            root.insert(i + 1, childs_children[i])
        
        only_child.extract()
        merge_single_child_parents(root)

merge_single_child_parents(form_doc)'''

form_processable_nodes = get_processable_nodes(form_doc)

In [9]:
relation_graph = create_relations_graph(form_processable_nodes, TEXT_EMBEDDING_METHOD)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:06<00:00,  6.01it/s]


In [10]:
spans_2d = create_2d_span_ordered_dict(relation_graph)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 112812.31it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 32640.50it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 40920.04it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 13797.05it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 29537.35it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [0

In [11]:
relation_graph = add_for_links(relation_graph)
relation_graph = add_parent_child_links(spans_2d, relation_graph)
relation_graph = add_left_right_links(spans_2d, relation_graph)
relation_graph = add_top_bottom_links(spans_2d, relation_graph)

In [12]:
model = create_node2vec_model(form_doc)

Computing transition probabilities:   0%|          | 0/193 [00:00<?, ?it/s]


Generating walks (CPU: 1):   0%|          | 0/50 [00:00<?, ?it/s]
Generating walks (CPU: 3):   0%|          | 0/50 [00:00<?, ?it/s]
Generating walks (CPU: 2):   0%|          | 0/50 [00:00<?, ?it/s]
Generating walks (CPU: 4):   0%|          | 0/50 [00:00<?, ?it/s]
Generating walks (CPU: 3): 100%|██████████| 50/50 [00:00<00:00, 754.69it/s]

Generating walks (CPU: 1): 100%|██████████| 50/50 [00:00<00:00, 725.55it/s]

Generating walks (CPU: 4): 100%|██████████| 50/50 [00:00<00:00, 748.07it/s]

Generating walks (CPU: 2): 100%|██████████| 50/50 [00:00<00:00, 723.47it/s]


In [13]:
relation_graph = add_weight_to_graph(model, relation_graph)
relation_graph = cutoff_low_score_edges(relation_graph, 0)

  return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 111.60it/s]

Mean: 0.5952570922023833
Standard Deviation: 0.190799543415622
Cutoff: 0.5952570922023833





In [14]:
for edge in relation_graph.edges():
    print(edge.source)
    print(edge.target)
    print(edge.weight)
    print()

<label>E-mail</label> at y: (10640, 10672), x: (545, 609)
<input></input> at y: (10640, 10672), x: (609, 1009)
0.7188399

<label>Password</label> at y: (10696, 10728), x: (524, 609)
<input></input> at y: (10701, 10723), x: (621, 979)
0.7338136

<label>Confirm Password</label> at y: (10752, 10784), x: (468, 609)
<input></input> at y: (10757, 10779), x: (621, 979)
0.7415908

<label></label> at y: (10808, 10840), x: (503, 609)
<input></input> at y: (10808, 10840), x: (609, 1009)
0.6465418

<label>Habitual Residence</label> at y: (10864, 10896), x: (461, 609)
<input></input> at y: (10865, 10895), x: (621, 997)
0.697651629738871

<label>Phone Number</label> at y: (10920, 10952), x: (489, 609)
<input></input> at y: (10920, 10952), x: (680, 1009)
0.7216641

<label>Donation</label> at y: (10976, 11008), x: (527, 609)
<input></input> at y: (10977, 11007), x: (610, 937)
0.61011803

<label>Intro</label> at y: (11088, 11120), x: (556, 609)
<textarea></textarea> at y: (11089, 11141), x: (610, 1008)

# Create Global Connections

In [51]:
def append_if_not_in_list(groups, node_to_index, node):
    idx = max(groups.keys()) + 1 if len(groups.keys()) > 0 else 0
    if node_to_index[node] == -1:
        groups[idx] = [node]
        node_to_index[node] = idx


def merge_subgraphs(groups, node_to_index, idx1, idx2):
    merge_at = min(idx1, idx2)
    remove_at = max(idx1, idx2)
    sub1, sub2 = groups[idx1], groups[idx2]
    merged = [*sub1, *sub2]
    groups[merge_at] = merged
    groups.pop(remove_at)
    for node in merged:
        node_to_index[node] = merge_at
    return groups


def create_subgraph_groups(nodes, edges):
    groups = {}
    node_to_index = {}
    
    for node in nodes:
        node_to_index[node] = -1
    
    for node, idx in node_to_index.items():
        append_if_not_in_list(groups, node_to_index, node)
    
    for edge in edges:
        source_idx = node_to_index[edge.source]
        target_idx = node_to_index[edge.target]
        if source_idx == target_idx:
            continue
        groups = merge_subgraphs(groups, node_to_index, source_idx, target_idx)    
    
    return groups

In [52]:
groups = create_subgraph_groups(
    relation_graph.nodes(),
    relation_graph.edges()
)

In [53]:
for group in groups.values():
    for e in group:
        print(e)
    print()

<label>E-mail</label> at y: (10640, 10672), x: (545, 609)
<input></input> at y: (10640, 10672), x: (609, 1009)

<label>Password</label> at y: (10696, 10728), x: (524, 609)
<input></input> at y: (10701, 10723), x: (621, 979)
<label>Confirm Password</label> at y: (10752, 10784), x: (468, 609)
<input></input> at y: (10757, 10779), x: (621, 979)

<label></label> at y: (10808, 10840), x: (503, 609)
<input></input> at y: (10808, 10840), x: (609, 1009)

<label>Habitual Residence</label> at y: (10864, 10896), x: (461, 609)
<input></input> at y: (10865, 10895), x: (621, 997)
<span>Zhejiang / Hangzhou / West Lake</span> at y: (10865, 10895), x: (621, 997)

<input></input> at y: (10921, 10951), x: (622, 668)
<span>+86</span> at y: (10921, 10951), x: (622, 668)
<label>Phone Number</label> at y: (10920, 10952), x: (489, 609)
<input></input> at y: (10920, 10952), x: (680, 1009)

<label>Donation</label> at y: (10976, 11008), x: (527, 609)
<input></input> at y: (10977, 11007), x: (610, 937)
<input></i

In [61]:
import itertools
from method.ours.embedding_distance import get_text_similarity


scores = []

for idx1, group1 in groups.items():
    labels1 = list(filter(lambda x: x.element.name == 'label', group1))
    # TODO: do all types of input
    inputs1 = list(filter(lambda x: x.element.name == 'input', group1))
    
    for idx2, group2 in groups.items():
        if group1 == group2:
            continue
        
        labels2 = list(filter(lambda x: x.element.name == 'label', group2))
        inputs2 = list(filter(lambda x: x.element.name == 'input', group2))
        
        labels_scores = [
            get_text_similarity(n1, n2) for n1, n2 in itertools.product(labels1, labels2)
        ]
        inputs_scores = [
            get_text_similarity(n1, n2) for n1, n2 in itertools.product(inputs1, inputs2)
        ]
        
        try:
            max_score = max([*labels_scores, *inputs_scores])
        except:
            max_score = 0
        
        print(
            idx1, idx2,
            '\n',
            group1,
            '\n',
            group2,
            '\n',
            max_score,
            '\n\n'
        )
        
        scores.append((idx1, idx2, max_score))

0 2 
 [<label>E-mail</label> at y: (10640, 10672), x: (545, 609), <input></input> at y: (10640, 10672), x: (609, 1009)] 
 [<label>Password</label> at y: (10696, 10728), x: (524, 609), <input></input> at y: (10701, 10723), x: (621, 979), <label>Confirm Password</label> at y: (10752, 10784), x: (468, 609), <input></input> at y: (10757, 10779), x: (621, 979)] 
 0.8135655690224038 


0 6 
 [<label>E-mail</label> at y: (10640, 10672), x: (545, 609), <input></input> at y: (10640, 10672), x: (609, 1009)] 
 [<label></label> at y: (10808, 10840), x: (503, 609), <input></input> at y: (10808, 10840), x: (609, 1009)] 
 0 


0 8 
 [<label>E-mail</label> at y: (10640, 10672), x: (545, 609), <input></input> at y: (10640, 10672), x: (609, 1009)] 
 [<label>Habitual Residence</label> at y: (10864, 10896), x: (461, 609), <input></input> at y: (10865, 10895), x: (621, 997), <span>Zhejiang / Hangzhou / West Lake</span> at y: (10865, 10895), x: (621, 997)] 
 0.7402768992024891 


0 11 
 [<label>E-mail</labe

6 19 
 [<label></label> at y: (10808, 10840), x: (503, 609), <input></input> at y: (10808, 10840), x: (609, 1009)] 
 [<label>Website</label> at y: (11032, 11064), x: (533, 609)] 
 0 


6 20 
 [<label></label> at y: (10808, 10840), x: (503, 609), <input></input> at y: (10808, 10840), x: (609, 1009)] 
 [<input></input> at y: (11032, 11064), x: (609, 1009), <span>website</span> at y: (11032, 11062), x: (609, 1009)] 
 0 


6 22 
 [<label></label> at y: (10808, 10840), x: (503, 609), <input></input> at y: (10808, 10840), x: (609, 1009)] 
 [<label>Intro</label> at y: (11088, 11120), x: (556, 609), <textarea></textarea> at y: (11089, 11141), x: (610, 1008), <span>0 / 100</span> at y: (11141, 11163), x: (964, 1008)] 
 0 


6 25 
 [<label></label> at y: (10808, 10840), x: (503, 609), <input></input> at y: (10808, 10840), x: (609, 1009)] 
 [<label>Gender</label> at y: (11166, 11198), x: (538, 609), <input></input> at y: (11167, 11197), x: (621, 997), <span>select your gender</span> at y: (11167,

15 2 
 [<label>Donation</label> at y: (10976, 11008), x: (527, 609), <input></input> at y: (10977, 11007), x: (610, 937), <input></input> at y: (10977, 11007), x: (950, 996), <span></span> at y: (10977, 11007), x: (950, 996)] 
 [<label>Password</label> at y: (10696, 10728), x: (524, 609), <input></input> at y: (10701, 10723), x: (621, 979), <label>Confirm Password</label> at y: (10752, 10784), x: (468, 609), <input></input> at y: (10757, 10779), x: (621, 979)] 
 0.7977074501973495 


15 6 
 [<label>Donation</label> at y: (10976, 11008), x: (527, 609), <input></input> at y: (10977, 11007), x: (610, 937), <input></input> at y: (10977, 11007), x: (950, 996), <span></span> at y: (10977, 11007), x: (950, 996)] 
 [<label></label> at y: (10808, 10840), x: (503, 609), <input></input> at y: (10808, 10840), x: (609, 1009)] 
 0 


15 8 
 [<label>Donation</label> at y: (10976, 11008), x: (527, 609), <input></input> at y: (10977, 11007), x: (610, 937), <input></input> at y: (10977, 11007), x: (950,

25 11 
 [<label>Gender</label> at y: (11166, 11198), x: (538, 609), <input></input> at y: (11167, 11197), x: (621, 997), <span>select your gender</span> at y: (11167, 11197), x: (621, 997), <label>Captcha</label> at y: (11222, 11254), x: (541, 609)] 
 [<input></input> at y: (10921, 10951), x: (622, 668), <span>+86</span> at y: (10921, 10951), x: (622, 668), <label>Phone Number</label> at y: (10920, 10952), x: (489, 609), <input></input> at y: (10920, 10952), x: (680, 1009)] 
 0.8026090614856838 


25 15 
 [<label>Gender</label> at y: (11166, 11198), x: (538, 609), <input></input> at y: (11167, 11197), x: (621, 997), <span>select your gender</span> at y: (11167, 11197), x: (621, 997), <label>Captcha</label> at y: (11222, 11254), x: (541, 609)] 
 [<label>Donation</label> at y: (10976, 11008), x: (527, 609), <input></input> at y: (10977, 11007), x: (610, 937), <input></input> at y: (10977, 11007), x: (950, 996), <span></span> at y: (10977, 11007), x: (950, 996)] 
 0.8146041891392553 


25

In [62]:
scores

[(0, 2, 0.8135655690224038),
 (0, 6, 0),
 (0, 8, 0.7402768992024891),
 (0, 11, 0.8062587585593722),
 (0, 15, 0.7822140480207036),
 (0, 19, 0.823041756481178),
 (0, 20, 0),
 (0, 22, 0.787209070792227),
 (0, 25, 0.802094992993893),
 (0, 29, 0),
 (0, 33, 0),
 (0, 37, 0),
 (2, 0, 0.8135655690224038),
 (2, 6, 0),
 (2, 8, 0.7687151355234573),
 (2, 11, 0.8065038988546304),
 (2, 15, 0.7977074501973495),
 (2, 19, 0.8165486270878868),
 (2, 20, 0),
 (2, 22, 0.8141775524584348),
 (2, 25, 0.8605638459416701),
 (2, 29, 0),
 (2, 33, 0),
 (2, 37, 0),
 (6, 0, 0),
 (6, 2, 0),
 (6, 8, 0),
 (6, 11, 0),
 (6, 15, 0),
 (6, 19, 0),
 (6, 20, 0),
 (6, 22, 0),
 (6, 25, 0),
 (6, 29, 0),
 (6, 33, 0),
 (6, 37, 0),
 (8, 0, 0.7402768992024891),
 (8, 2, 0.7687151355234573),
 (8, 6, 0),
 (8, 11, 0.768813267646753),
 (8, 15, 0.7690852256580688),
 (8, 19, 0.7683330560673156),
 (8, 20, 0),
 (8, 22, 0.7739410245450628),
 (8, 25, 0.7947589417276986),
 (8, 29, 0),
 (8, 33, 0),
 (8, 37, 0),
 (11, 0, 0.8062587585593722),
 (11,

In [63]:
import numpy as np

In [64]:
non_zero_scores = list(filter(lambda x: x != 0, map(lambda x: x[2], scores)))

In [71]:
mean = np.mean(non_zero_scores)
std_dev = np.std(non_zero_scores)

# Set the cutoff to be one standard deviation above the mean
cutoff = mean + 0.5 * std_dev

In [72]:
print(mean)
print(std_dev)
print(cutoff)

0.7961643094449412
0.02396026763160815
0.8081444432607453


In [73]:
cutoff_score_list = list(filter(lambda x: x[2] > cutoff, scores))

In [74]:
cutoff_score_list

[(0, 2, 0.8135655690224038),
 (0, 19, 0.823041756481178),
 (2, 0, 0.8135655690224038),
 (2, 19, 0.8165486270878868),
 (2, 22, 0.8141775524584348),
 (2, 25, 0.8605638459416701),
 (15, 25, 0.8146041891392553),
 (19, 0, 0.823041756481178),
 (19, 2, 0.8165486270878868),
 (19, 25, 0.8117730955116259),
 (22, 2, 0.8141775524584348),
 (22, 25, 0.8387258151412464),
 (25, 2, 0.8605638459416701),
 (25, 15, 0.8146041891392553),
 (25, 19, 0.8117730955116259),
 (25, 22, 0.8387258151412464)]

In [75]:
for score in cutoff_score_list:
    nidx1, nidx2, s = score
    nodes1 = groups[nidx1]
    nodes2 = groups[nidx2]
    print(nodes1)
    print(nodes2)
    print()

[<label>E-mail</label> at y: (10640, 10672), x: (545, 609), <input></input> at y: (10640, 10672), x: (609, 1009)]
[<label>Password</label> at y: (10696, 10728), x: (524, 609), <input></input> at y: (10701, 10723), x: (621, 979), <label>Confirm Password</label> at y: (10752, 10784), x: (468, 609), <input></input> at y: (10757, 10779), x: (621, 979)]

[<label>E-mail</label> at y: (10640, 10672), x: (545, 609), <input></input> at y: (10640, 10672), x: (609, 1009)]
[<label>Website</label> at y: (11032, 11064), x: (533, 609)]

[<label>Password</label> at y: (10696, 10728), x: (524, 609), <input></input> at y: (10701, 10723), x: (621, 979), <label>Confirm Password</label> at y: (10752, 10784), x: (468, 609), <input></input> at y: (10757, 10779), x: (621, 979)]
[<label>E-mail</label> at y: (10640, 10672), x: (545, 609), <input></input> at y: (10640, 10672), x: (609, 1009)]

[<label>Password</label> at y: (10696, 10728), x: (524, 609), <input></input> at y: (10701, 10723), x: (621, 979), <labe

In [96]:
messages = [
    {
        'role': 'system',
        'content': '''
        Your task is to generate a set of assertions for form fields. The list of the assertions and their signatures is as following:

        1. toBeEqual(value) # for any input type
        2. toHaveLengthCondition(condition, value) # for textual inputs
        3. toBeTruthy() # for boolean inputs
        4. toHaveCondition(condition, value) # for numeric or date inputs
        5. toBeEmpty() # for any input type
        6. toMatch(regexPattern) # for textual inputs

        generate conditions in the as the sample:

        # sample
        expect(field('username'))
        .toHaveLengthCondition('>', 8)
        .toHaveLengthCondition('<', 50)
        .not.toBeEmpty()
        # end of sample

        if there are multiple inputs that should have the relation with each other, you can use a format such as:

        # sample
        expect(field('password'))
        .toBe(field('confirm password'))
        # end of sample

        Only generate the assertions and nothing else. Only generate assertions for the inputs in question, and not the ones in the relevant information section.
        '''
    },
    {
        'role': 'user',
        'content': """
            We are filling the following field in the form:
            <input id="register_password" aria-required="true" type="password" class="ant-input css-dfjnss">

            The relevant information available in the form are (in order of relevance):
            1. <label for="register_password" class="ant-form-item-required" title="Password">Password</label>
            2. <input id="register_confirm" aria-required="true" type="password" class="ant-input css-dfjnss">
            3. <label for="register_confirm" class="ant-form-item-required" title="Confirm Password">Confirm Password</label>
        """
    }
]

In [97]:
response = openai.ChatCompletion.create(
    # model='gpt-3.5-turbo',
    model='gpt-4',
    messages=messages,
    temperature=0.5,
)

In [98]:
response_text = response.choices[0].message.content
print(response_text)

expect(field('register_password'))
.toHaveLengthCondition('>', 8)
.toHaveLengthCondition('<', 50)
.not.toBeEmpty()
.toMatch(/^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)[a-zA-Z\d]{8,50}$/)

expect(field('register_password'))
.toBe(field('register_confirm'))
