In [None]:
from dataclasses import dataclass
import itertools as it

from pathlib import Path

import numpy as np

from concernbert.semantic import find_entity_docs
from concernbert.embeddings import load_caching_embedder
from concernbert.metrics import to_aad


@dataclass
class CdResult:
    inter_cd: float
    intra_cd: float
    groups: dict[str, float]


class CdCalculator:
    def __init__(self, model: str, cache_dir: str, batch_size: int = 16):
        self._embedder = load_caching_embedder(model, cache_dir, batch_size)
    
    def calc_cd(self, source: str, *, pbar: bool = False) -> CdResult:
        res = find_entity_docs(source)
        texts = list(it.chain(*res.values()))
        embeddings = self._embedder.embed(texts, pbar=pbar)
        groups: dict[str, float] = dict()
        for group_name, group_texts in res.items():
            group_embeddings = [embeddings[t] for t in group_texts]
            groups[group_name] = to_aad(np.array(group_embeddings))
        inter_cd = float(np.mean(list(groups.values())))
        intra_cd = to_aad(np.array(list(embeddings.values())))
        return CdResult(inter_cd, intra_cd, groups)

In [None]:
SOURCE = """
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.deltaspike.core.impl.message;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.event.Observes;
import jakarta.enterprise.inject.Default;
import jakarta.enterprise.inject.spi.*;
import jakarta.enterprise.inject.spi.configurator.BeanConfigurator;
import org.apache.deltaspike.core.api.message.Message;
import org.apache.deltaspike.core.api.message.MessageBundle;
import org.apache.deltaspike.core.api.message.MessageTemplate;
import org.apache.deltaspike.core.spi.activation.Deactivatable;
import org.apache.deltaspike.core.util.BeanConfiguratorUtils;
import org.apache.deltaspike.core.util.ClassDeactivationUtils;
import org.apache.deltaspike.core.util.ClassUtils;

import java.io.Serializable;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.*;

/**
 * Extension for handling {@link MessageBundle}s.
 *
 * @see MessageBundle
 * @see MessageTemplate
 */
public class MessageBundleExtension implements Extension, Deactivatable
{
    private final Collection<AnnotatedType<?>> messageBundleTypes = new HashSet<AnnotatedType<?>>();

    private List<String> deploymentErrors = new ArrayList<String>();

    private Boolean isActivated = true;

    @SuppressWarnings("UnusedDeclaration")
    protected void init(@Observes BeforeBeanDiscovery beforeBeanDiscovery)
    {
        isActivated = ClassDeactivationUtils.isActivated(getClass());
    }

    @SuppressWarnings("UnusedDeclaration")
    protected void detectInterfaces(@Observes ProcessAnnotatedType processAnnotatedType)
    {
        if (!isActivated)
        {
            return;
        }

        AnnotatedType<?> type = processAnnotatedType.getAnnotatedType();

        if (type.isAnnotationPresent(MessageBundle.class))
        {
            if (validateMessageBundle(type.getJavaClass()))
            {
                messageBundleTypes.add(type);
            }
        }
    }

    /**
     * @return <code>true</code> if all is well
     */
    private boolean validateMessageBundle(Class<?> currentClass)
    {
        boolean ok = true;

        // sanity check: annotated class must be an Interface
        if (!currentClass.isInterface())
        {
            deploymentErrors.add("@MessageBundle must only be used on Interfaces, but got used on class " +
                    currentClass.getName());
            return false;
        }

        for (Method currentMethod : currentClass.getDeclaredMethods())
        {
            if (!currentMethod.isAnnotationPresent(MessageTemplate.class))
            {
                continue;
            }
            
            if (String.class.isAssignableFrom(currentMethod.getReturnType()))
            {
                continue;
            }

            if (Message.class.isAssignableFrom(currentMethod.getReturnType()))
            {
                continue;
            }

            deploymentErrors.add(currentMethod.getReturnType().getName() + " isn't supported. Details: " +
                    currentMethod.getDeclaringClass().getName() + "#" + currentMethod.getName() +
                    " only " + String.class.getName() + " or " + Message.class.getName());
            ok = false;
        }

        return ok;
    }

    @SuppressWarnings("UnusedDeclaration")
    protected void installMessageBundleProducerBeans(@Observes AfterBeanDiscovery abd, BeanManager beanManager)
    {
        if (!deploymentErrors.isEmpty())
        {
            abd.addDefinitionError(new IllegalArgumentException("The following MessageBundle problems where found: " +
                    Arrays.toString(deploymentErrors.toArray())));
            return;
        }

        for (AnnotatedType<?> type : messageBundleTypes)
        {
            addAsBean(abd, beanManager, type);
        }
    }

    protected <T> void addAsBean(AfterBeanDiscovery abd, BeanManager beanManager, AnnotatedType<T> type)
    {
        BeanConfigurator<T> beanConfigurator = abd.addBean()
                .createWith(cc ->
                    {
                        final Bean<?> invocationHandlerBean = beanManager.resolve(
                                beanManager.getBeans(MessageBundleInvocationHandler.class));

                        return createMessageBundleProxy(type.getJavaClass(),
                                (MessageBundleInvocationHandler)
                                        beanManager.getReference(invocationHandlerBean, MessageBundleInvocationHandler.class, cc));
                    });
        BeanConfiguratorUtils.read(beanManager, beanConfigurator, type)
                .types(type.getJavaClass(), Object.class, Serializable.class)
                .addQualifier(Default.Literal.INSTANCE)
                .scope(ApplicationScoped.class) // needs to be a normalscope due to a bug in older Weld versions
                .id("MessageBundleBean#" + type.getJavaClass().getName());
    }

    @SuppressWarnings("UnusedDeclaration")
    protected void cleanup(@Observes AfterDeploymentValidation afterDeploymentValidation)
    {
        messageBundleTypes.clear();
    }

    private <T> T createMessageBundleProxy(Class<T> type, MessageBundleInvocationHandler handler)
    {
        return type.cast(Proxy.newProxyInstance(ClassUtils.getClassLoader(null),
            new Class<?>[]{type, Serializable.class}, handler));
    }
}
"""

In [None]:
SOURCE = Path("View.java").read_text()

In [None]:
cd_calculator = CdCalculator("_models/EntityBERT-v3_train_nonldl-lr5e5-2_83-e3", "_foo")

In [None]:
cd_calculator.calc_cd(SOURCE, pbar=True)

In [None]:
res = find_entity_docs(SOURCE)
res

In [None]:
texts = res["(45, 0)"]

In [None]:
embedder = load_caching_embedder("_models/EntityBERT-v3_train_nonldl-lr5e5-2_83-e3", "_foo", 32)

In [None]:
arr = np.array(list(embedder.embed(texts, pbar=True).values()))

In [None]:
to_aad(arr)