In [None]:
# from google.colab import auth
# auth.authenticate_user()

# # https://cloud.google.com/resource-manager/docs/creating-managing-projects
# project_id = 'ottawadev-d26ce '
# !gcloud config set project {project_id}



# # prompt: download all objects from a google cloud bucket

# from google.colab import drive

# drive.mount('/content/drive')

# !gsutil -m cp -r gs://webleaf /content/drive/MyDrive/product_page_dataset/html

# !gsutil -m cp -r gs://webleaftest /content/drive/MyDrive/product_page_dataset/html

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence-transformers
Successfully installed sentence-transformers-3.0.1


In [None]:
import torch
import torch.nn as nn
import os

html_tags = [
    'a', 'abbr', 'address', 'area', 'article', 'aside', 'audio', 'b', 'base', 'bdi', 'bdo', 'blockquote',
    'body', 'br', 'button', 'canvas', 'caption', 'cite', 'code', 'col', 'colgroup', 'data', 'datalist', 'dd',
    'del', 'details', 'dfn', 'dialog', 'div', 'dl', 'dt', 'em', 'embed', 'fieldset', 'figcaption', 'figure',
    'footer', 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'head', 'header', 'hr', 'html', 'i', 'iframe', 'img',
    'input', 'ins', 'kbd', 'label', 'legend', 'li', 'link', 'main', 'map', 'mark', 'meter', 'nav', 'noscript',
    'object', 'ol', 'optgroup', 'option', 'output', 'p', 'param', 'picture', 'pre', 'progress', 'q', 'rp', 'rt', 'ruby',
    's', 'samp', 'section', 'select', 'small', 'source', 'span', 'strong', 'sub', 'summary', 'sup',
    'table', 'tbody', 'td', 'template', 'textarea', 'tfoot', 'th', 'thead', 'time', 'title', 'tr', 'track', 'u', 'ul',
    'var', 'video', 'wbr'
]



class NormalizedEmbedding(nn.Module):
    def __init__(self, n_classes, m_dimensions):
        super(NormalizedEmbedding, self).__init__()
        # Create the embedding layer
        self.embedding = nn.Embedding(n_classes, m_dimensions)

        # Initialize the embedding weights randomly
        nn.init.xavier_uniform_(self.embedding.weight)

    def forward(self, x):
        # Get the embeddings
        embed = self.embedding(x)

        # Normalize the embeddings to have unit length
        normalized_embed = embed / embed.norm(dim=1, keepdim=True)
        return normalized_embed

TAG_PATH = f'/content/drive/MyDrive/product_page_dataset/tag_embeddings.pkl'

class TagEmbeddingModel:
    def __init__(self):
        self.embedding_model = NormalizedEmbedding(len(html_tags), 8)
        if os.path.exists(TAG_PATH):
            self.embedding_model.load_state_dict(torch.load(TAG_PATH))
        else:
            torch.save(self.embedding_model.state_dict(), TAG_PATH)
    def get_tag_embedding(self, tags):
        return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))


In [None]:
from sentence_transformers import SentenceTransformer
import nltk
from nltk.tokenize import sent_tokenize
from functools import lru_cache

TEXT_EMBEDDING = 384

class TextEmbeddingModel:
    def __init__(self):
        nltk.download('punkt')
        self.model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')

    def get_text_embeddings(self, text):
        sentences = []
        for t in text:
            sentence = sent_tokenize(t)
            if sentence:
                sentences.append(sentence[0])
            else:
                sentences.append("")
        return self.model.encode(sentences)


  from tqdm.autonotebook import tqdm, trange


In [None]:
from lxml import etree
import torch
import os
import re

class Web:
    def __init__(self):
        self.tag_model = TagEmbeddingModel()
        self.text_model = TextEmbeddingModel()

    def extract(self, html):
        (tree = etree.ElementTree(etree.HTML(html))
        root = tree.getroot()

        # List of formatting tags we want to remove
        formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins']

        etree.strip_tags(root, *formatting_tags)


        stack = [(root, 0)]
        tag_lookup = set(html_tags)
        if tree:
            i = 0
            texts = [""]
            tags = [root.tag]
            edge_index = []
            masks = [False]
            metas = [{'path': 'root', 'text': ""}]
            while stack:
                element, parent_id = stack.pop(0)

                if element.tag == "div" and len(element) == 1:
                    stack.append([element[0], parent_id])
                    continue

                for index, child in enumerate(element):
                    if isinstance(child, etree._Comment):
                        continue

                    if child.tag in tag_lookup:
                        tags.append(child.tag)
                        text = self.extract_text(child)[:256]
                        texts.append(text)
                        masks.append(bool(text))
                        metas.append({'path': tree.getpath(child), 'text': text, 'tag': child.tag})
                        i += 1
                        edge_index.append([parent_id, i])
                        stack.append((child, i))

            text_embeddings = self.text_model.get_text_embeddings(texts)
            tag_embeddings = self.tag_model.get_tag_embedding(tags)
            x = []
            for i in range(len(text_embeddings)):
                x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i])))

        return torch.stack(x), torch.tensor(edge_index, dtype=torch.int64), torch.tensor(masks, dtype=torch.bool), metas

    def clean_text(self, text):
        if not text:
            return
        cleaned_text = ' '.join(re.sub(r'[^a-zA-Z\s.,!?\'\";:]', '', text).split())
        return cleaned_text

    def extract_text(self, element) -> str:
        text = self.clean_text(element.text)
        if text:
            return text

        for label in ["alt", "tite", "aria-label"]:
            text = self.clean_text(element.get(label))
            if text:
                return text
        return ""

In [None]:
import pickle

web = Web()


  self.embedding_model.load_state_dict(torch.load(TAG_PATH))
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/11.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

# Train

In [None]:
import pickle
import os
webleaf_path = "/content/drive/MyDrive/product_page_dataset/html/webleaftest"
dataset_path = "/content/drive/MyDrive/product_page_dataset/dataset/test"

BATCH_SIZE = 100
BATCH_STOP = 1000


batchx = torch.tensor([])
batche = torch.tensor([])
batchmask = torch.tensor([], dtype=torch.bool)
metas = []
index = 0
html_index = 0
batch_index = 0

for batch in range(BATCH_STOP):
    print("Processing:", batch_index)
    filename = f"webleaf{batch}.pkl"
    if os.path.exists(os.path.join(webleaf_path, filename)):
        htmls = pickle.load(open(os.path.join(webleaf_path, filename), "rb"))
        for html in htmls:
            x, e, mask, meta = web.extract(html)
            e = e + index
            index += torch.tensor(len(x), dtype=torch.int32)
            batchx = torch.cat((batchx, x))
            batche = torch.cat((batche, e))
            batchmask = torch.cat((batchmask, mask))
            metas = metas + meta
            html_index += 1
            if html_index >= BATCH_SIZE:
                torch.save(batchx, os.path.join(dataset_path, f"{batch_index}.x"))
                torch.save(batche.to(torch.int64).permute(1, 0), os.path.join(dataset_path, f"{batch_index}.e"))
                torch.save(batchmask, os.path.join(dataset_path, f"{batch_index}.mask"))
                pickle.dump(metas, open(os.path.join(dataset_path, f"{batch_index}.meta"), "wb"))
                batchx = torch.tensor([])
                batche = torch.tensor([])
                batchmask = torch.tensor([], dtype=torch.bool)
                metas = []
                batch_index += 1
                html_index = 0
                index = 0
                print(batch_index)
    else:
        print("could not find", filename)
        break

Processing: 0
1
Processing: 1
2
Processing: 2
3
Processing: 3
4
Processing: 4
5
Processing: 5
6
Processing: 6
7
Processing: 7
8
Processing: 8
9
Processing: 9
10
Processing: 10
11
Processing: 11
12
Processing: 12
13
Processing: 13
14
Processing: 14
15
Processing: 15
16
Processing: 16
17
Processing: 17
18
Processing: 18
19
Processing: 19
20
Processing: 20
21
Processing: 21
22
Processing: 22
23
Processing: 23
24
Processing: 24
25
Processing: 25
26
Processing: 26
27
Processing: 27
28
Processing: 28
29
Processing: 29
30
Processing: 30
31
Processing: 31
32
Processing: 32
33
Processing: 33
34
Processing: 34
35
Processing: 35
36
Processing: 36
37
Processing: 37
38
Processing: 38
39
Processing: 39
40
Processing: 40
41
Processing: 41
42
Processing: 42
43
Processing: 43
44
Processing: 44
45
Processing: 45
46
Processing: 46
47
Processing: 47
48
Processing: 48
49
Processing: 49
50
Processing: 50
51
Processing: 51
52
Processing: 52
53
Processing: 53
54
Processing: 54
55
Processing: 55
56
Processing:

In [None]:
batche.to(torch.int32)

tensor([], dtype=torch.int32)

In [None]:
batchx.shape

torch.Size([0])