# Sequential Prediction of F-Terms

This notebook provides some code to load the model from huggingface and predict f-terms for given patent abstracts.

## load model and files

In [None]:
#pip install -r requirements.txt
from transformers import AutoTokenizer, OPTForCausalLM
import torch
import pickle as pk

torch.cuda.is_available()

In [None]:
#Request access via https://huggingface.co/RWTH-TIME/galactica-125m-f-term-classification
default_dtype = torch.bfloat16
torch.set_default_dtype(default_dtype)

tokenizer = AutoTokenizer.from_pretrained("RWTH-TIME/galactica-125m-f-term-classification")
model = OPTForCausalLM.from_pretrained("RWTH-TIME/galactica-125m-f-term-classification", torch_dtype=default_dtype, low_cpu_mem_usage=True,
                                           device_map="auto")

In [3]:
# load theme, viewpoint and f-term dictionaries from pickle files
"""
with open(f'data/full_descriptions.pk', 'rb') as f:
    full_descriptions_dict = pk.load(f)
"""

with open(f'data/f_term_dict.pk', 'rb') as f:
    f_term_dict = pk.load(f)

with open(f'data/themes_descriptions.pk', 'rb') as f:
    themes_descriptions = pk.load(f)

with open(f'data/viewpoints_descriptions.pk', 'rb') as f:
    viewpoints_descriptions = pk.load(f)

## predict F-terms for given abstracts or technological descriptions

In [4]:
def generate(
    prompt, 
    model, 
    tokenizer, 
    max_pred_tokens=10, 
    decode=True, 
    enforce_no_repetition=True, 
    ignore_eos_token=True
):
    """
    Generates FTERM classifications for a given patent abstract.

    Prompts a given model and returns comma-separated FTERMS.

    Parameters
    ----------
    prompt : str
        A patent abstract text or technology description.
    model : transformers.models.opt.modeling_opt.OPTForCausalLM
        The transformer model used for classification.
    tokenizer : transformers.PreTrainedTokenizer
        The tokenizer associated with the model.
    max_pred_tokens : int, optional
        The maximum number of patent classes to predict (default is 10).
    decode : bool, optional
        If True, outputs decoded text classes; otherwise, returns token IDs (default is True).
    enforce_no_repetition : bool, optional
        If True, inhibits repeated prediction of the same FTERM class (default is True).
    ignore_eos_token : bool, optional
        If True, enforces prediction of max_pred_tokens and ignores the model's EOS token (default is True).

    Returns
    -------
    str or list of int
        A list of FTERM classes for the given prompt, either decoded or as token IDs.
    """
    # Add the start FTERM token to the prompt
    prompt += "<START F-TERMS>"

    # Convert the prompt to tokens
    eos_token_id = -999 if ignore_eos_token else tokenizer.eos_token_id
    tokenized = tokenizer(prompt, return_tensors="pt")
    prompt_tokens = tokenized["input_ids"][:, :-1]
    attention_mask = tokenized["attention_mask"][:, :-1]

    # Initialize variables for generation
    predictions = []
    current_token = -100

    # Generate the FTERMS
    while current_token != eos_token_id and len(predictions) < max_pred_tokens:
        # Model forward pass
        output = model(
            input_ids=prompt_tokens, 
            attention_mask=attention_mask, 
            output_attentions=False, 
            output_hidden_states=False, 
            return_dict=True
        )
        logits = output["logits"]

        # Get token predictions sorted by likelihood
        current_token_predictions = torch.sort(logits[0, -1], dim=-1, descending=True)
        i = 0
        current_token = current_token_predictions[1][i].item() + 50000

        # Handle no repetition and EOS token rules
        while (
            (enforce_no_repetition and current_token in predictions) or 
            (current_token == tokenizer.eos_token_id and ignore_eos_token)
        ):
            i += 1
            current_token = current_token_predictions[1][i].item() + 50000

        predictions.append(current_token)

        # Update prompt tokens and attention mask
        prompt_tokens = torch.cat([prompt_tokens, torch.tensor([[current_token]])], dim=-1)
        attention_mask = torch.cat([attention_mask, attention_mask[:, -1:]], dim=-1)

    # Decode predictions if required
    if decode:
        return tokenizer.decode(predictions)
    return predictions

In [None]:
abstract="PROBLEM TO BE SOLVED: To enable biological rhythm of small animals to be regulated. <P>SOLUTION: The cushion 1 for small animal includes a base part 2 on which the small animal can lay the body, a swelled part 3 formed on the base part 2, a light-radiating part 5 for irradiating light to the small animal lying on the base part 2, and a control part 8 for switching the light irradiated from the light-radiating part 5 according to the time regulated based on a previously set light pattern. For example, the light-irradiating part 5 includes a light source part 6 that includes a light source in the inside, and a light-inlet part 7 for introducing the light from the light source, while emitting light."
generate(abstract, model, tokenizer)

In [None]:
#Translating F-term predictions:
output=generate(abstract, model, tokenizer)
output=output.split(",")[:-1]
for fterm in output:
    print(f' Theme: {themes_descriptions[fterm[:5]]} | Viewpoint: {viewpoints_descriptions[fterm[:8]]} | F-term: {f_term_dict[fterm]}')