In [1]:
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
import matplotlib.lines as mlines
import textwrap
from typing import List, Dict
import html
import re
from html2image import Html2Image
import numpy as np
from io import BytesIO
import base64

import matplotlib
matplotlib.use('Agg')

In [8]:
class StoryGenerator:
    def __init__(self, stories: List[Dict], arrows: List[Dict], plot_data: Dict):
        self.character_words = {'carl', 'phil'}  # sky blue
        self.container_words = {'tote', 'flute'}  # maroon
        self.object_words = {'cider', 'vodka', 'rum', 'champagne'}  # dark green
        self.arrows = arrows
        self.stories = stories
        self.target = "Target: champagne"
        self.plot_data = plot_data

    def color_text(self, text: str) -> str:
        """Apply color formatting to specific words in the text."""
        words = text.split()
        colored_words = []

        for word in words:
            clean_word = re.sub(r'[^\w\s]', '', word.lower())  # Remove punctuation

            if clean_word in self.character_words:
                colored_word = f'<span style="color:skyblue; font-weight:bold;">{word}</span>'
            elif clean_word in self.container_words:
                colored_word = f'<span style="color:maroon; font-weight:bold;">{word}</span>'
            elif clean_word in self.object_words:
                colored_word = f'<span style="color:darkgreen; font-weight:bold;">{word}</span>'
            else:
                colored_word = word

            colored_words.append(colored_word)

        return ' '.join(colored_words)

    def generate_html(self) -> str:
        """Generate HTML content with stories, arrows, and a right-aligned plot."""
        html_content = """
<!DOCTYPE html>
<html>
<head>
    <style>
        body {
            font-family: monospace;
            white-space: pre-wrap;
            line-height: 1.2;
            margin: 0 5px;
            background-color: white;
            position: relative;
            font-size: 18px;
            display: flex;
            height: 400px;
            width: 1300px;
        }
        .left-container {
            flex: 1;
            display: flex;
            max-width: 60%;
            flex-direction: column;
            justify-content: space-around;
        }
        .right-container {
            display: flex;
            flex-direction: column;
            flex: 0 0 40%;
        }
        .story-container {
            position: relative;
            border: 1px solid black;
            padding: 5px;
        }
        .target-container {
            border: 1px solid black;
            display: inline-block;
            padding: 5px;
        }
        .svg-container {
            position: absolute;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            pointer-events: none;
        }
        .label {
            background-color: brown;
            color: white;
            padding: 2px 5px;
            margin-bottom: 2px;
            display: inline-block;
            font-weight: bold;
        }
        img {
            height: 93%;
            width: -webkit-fill-available;
            margin: auto;
            position: relative;
            top: 15px;
        }
    </style>
</head>
<body>
<div class="left-container">
"""

        # Add story blocks with labels
        for i, story in enumerate(self.stories):
            colored_story = self.color_text(story['story'])
            colored_question = self.color_text(story['question'])
            colored_answer = self.color_text(story['answer'])

            if i == 0:
                html_content += f"""<div class="box-wrapper"><div class="label">Alternate</div><div class="story-container" id="story-{i}">Story: {colored_story}<br>Question: {colored_question}<br>Answer: {colored_answer}</div></div>"""
            else:
                html_content += f"""<div class="box-wrapper"><div class="label">Original</div><div class="story-container" id="story-{i}">Story: {colored_story}<br>Question: {colored_question}<br>Answer: {colored_answer}</div></div>"""

        # Add target with label
        html_content += f"""<div class="target-container">Target: {self.color_text(self.target)}</div>"""

        # Add SVG overlay for arrows
        for i, arrow in enumerate(self.arrows):
            try:
                start_x, start_y = arrow['start']
                end_x, end_y = arrow['end']
                color = arrow.get('color', 'black')

                if start_x == end_x and start_y == end_y:
                    html_content += f"""
<svg class="svg-container">
    <path d="M {start_x},{start_y} C {start_x + 75},{start_y - 75} {end_x - 75},{end_y - 75} {end_x - 10},{end_y}"
          fill="none"
          stroke="{color}"
          stroke-width="4"
          style="paint-order: stroke fill;"
          marker-end="url(#arrowhead_{i})"/>
    <defs>
        <marker id="arrowhead_{i}" markerWidth="6" markerHeight="4.2" refX="4.5" refY="2.1" orient="auto">
            <polygon points="0 0, 6 2.1, 0 4.2" fill="{color}"/>
        </marker>
    </defs>
</svg>
"""
                else:
                    html_content += f"""
<svg class="svg-container">
    <path d="M {start_x},{start_y} {end_x},{end_y}"
          fill="none"
          stroke="{color}"
          stroke-width="4"
          style="paint-order: stroke fill;"
          marker-end="url(#arrowhead_{i})"/>
    <defs>
        <marker id="arrowhead_{i}" markerWidth="6" markerHeight="4.2" refX="4.5" refY="2.1" orient="auto">
            <polygon points="0 0, 6 2.1, 0 4.2" fill="{color}"/>
        </marker>
    </defs>
</svg>
"""
            except KeyError as e:
                print(f"Arrow configuration error: {e}")

        html_content += """
</div>
<div class="right-container">
"""

        # Generate the line plot
        x = np.arange(len(self.plot_data['labels']))
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.plot(x, self.plot_data['acc_one_layer'], marker='o', color='black', linestyle='-', label='One layer')
        ax.plot(x, self.plot_data['acc_upto_layer'], marker='*', color='black', linestyle='-', label='Upto layer')
        ax.plot(x, self.plot_data['acc_from_layer'], marker='^', color='black', linestyle='-', label='From layer')

        ax.set_xticks(x)
        ax.set_xticklabels(self.plot_data['labels'])
        ax.set_title(self.plot_data['title'])
        ax.set_xlabel(self.plot_data['x_label'])
        ax.set_ylabel(self.plot_data['y_label'], color='black')
        ax.set_ylim(-0.1, 1.1)
        ax.tick_params(axis='y', labelcolor='black')
        ax.legend()
        ax.grid(True)

        # Increase the marker size
        for line in ax.get_lines():
            line.set_markersize(8)

        if 'prob_one_layer' in self.plot_data:
            # Rotate the x-axis labels
            plt.xticks(rotation=90)

            ax2 = ax.twinx()
            ax2.set_ylabel('Probability', color='deeppink')
            ax2.set_ylim(-0.1, 1.1)
            ax2.tick_params(axis='y', labelcolor='deeppink')

            ax2.plot(x, self.plot_data['prob_one_layer'], marker='o', color='hotpink', linestyle='--', label='One layer')
            ax2.plot(x, self.plot_data['prob_upto_layer'], marker='*', color='hotpink', linestyle='--', label='Upto layer')
            ax2.plot(x, self.plot_data['prob_from_layer'], marker='^', color='hotpink', linestyle='--', label='From layer')

            # Change the opacity of plot lines
            for line in ax2.get_lines():
                line.set_alpha(0.5)

            # Change the font family to times new roman
            for item in ([ax.title, ax.xaxis.label, ax.yaxis.label, ax2.yaxis.label] +
                        ax.get_xticklabels() + ax.get_yticklabels() + ax2.get_yticklabels()):
                item.set_fontsize(18)
                item.set_fontname('Times New Roman')

        else:
            # Change the font family to times new roman
            for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                        ax.get_xticklabels() + ax.get_yticklabels()):
                item.set_fontsize(18)
                item.set_fontname('Times New Roman')


        plt.tight_layout()

        buf = BytesIO()
        plt.savefig(buf, format='png', dpi=300)
        plot_data = buf.getvalue()
        buf.close()

        # Encode the plot as a base64 URI
        plot_data_uri = f'data:image/png;base64,{base64.b64encode(plot_data).decode("utf-8")}'
        html_content += f'<img src="{plot_data_uri}" alt="Plot"/>'

        html_content += """
</div>
</body>
</html>"""

        return html_content

    def save_html(self, filename: str = "../plots/experiments/output.html"):
        with open(filename, "w", encoding="utf-8") as f:
            f.write(self.generate_html())

    def save_image(self, filename: str = "output.png"):
        hti = Html2Image(output_path="../plots/experiments/")
        hti.screenshot(html_str=self.generate_html(), save_as=filename)

In [12]:
plot_data = {
    'labels': [i for i in range(0, 126, 10)],
    'acc_one_layer': [min(1, i/100) for i in range(0, 126, 10)],
    'acc_upto_layer': [min(1, i/100 + 0.1) for i in range(0, 126, 10)],
    'acc_from_layer': [min(1, max(i/100 - 0.1, 0)) for i in range(0, 126, 10)],
    'title': 'Aligning Value Fetcher Variable',
    'x_label': 'Layers',
    'y_label': 'Intervention Accuracy'
}

generator = StoryGenerator(stories, arrows, plot_data)
generator.save_html()

In [6]:
arrows = [{'start': token_pos_coords['e1_charac1'], 'end': token_pos_coords['e2_charac2'], 'color': 'skyblue'},
          {'start': token_pos_coords['e1_charac2'], 'end': token_pos_coords['e2_charac1'], 'color': 'skyblue'},
          {'start': token_pos_coords['e1_obj1'], 'end': token_pos_coords['e2_obj2'], 'color': 'maroon'},
          {'start': token_pos_coords['e1_obj2'], 'end': token_pos_coords['e2_obj1'], 'color': 'maroon'},
          {'start': token_pos_coords['e2_state2'], 'end': token_pos_coords['e2_state2'], 'color': 'darkgreen'},
          {'start': token_pos_coords['e2_state1'], 'end': token_pos_coords['e2_state1'], 'color': 'darkgreen'}]

In [4]:
stories =  [
            {
                "story": "Carl and Phil are working in entirely separate sections of a busy restaurant, with no visibility between them. To complete an order, Carl grabs an opaque tote and fills it with cider. Then Phil grabs another opaque flute and fills it with vodka.",
                "question": "What does the flute contain?",
                "answer": "vodka"
            },
            {
                "story": "Carl and Phil are working in entirely separate sections of a busy restaurant, with no visibility between them. To complete an order, Phil grabs an opaque flute and fills it with rum. Then Carl grabs another opaque tote and fills it with champagne.",
                "question": "What does the flute contain?",
                "answer": "rum"
            }
        ]

In [5]:
token_pos_coords = {
    "e1_last": (95, 170),
    "e2_last": (95, 325),
    "e1_query_obj_real": (310, 145),
    "e1_query_obj_belief": (450, 145),
    "e2_query_obj_real": (310, 300),
    "e2_query_obj_belief": (450, 300),
    "e1_query_charac": (260, 145),
    "e2_query_charac": (260, 300),
    "e1_obj1": (360, 105),
    "e1_obj2": (280, 125),
    "e2_obj1": (360, 260),
    "e2_obj2": (280, 280),
    "e1_state1": (610, 105),
    "e1_state2": (550, 125),
    "e2_state1": (610, 260),
    "e2_state2": (550, 280),
    "e1_charac1": (120, 105),
    "e1_charac2": (730, 105),
    "e2_charac1": (120, 255),
    "e2_charac2": (720, 250),
}