# Sample Inference using RAG and DPL
Here we use the precomputed context and translations from Deepl to create the prompt that will be passed into the Mistral Nemo 12B model




In [None]:

from transformers import AutoTokenizer, AutoModelForCausalLM
from pydantic import BaseModel, ValidationError, validator, Field


import json
import torch
import numpy as np
import pickle
import pandas as pd


In [None]:
# set up paths and model names
model_id = "mistralai/Mistral-Nemo-Instruct-2407"
context_file = '/home/automate_eurocrops/data/BW_context.pkl'
data_root = '/home/automate_eurocrops/data/raw/'
HCAT_fname = 'HCAT4.xlsx'
country_fname = 'BW.csv'
# deepl file
d_path = '/home/automate_eurocrops/data/processed/' 
file_name = 'BW_trans.xlsx'

In [None]:
# Set up mistral model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")


In [None]:
# set up the crop list to reference to-HCAT and crop descriptions to be translated and mapped
with open(context_file, 'rb') as f:
    contexts = pickle.load(f)

hcat_file = pd.read_excel(data_root+HCAT_fname)
country_file = pd.read_csv(data_root+country_fname ) 
hcat_name = hcat_file['HCAT4_name']
original_name = country_file['original_name']   
HCAT_list = list(hcat_name)


In [None]:
# set up the translations of the crop descriptions by deepl
dpl = pd.read_excel(d_path+file_name)
dpl_trans = dpl['dpl_trans']

In [None]:
# To prevent hallucinations of non HCAT names being mapped to
# pydantic function to ensure outputed HCAT is actually from the list of HCAT_names

class HCAT(BaseModel):
    # id: int
    # original code will be appended as key outside llm prompt
    original_name: str 
    HCAT_name: str
    translated_name: str

    @validator('HCAT_name')
    def validate_HCAT_mapping(cls, v):
        if v not in HCAT_list:
            raise ValueError({v},' is not a valid HCAT mapping.')
        return v
    

In [None]:
# defining the inference function
# here to change the prompt, adjust whether if context/dpl translation is required.
def perform_inference(model, data, hcatn, context, dpl_trans):
        """

# code adapted from https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407 function calling with transformers


Inputs:
  - model: the llm to be used (loaded via AutoModelForCausalLM.from_pretrained)
  - data: original crop description from the country GSA
  - hcatn: HCAT list of allowed HCAT names 
  - context: text string of context by RAG via cosine similarity from AGROVOC and Agriprod (precomputed in advance)
  - dpl_trans: DeepL translation of the original crop description

Output:
  - JSON of validated objects: {original_name, translated_name, HCAT_name}
"""
      
        max_retries = 5 
        for attempt in range(max_retries):
        


            prompt = f"""
            [INST]
            You are an assistant that answers ONLY with the required json output to the user's question. No explanations.
            Based on the following agricultural information: {context}
            
            Given a description below:
            
            {data}
            

            First translate every part of the description to English, while taking into consideration that deepl's translation is {dpl_trans} 
            This is "translated_name"


            Next use the "translated_name" to match based on the closest semantic meaning to an entry from this list of HCAT_names:

            {hcatn}
            The "HCAT_name" should be as specific as possible, prioritize species level semantic matches to the HCAT_names from the "translated_name"
            Consider the entire "translated_name" and look for the HCAT_name that is the closest match to the entire description. 
            ALso map to the more informative HCAT_name.
            E.g for root chicory, chicory is more informative than root so map to chicory_chicories
            Summer crops should be the spring equivalent in "HCAT_name", but keep the translated_name as the original English translation
            If the exact "translated_name" is not in the HCAT_names, find the upper class of crops that includes the crop in HCAT_names and map to its other class.
            For example for prickly pear, the upper crop class is fruit, "HCAT_name" is other_orchards_fruits.
            If the "translated_name" is not a crop or argriculture product, for example rocks or landscape features, "HCAT_name" is not_known_and_other
            If the "translated_name" is a mix of crops, find the upper class of crops that includes all the crops in HCAT_names and map to the upper class
            For example if "translated_name" is mixed alfafa and clover, the common upper crop class is legumes, "HCAT_name" is legumes.
            Or if "translated_name" is mixed cultures, the common upper crop class is arable crops, HCAT_name" is arable_crops
            Trees and flowers are also agriculture products in HCAT.
            
            
                


            Return only in json format defined below:
            '''
            "original_name": "{data}", "translated_name": "english translation of {data}",  "HCAT_name": "closest semantic match of translated name in HCAT_names"
            ''' 
            Only use a "HCAT_name" from the provided HCAT_names.
            No other detail.

            **Example Json Output:**
            "original_name": "Erdbeeren", "translated_name": "strawberries",  "HCAT_name": "strawberries"

            [/INST]
            """


            # Tokenize input
            tokens = tokenizer(prompt, return_tensors="pt").to(model.device)
            # Generate output tokens
            with torch.no_grad():
                out_tokens = model.generate(**tokens, max_new_tokens=100, temperature=0.35, top_p=0.9, eos_token_id=tokenizer.eos_token_id, do_sample=True)
            # Decode tokens to string

            result = tokenizer.decode(out_tokens[0][tokens['input_ids'].shape[1]:], skip_special_tokens=True).strip()
           
    

            try:
    
                
                validated = HCAT.parse_raw(result)
                
                print("Validated Crop:", validated)
                
         
                break  # exit loop if successful
            except (json.JSONDecodeError, ValidationError) as e:
                print("Parse failed. Re-prompting...\n")
                print(f"JSON Decode Error: {e}")
    
                validated = HCAT(original_name=data, HCAT_name="not_known_and_other", translated_name="")
                
   

        return validated


In [None]:
# perform inference on one crop description
output = perform_inference(model, original_name[0], hcat_name.to_list(), contexts[0], dpl_trans[0])

In [None]:
# print generation
print(output)