In [11]:
import os
import numpy as np

WORDING_CLASS = "Angiosperm Family"
WORDING_CLASSES = "Angiosperm Families"

DATA = np.load('website_data.npy', allow_pickle=True).item()
DATA_CONCEPTS = np.load('website_data_concepts.npy', allow_pickle=True).item()
ID_TO_CLS = DATA['id']
CLASS_NAMES = [ID_TO_CLS[i] for i in range(len(list(ID_TO_CLS.keys())))]
ALPHA = 0.95

slack_blue = f"rgba(54,197,240, {ALPHA})"
slack_green = f"rgba(46,182,125, {ALPHA})"
slack_red = f"rgba(210,40,95, {ALPHA})"
slack_yellow = f"rgba(236,178,46, {ALPHA})"
slack_violet = f"rgba(84,25,85, {ALPHA})"

google_blue = f"rgba(66,133,244, {ALPHA})"
google_red = f"rgba(219,68,55, {ALPHA})"
google_yellow = f"rgba(244,180,0, {ALPHA})"
google_green = f"rgba(15,157,88, {ALPHA})"

deep_purple = f"rgba(103,58,183, {ALPHA})"
pink = f"rgba(236,64,122, {ALPHA})"
anthracite = f"rgba(13,13,21, {ALPHA})"

COLORS = [
    google_blue,
    slack_red,
    slack_yellow,
    slack_green,
    slack_violet,
    anthracite,
    pink,
    slack_blue,
    deep_purple,
    google_red,
]

In [3]:
def class_scatter_plot(x, y, container_id, colors, point_size=5, labels=None, chart_title="Scatter plot", radius=0.1):
    dataset = []
    for i in range(len(x)):
        label = f"'{labels[i]}'" if labels is not None else "''"
        point = f"{{x: {x[i]}, y: {y[i]}, r: {point_size}, label: {label}, cls_id: {i}}}"
        dataset.append(point)

    background_colors = [f'"{color}"' for color in colors]
    border_colors = background_colors

    return f"""
    const ctx_cls = document.getElementById('{container_id}').getContext('2d');
    const chart_cls = new Chart(ctx_cls, {{
        type: 'bubble',
        options: {{
            plugins: {{
                legend: {{
                    display: false
                }},
                title: {{
                    display: false,
                    text: '{chart_title}'
                }},
                tooltip: {{
                    enabled: false, // Disable default tooltips
                    external: function(context) {{
                        let tooltipEl = document.getElementById('chartjs-tooltip');
                        if (!tooltipEl) {{
                            tooltipEl = document.createElement('div');
                            tooltipEl.id = 'chartjs-tooltip';
                            tooltipEl.style.position = 'absolute';
                            tooltipEl.style.background = 'rgba(255, 255, 255, 0.9)';
                            tooltipEl.style.border = '1px solid #ccc';
                            tooltipEl.style.borderRadius = '5px';
                            tooltipEl.style.padding = '10px';
                            tooltipEl.style.pointerEvents = 'none';
                            tooltipEl.style.textAlign = 'center';
                            document.body.appendChild(tooltipEl);
                        }}

                        const tooltipModel = context.tooltip;
                        if (!tooltipModel) {{
                            tooltipEl.style.opacity = '0';
                            return;
                        }}

                        // Mouse position in chart coordinates
                        const chart = context.chart;
                        const mouseX = chart.scales.x.getValueForPixel(tooltipModel.caretX);
                        const mouseY = chart.scales.y.getValueForPixel(tooltipModel.caretY);

                        // Calculate distances and filter points within the radius
                        const radius = {radius};
                        const points = chart.data.datasets[0].data
                            .map((point, index) => {{
                                const dx = point.x - mouseX;
                                const dy = point.y - mouseY;
                                const distance = Math.sqrt(dx * dx + dy * dy);
                                return {{ distance, point, index }};
                            }})
                            .filter(item => item.distance <= radius)
                            .sort((a, b) => a.distance - b.distance);

                        // Hide if no points in range
                        if (points.length === 0) {{
                            tooltipEl.style.opacity = '0';
                            return;
                        }}

                        // Build tooltip content
                        const content = points.map(({{point}}) => {{
                            const label = point.label;
                            const cls_id = point.cls_id;
                            const link = `/classes/${{label}}`;
                            return `<strong><a href="${{link}}" target="_blank">${{label}}</a></strong>`;
                        }}).join('<br>');

                        tooltipEl.innerHTML = `
                            <div style="color: black; font-size:18px">
                                ${{content}}
                            </div>`;

                        // Position tooltip
                        const position = chart.canvas.getBoundingClientRect();
                        tooltipEl.style.opacity = '1';
                        tooltipEl.style.left = position.left + window.pageXOffset + tooltipModel.caretX + 'px';
                        tooltipEl.style.top = position.top + window.pageYOffset + tooltipModel.caretY + 'px';
                        tooltipEl.style.zIndex = '100';
                    }}
                }}
            }},
            scales: {{
                x: {{
                    display: false
                }},
                y: {{
                    display: false
                }}
            }},
            onClick: function(event, elements) {{
                if (elements.length > 0) {{
                    const index = elements[0].index;
                    const datasetIndex = elements[0].datasetIndex;
                    const dataPoint = this.data.datasets[datasetIndex].data[index];
                    const label = dataPoint.label;
                    const link = `./classes/${{label}}`;
                    window.open(link, '_blank'); // Open the link in a new tab
                }}
            }}
        }},
        data: {{
            datasets: [{{
                label: '{chart_title}',
                data: [{', '.join(dataset)}],
                backgroundColor: [{', '.join(background_colors)}],
                borderColor: [{', '.join(border_colors)}],
                borderWidth: 1
            }}]
        }}
    }});
    """.strip()

umap_cls = np.load('umap_cls.npy')
umap_dictionary = np.load('umap_dictionary.npy')

js_cls = class_scatter_plot(umap_cls[:len(CLASS_NAMES), 0], umap_cls[:len(CLASS_NAMES), 1], "scatterClass",
                      [COLORS[0] for _ in range(len(CLASS_NAMES))], point_size=5,
                      labels=CLASS_NAMES,
                      chart_title="Classes")

if os.path.exists('docs/js/class_umap.js'):
    os.remove('docs/js/class_umap.js')

with open('docs/js/class_umap.js', 'w') as f:
    f.write(js_cls)

print('done')

done


In [12]:
def dictionary_scatter_plot(x, y, container_id, colors, point_size=5, labels=None, cls_id=None, chart_title="Scatter plot", radius=0.07):
    dataset = []
    for i in range(len(x)):
        label = f"'{labels[i]}'" if labels is not None else "''"
        point = f"{{x: {x[i]}, y: {y[i]}, r: {point_size}, label: {label}, cls_id: {cls_id[i]}}}"
        dataset.append(point)

    background_colors = [f'"{color}"' for color in colors]
    border_colors = background_colors

    return f"""
    const ctx_dico = document.getElementById('{container_id}').getContext('2d');
    const chart_dico = new Chart(ctx_dico, {{
        type: 'bubble',
        options: {{
            plugins: {{
                legend: {{
                    display: false
                }},
                title: {{
                    display: false,
                    text: '{chart_title}'
                }},
                tooltip: {{
                    enabled: false, // Disable default tooltips
                    external: function(context) {{
                        let tooltipEl = document.getElementById('chartjs-tooltip');
                        if (!tooltipEl) {{
                            tooltipEl = document.createElement('div');
                            tooltipEl.id = 'chartjs-tooltip';
                            tooltipEl.style.position = 'absolute';
                            tooltipEl.style.background = 'rgba(255, 255, 255, 0.9)';
                            tooltipEl.style.border = '1px solid #ccc';
                            tooltipEl.style.borderRadius = '5px';
                            tooltipEl.style.padding = '10px';
                            tooltipEl.style.pointerEvents = 'none';
                            tooltipEl.style.textAlign = 'center';
                            document.body.appendChild(tooltipEl);
                        }}

                        const tooltipModel = context.tooltip;
                        if (!tooltipModel) {{
                            tooltipEl.style.opacity = '0';
                            return;
                        }}

                        // Mouse position in chart coordinates
                        const chart = context.chart;
                        const mouseX = chart.scales.x.getValueForPixel(tooltipModel.caretX);
                        const mouseY = chart.scales.y.getValueForPixel(tooltipModel.caretY);

                        // Calculate distances and filter points within the radius
                        const radius = {radius};
                        const points = chart.data.datasets[0].data
                            .map((point, index) => {{
                                const dx = point.x - mouseX;
                                const dy = point.y - mouseY;
                                const distance = Math.sqrt(dx * dx + dy * dy);
                                return {{ distance, point, index }};
                            }})
                            .filter(item => item.distance <= radius)
                            .sort((a, b) => a.distance - b.distance);

                        // Hide if no points in range
                        if (points.length === 0) {{
                            tooltipEl.style.opacity = '0';
                            return;
                        }}

                        // Build tooltip content
                        const content = points.map(({{point}}) => {{
                            const label = point.label;
                            const cls_id = point.cls_id;
                            const link = `./concepts/${{label}}`;
                            return `
                            <div style="width: 500px; text-align: center;"> <!-- Fixed width -->
                                <strong><a href="${{link}}" target="_blank" style="text-decoration: none; color: black;">${{label}}</a></strong>
                                <br>
                                <div style="display: flex; flex-direction: row; justify-content: center; width: 100%; box-sizing: border-box;">
                                    <img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_${{cls_id}}_fv.webp"
                                        style="width: 50%; object-fit: contain;"/>
                                    <img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_${{cls_id}}_0.webp"
                                        style="width: 25%; object-fit: contain;"/>
                                    <img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_${{cls_id}}_1.webp"
                                        style="width: 25%; object-fit: contain;"/>
                                </div>
                            </div>
                            <br>
                            `;
                        }}).join('<br>');

                        tooltipEl.innerHTML = `
                            <div style="color: black; font-size:18px">
                                ${{content}}
                            </div>`;

                        // Position tooltip
                        const position = chart.canvas.getBoundingClientRect();
                        tooltipEl.style.opacity = '1';
                        tooltipEl.style.left = position.left + window.pageXOffset + tooltipModel.caretX + 'px';
                        tooltipEl.style.top = position.top + window.pageYOffset + tooltipModel.caretY + 'px';
                        tooltipEl.style.zIndex = '100';
                    }}
                }}
            }},
            scales: {{
                x: {{
                    display: false
                }},
                y: {{
                    display: false
                }}
            }},
            onClick: function(event, elements) {{
                if (elements.length > 0) {{
                    const index = elements[0].index;
                    const datasetIndex = elements[0].datasetIndex;
                    const dataPoint = this.data.datasets[datasetIndex].data[index];
                    const label = dataPoint.label;
                    const link = `./concepts/${{label}}`;
                    window.open(link, '_blank'); // Open the link in a new tab
                }}
            }}
        }},
        data: {{
            datasets: [{{
                label: '{chart_title}',
                data: [{', '.join(dataset)}],
                backgroundColor: [{', '.join(background_colors)}],
                borderColor: [{', '.join(border_colors)}],
                borderWidth: 1
            }}]
        }}
    }});
    """.strip()


umap_dictionary = np.load('umap_dictionary.npy')
is_alive = np.load('ids_alive.npy')

js_dico = dictionary_scatter_plot(umap_dictionary[is_alive, 0], umap_dictionary[is_alive, 1], "scatterDico",
                      [COLORS[0] for _ in range(len(is_alive))], point_size=1,
                      labels=[f'Concept {i}' for i in is_alive],
                      cls_id=is_alive,
                      chart_title="Dictionary")

if os.path.exists('docs/js/dico_umap.js'):
    os.remove('docs/js/dico_umap.js')

with open('docs/js/dico_umap.js', 'w') as f:
    f.write(js_dico)

print('done')

done


In [13]:
!rm -rf docs/classes/*
!rm -rf docs/concepts/*

In [14]:
limit = 1_000_000

def block_feature_viz(concept_id, score, order, alpha=1.0, nb_crops=10):
    if score > 10.0 or order < 3:
        color = COLORS[order]
    else:
        color = f"var(--text-color)"
    content = f"""
<div class="gallery-container-img" style="border-color: {color}; color: {color}"  id="card-{concept_id}">
    <span style="color: {color}; z-index:3"><small>Concept</small> {concept_id} (<small>Importance</small> {score:.2f} %)</span>
    <img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_{concept_id}_fv.webp" class="gallery-img">
    <div class="hidden-images">
        {''.join([f'<img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_{concept_id}_{i}.webp" class="surprise-img img{i+1}">' for i in range(nb_crops)])}
    </div>
</div>
    """
    return content

def get_class_page_template(d):
    name = d['name']

    similarity_cls, similarity_score = d['similarity']
    importance_concepts, importance_scores = d['importance']

    Wc = d['Wc']
    nb_points = d['nb_points']

    # normalize the importance scores
    scores = np.array(importance_scores)
    scores = scores / scores.sum()
    scores = scores * 100

    blocks = []
    j = 0
    for concept_id, score in zip(importance_concepts[:15], scores):
        blocks.append(block_feature_viz(concept_id, score, j))
        j += 1

    content = f"""
# {name}

The {WORDING_CLASS} {name} contains <u><b>{nb_points}</b></u> data points.
The closest {WORDING_CLASSES} to {name} are
[{ID_TO_CLS[similarity_cls[1]]}](../classes/{ID_TO_CLS[similarity_cls[1]]}.md) ({similarity_score[1]:.2f} similarity),
[{ID_TO_CLS[similarity_cls[2]]}](../classes/{ID_TO_CLS[similarity_cls[2]]}.md) ({similarity_score[2]:.2f} similarity),
[{ID_TO_CLS[similarity_cls[3]]}](../classes/{ID_TO_CLS[similarity_cls[3]]}.md) ({similarity_score[3]:.2f} similarity),
and [{ID_TO_CLS[similarity_cls[4]]}](../classes/{ID_TO_CLS[similarity_cls[4]]}.md) ({similarity_score[4]:.2f} similarity).

{name} are mainly characterized by the following concepts:

  - [Concept {importance_concepts[0]}](../concepts/Concept {importance_concepts[0]}.md) with an importance score of {importance_scores[0]:.2f}
  - [Concept {importance_concepts[1]}](../concepts/Concept {importance_concepts[1]}.md) with an importance score of {importance_scores[1]:.2f}
  - [Concept {importance_concepts[2]}](../concepts/Concept {importance_concepts[2]}.md) with an importance score of {importance_scores[2]:.2f}
  - [Concept {importance_concepts[3]}](../concepts/Concept {importance_concepts[3]}.md) with an importance score of {importance_scores[3]:.2f}
  - [Concept {importance_concepts[4]}](../concepts/Concept {importance_concepts[4]}.md) with an importance score of {importance_scores[4]:.2f}

# Feature Visualization

The images below display the feature visualizations (or maximally activating images) for each concept.
Click on any feature visualization to view the top 10 images most strongly activating that concept.

<div class="feature-viz-intro">
    <div class="gallery">
        {' '.join(blocks)}
    </div>
</div>

"""

    return content

# write the class pages
for cls_id in list(DATA['class_info'].keys())[:limit]:
    d = DATA['class_info'][cls_id]
    name = d['name']

    # remove the file if it exists
    fn = f"docs/classes/{name}.md"
    if os.path.exists(fn):
        os.remove(fn)

    cls_page = get_class_page_template(d)

    with open(fn, 'w') as file:
        file.write(cls_page)

def get_concept_page_template(concept_id):
    idx = int(concept_id.split(' ')[1])
    d = DATA_CONCEPTS[concept_id]
    nb_fire = d['nb_fire'] * 100
    top_classes = d['top_classes']
    classes_links = ', '.join([f"[{ID_TO_CLS[cls]}](../classes/{ID_TO_CLS[cls]}.md)" for cls in top_classes])
    content = f"""
# Concept {idx}

The concept {idx} is activated in <u><b>{nb_fire:.2f}%</b></u> of the data points.
It is most important for the {WORDING_CLASSES} {classes_links}.

# Visualizations

<div style="border-color: {slack_blue}; color: {slack_blue}"  id="card-{concept_id}">
    <img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_{idx}_fv.webp" class="gallery-img" style="width: 100%">
</div>

The following images are the top 10 images that activate the concept {idx} the most.

<div class="heatmap-container">
    {''.join([f'<img src="https://storage.googleapis.com/serrelab/prj_fossils/thomas_sae_compressed/concept_{idx}_{i}.webp" class="heatmaps" style="width: 32%">' for i in range(10)])}
</div>
    """
    return content


# now write the concept pages
for concept_id in list(DATA_CONCEPTS.keys())[:limit]:
    if not DATA_CONCEPTS[concept_id]['is_dead']:

        fn = f"docs/concepts/{concept_id}.md"
        if os.path.exists(fn):
            os.remove(fn)

        concept_page = get_concept_page_template(concept_id)

        with open(fn, 'w') as file:
            file.write(concept_page)


# now update the mkdocs.yml file
with open('mkdocs_template.yml', 'r') as file:
    mkdocs_template = file.readlines()

with open('mkdocs.yml', 'w') as file:
    # look for the line '  - Home: index.md
    # then add the new classes after that
    # but preserve all the other lines
    for line in mkdocs_template:
        file.write(line)
        if line.strip() == '- Home: index.md':
            file.write(f'  - {WORDING_CLASSES}:\n')
            # add the new classes
            for cls_id in list(DATA['class_info'].keys())[:limit]:
                d = DATA['class_info'][cls_id]
                name = d['name']
                file.write(f'    - {name}: classes/{name}.md\n')
            file.write('\n')
            file.write('  - Concepts:\n')
            for concept_id in list(DATA_CONCEPTS.keys())[:limit]:
                if not DATA_CONCEPTS[concept_id]['is_dead']:
                  file.write(f'    - {concept_id}: concepts/{concept_id}.md\n')


