## Data Loading and Initial Processing

Read dataset and prints the first 10 rows of the dataset to provide an initial view of the data.

In [2]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization
from bs4 import BeautifulSoup
import re
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

df=pd.read_csv("./data/html_cleaned_prompt_aligned.csv")

print('First 10 rows of data: ')
df.head(10)

First 10 rows of data: 


Unnamed: 0,prompt,output,is_valid,beautified_html
0,#907f54 active button with text contain,"<button style=""background:#907f54;"" class=""act...",True,"<button class=""active"" style=""background:#907f..."
1,#463658 button for treatment,"<button style=""background:#463658;"" disabled> ...",True,"<button disabled="""" style=""background:#463658;..."
2,#59849e disabled button for movie,"<button style=""background:#59849e;"" disabled> ...",True,"<button disabled="""" style=""background:#59849e;..."
3,Button with #bb6219 having down,"<button style=""background:#bb6219;"" disabled> ...",True,"<button disabled="""" style=""background:#bb6219;..."
4,Submit button having message,"<button style=""background:#973a18;"" disabled> ...",True,"<button disabled="""" style=""background:#973a18;..."
5,active button in #983e17 background with strategy,"<button style=""background:#983e17;"" class=""act...",True,"<button class=""active"" style=""background:#983e..."
6,Button in #7367a0 having choose,"<button style=""background:#7367a0;"" disabled> ...",True,"<button disabled="""" style=""background:#7367a0;..."
7,disabled clickable button for text yard,"<button style=""background:#51756a;"" disabled> ...",True,"<button disabled="""" style=""background:#51756a;..."
8,#e0197f colored button having return,"<button style=""background:#e0197f;"" class=""act...",True,"<button class=""active"" style=""background:#e019..."
9,Button in disabled state containing song,"<button style=""background:#47f621;"" disabled> ...",True,"<button disabled="""" style=""background:#47f621;..."


# 1. Data cleaning
Removing new lines, Improving spacings and adding "[start]" and "[end]" tokens to mark the beginning and end of each output sequence.

In [3]:
df = df.dropna(subset=["prompt", "output"])
def extract_label_from_prompt(prompt):
    # Use all capitalized or non-stopword tokens as label candidates
    words = prompt.strip().split()
    for word in reversed(words):
        w = word.strip(".,!?:;'").capitalize()
        if len(w) > 1 and w.isalpha():
            return w
    return "Submit"

def inject_text_into_html(row):
    soup = BeautifulSoup(row["output"], "html.parser")
    label = extract_label_from_prompt(row["prompt"])
    for tag in soup.find_all("button"):
        tag.string = label
    return str(soup)

df["output"] = df.apply(inject_text_into_html, axis=1)

# Add special tokens
start_token, end_token = "[start]", "[end]"
df["output"] = df["output"].apply(lambda x: f"{start_token} {x} {end_token}")

prompts = df["prompt"].astype(str).tolist()
outputs = df["output"].astype(str).tolist()
all_text = prompts + outputs

# 2. Input Extraction and Tokenization


In [4]:
max_seq_len = max(len(txt.split()) for txt in prompts + outputs)

print("max_seq_len: ", max_seq_len)

max_seq_len = 160 if max_seq_len > 160 else max_seq_len

print("updated max_seq_len: ", max_seq_len)

vectorizer = TextVectorization(
    output_mode='int',
    output_sequence_length=max_seq_len,
    standardize=None,
    split='whitespace'
)
vectorizer.adapt(all_text)

vocab = vectorizer.get_vocabulary()
vocab_size = len(vocab)
print('Vocab length: ', vocab_size)

max_seq_len:  200
updated max_seq_len:  160
Vocab length:  55414


# 3. Prepare dataset

In [5]:
def format_dataset(prompt, output):
    enc_tokens = vectorizer(prompt)
    dec_tokens = vectorizer(output)
    dec_input = tf.concat([[0], dec_tokens[:-1]], axis=0)  # shift right
    return (enc_tokens, dec_input), dec_tokens

dataset = tf.data.Dataset.from_tensor_slices((prompts, outputs))
dataset = dataset.map(lambda p, o: format_dataset(p, o))
dataset = dataset.shuffle(64).batch(16).prefetch(tf.data.AUTOTUNE)

print('Dataset is ready.')

Dataset is ready.


# 4. Model building

In [6]:
def transformer_model(vocab_size, seq_len):
    enc_inputs = layers.Input(shape=(seq_len,), dtype="int64")
    dec_inputs = layers.Input(shape=(seq_len,), dtype="int64")

    embed = layers.Embedding(vocab_size, 256)
    enc_emb = embed(enc_inputs)
    dec_emb = embed(dec_inputs)

    pos_enc = layers.Embedding(seq_len, 256)
    enc_emb += pos_enc(tf.range(start=0, limit=seq_len))
    dec_emb += pos_enc(tf.range(start=0, limit=seq_len))

    for _ in range(2):
        attn_out = layers.MultiHeadAttention(num_heads=2, key_dim=256)(dec_emb, enc_emb)
        x = layers.LayerNormalization()(attn_out + dec_emb)
        ffn = layers.Dense(512, activation="relu")(x)
        ffn = layers.Dense(256)(ffn)
        dec_emb = layers.LayerNormalization()(ffn + x)

    outputs = layers.Dense(vocab_size, activation="softmax")(dec_emb)
    return tf.keras.Model([enc_inputs, dec_inputs], outputs)

print("Model is ready to train")

Model is ready to train


# 5. Create and compile Model

In [7]:
model = transformer_model(vocab_size, max_seq_len)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()

# 6. Train model with dataset

In [8]:
callbacks = [
    EarlyStopping(patience=3, restore_best_weights=True),
    ModelCheckpoint("best_model.keras", save_best_only=True)
]

history = model.fit(dataset, epochs=100, callbacks=callbacks)

Epoch 1/100
[1m  33/5113[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:31:40[0m 1s/step - accuracy: 0.8487 - loss: 6.9255

KeyboardInterrupt: 

# 7. See loss and accuracy

In [None]:
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['accuracy'], label='accuracy')
plt.legend()
plt.show()

# 8. Post processing and HTML generation from prompt

In [94]:
def clean_generated_html(text):
    text = re.sub(r'(class="[^"]+")(?=\sclass=")', '', text)
    text = re.sub(r'(<[^ >]+)(?=\s|$)', r'\1>', text)
    text = re.sub(r'>\s+<', '><', text)
    text = re.sub(r'\s{2,}', ' ', text)
    text = re.sub(r'(\s+style="[^"]+"){2,}', lambda m: m.group(1), text)
    try:
        soup = BeautifulSoup(text, 'html.parser')
        cleaned = soup.prettify()
    except:
        cleaned = text
    return cleaned.strip()

def apply_repetition_penalty(logits, generated_ids, penalty=1.2):
    for idx in set(generated_ids):
        logits[idx] = logits[idx] / penalty
    return logits

# === Greedy Decoding with Repetition Penalty ===
def generate_html(prompt, seq_length=max_seq_len):
    input_tokens = vectorizer(tf.constant([prompt]))
    start_token_idx = vocab.index("[start]")
    end_token_idx = vocab.index("[end]")

    decoder_input = [start_token_idx]
    generated_ids = [start_token_idx]

    for _ in range(seq_length - 1):
        decoder_input_padded = decoder_input + [0] * (seq_length - len(decoder_input))
        decoder_tensor = tf.constant([decoder_input_padded])
        preds = model([input_tokens, decoder_tensor], training=False)
        logits = preds[0, len(generated_ids) - 1].numpy()

        # Apply repetition penalty
        logits = apply_repetition_penalty(logits, generated_ids, penalty=1.3)

        next_token = int(np.argmax(logits))
        if next_token == end_token_idx:
            break

        decoder_input.append(next_token)
        generated_ids.append(next_token)

    tokens = [vocab[idx] for idx in generated_ids if idx not in [0, start_token_idx, end_token_idx]]
    return clean_generated_html(" ".join(tokens))

# === Test Example ===
print(generate_html("red button with text submit"))
print(generate_html("simple button with text random"))


<button>
 class="active" style="background:#1c9b63;"&gt;Language
</button>
<button>
 class="active" style="background:#1c9b63;"&gt;Language
</button>


# 9. Save model

In [80]:
model.save("html_gen_transformer_working_v1.keras")