In [5]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvinyk-sd[0m ([33mvinyk-sd-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import json
import pandas as pd
from IPython.display import display, HTML

# Load attention JSON
with open("attention.json", "r", encoding="utf-8") as f:
    attention_data = json.load(f)

# Load predictions CSV
pred_df = pd.read_csv("predictions.csv")

In [3]:
def visualize_attention_with_predictions(attention_data, pred_df, return_html=False):
    html_head = """
    <style>
        .container { font-family: Arial, sans-serif; margin: 10px 0; }
        .row { display: flex; margin-bottom: 5px; }
        .char {
            padding: 8px;
            margin: 2px;
            border: 1px solid #ccc;
            border-radius: 4px;
            min-width: 20px;
            text-align: center;
            transition: background-color 0.3s;
            font-size: 20px;
        }
        .input-char { background-color: #eee; }
        .output-char { background-color: #f0f8ff; cursor: pointer; }
        .word-info { font-size: 16px; margin-top: 5px; }
    </style>
    <div class="container">
    """

    html_body = ""

    for idx, item in enumerate(attention_data):
        input_chars = item["input"]
        output_chars = item["output"]
        attention_matrix = item["attention"]

        input_id_prefix = f"input_{idx}_"
        output_id_prefix = f"output_{idx}_"

        html_body += f'<div><strong>Example {idx + 1}</strong></div>'
        html_body += '<div class="row">'
        for i, ch in enumerate(input_chars):
            html_body += f'<div class="char input-char" id="{input_id_prefix}{i}">{ch}</div>'
        html_body += '</div>'

        html_body += '<div class="row">'
        for j, out_ch in enumerate(output_chars):
            weights = attention_matrix[j]
            weight_str = ",".join([str(w) for w in weights])
            html_body += (
                f'<div class="char output-char" '
                f'onmouseover="highlightAttention([{weight_str}], \'{input_id_prefix}\')" '
                f'onmouseout="clearHighlight(\'{input_id_prefix}\', {len(input_chars)})">'
                f'{out_ch}</div>'
            )
        html_body += '</div>'

        if idx < len(pred_df):
            row = pred_df.loc[idx]
            html_body += (
                f'<div class="word-info">'
                f'<b>Input Word:</b> {row["Input"]}<br>'
                f'<b>Actual Word:</b> {row["Target"]}<br>'
                f'<b>Predicted Word:</b> {row["Prediction"]}'
                f'</div>'
            )

        html_body += '<hr style="margin:10px 0;">'

    html_tail = """
    </div>
    <script>
        function highlightAttention(weights, inputPrefix) {
            for (let i = 0; i < weights.length; i++) {
                const el = document.getElementById(inputPrefix + i);
                if (el) {
                    el.style.backgroundColor = `rgba(255, 165, 0, ${weights[i]})`;
                }
            }
        }

        function clearHighlight(inputPrefix, len) {
            for (let i = 0; i < len; i++) {
                const el = document.getElementById(inputPrefix + i);
                if (el) {
                    el.style.backgroundColor = '#eee';
                }
            }
        }
    </script>
    """

    html_content = html_head + html_body + html_tail

    if return_html:
        return html_content
    else:
        display(HTML(html_content))

In [4]:
# Create HTML content
html_content = visualize_attention_with_predictions(attention_data, pred_df, return_html=True)

# Save to file
html_file_path = "attention_viz.html"
with open(html_file_path, "w", encoding="utf-8") as f:
    f.write(html_content)

print(f"Saved visualization to {html_file_path}")

Saved visualization to attention_viz.html


In [6]:
run = wandb.init(project="Assignment3_Attempt1", name="attention-visual")  # <-- Update project name

artifact = wandb.Artifact("attention_visualization", type="html")
artifact.add_file("attention_viz.html")
run.log_artifact(artifact)

wandb.finish()
