In [1]:
import sys
sys.path.append('../scripts')

import pandas as pd
import numpy as np
import torch

from langchain_core.documents import Document
from langchain_community.embeddings.openai import OpenAIEmbeddings

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate

import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords

import re
from datetime import datetime
from tqdm import tqdm

import os
import getpass

from sklearn.metrics import f1_score
nltk.download('punkt')
stop_words = stopwords.words("french")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

[nltk_data] Downloading package punkt to /Users/SamuelLP/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# Import datas

In [8]:
data = pd.read_csv("../datas/passages_dates.csv")
train = pd.read_csv("../datas/train_data.csv")

In [9]:
data['texte'] = data['texte'].apply(lambda x: ' '.join([word for word in x.split() if word not in (stop_words)]))

# Insert your API key

In [11]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI key:")

# Function for preprocess your datas

# After that, we define our retriveal function

In [12]:
def retriver(df, query, openai_model, embedding_model_name):
    res_date = []

    for _, row in tqdm(df.iterrows(), total=df.shape[0]):

        texte = row["texte"]
        filename = row["filename"]

        # content = preprocess_documents(pd.Series({"texte": texte, "filename": filename}))
        content = PreprocessDocuments(pd.Series({"texte": texte, "filename": filename})).preprocess_documents()
        if openai_model != None:
            model = ChatOpenAI(model=openai_model, temperature=0.2)

        else :
            model = ChatOpenAI(temperature=0.2)

        if embedding_model_name == None:
            embedding_model = OpenAIEmbeddings()
            faiss_index = FAISS.from_documents(content, embedding_model)

        else:
            embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
            faiss_index = FAISS.from_documents(content, embedding_model)
        
        retriever = faiss_index.as_retriever(search_kwargs={'k': 7}, search_type="mmr")

        if query == "what is the date of accident?":
            template = """You are a legal assistant. You will look for information in the following document::
            {context}.
            You will search for one of the following information :
            - the date of the accident : The date of the accident refers to the date on which the accident occurred
            You will only return the date with the following format : YYYY-MM-DD. I don't want text around it
            The user is going to ask you the following thing: {question}
            """

            prompt = ChatPromptTemplate.from_template(template)

            chain = (
            {"context": retriever, "question": RunnablePassthrough()}
            | prompt
            | model
            | StrOutputParser()
            )

            val = chain.invoke(query)
            res_date.append(val)

        elif query == "what is the date of consolidation?":
            template = """You are a legal assistant. You will look for information in the following document::
            {context}.

            You will search for one of the following information :
            - the date of consolidation : The consolidation date is the date on which the victim's injuries
            became stable and were declared definitive by a doctor.
            The information should be present in most cases, but sometimes it is either missing (you will display the value "n.c.") or not applicable (you will display the value "n.a.") if the injury did not stabilize before the victim's death.

            You will return the date with the following format : YYYY-MM-DD. I don't want text around it
            If the date of consolidation is missing in the document, you will return the value n.c.
            If the injury did not stabilize before the victim's death, you will return the value n.a.

            The user is going to ask you the following thing: {question}
            """

            prompt = ChatPromptTemplate.from_template(template)

            chain = (
            {"context": retriever, "question": RunnablePassthrough()}
            | prompt
            | model
            | StrOutputParser()
            )

            val = chain.invoke(query)
            res_date.append(val)
    return res_date

# We normalize our outputs

In [24]:
french_months = {
    'janvier': 1, 'février': 2, 'mars': 3, 'avril': 4, 'mai': 5, 'juin': 6,
    'juillet': 7, 'août': 8, 'septembre': 9, 'octobre': 10, 'novembre': 11, 'décembre': 12
}

def extract_and_reformat_date(text):
    if 'n.c.' in text:
        return "n.c."
    elif "n.a." in text: 
        return 'n.a.'
    
    date_pattern_ymd = r"\d{4}-\d{2}-\d{2}"
    date_pattern_dmy = r"(\d{1,2}),?\s([a-zA-Zéû]+)\s(\d{4})"
    date_pattern_slash = r"(\d{1,2})/(\d{1,2})/(\d{4})"
    date_pattern_mdy = r"([a-zA-Z]+) (\d{1,2}), (\d{4})"
    
    match_ymd = re.search(date_pattern_ymd, text)
    match_dmy = re.search(date_pattern_dmy, text)
    match_slash = re.search(date_pattern_slash, text)
    match_mdy = re.search(date_pattern_mdy, text)
    
    if match_ymd:
        return match_ymd.group(0)
    elif match_dmy:
        day, month_name, year = match_dmy.groups()
        try:
            month_number = datetime.strptime(month_name, "%B").month
        except ValueError:
            month_number = french_months.get(month_name.lower())
            if month_number is None:
                return None
        return f"{year}-{month_number:02d}-{day.zfill(2)}"
    elif match_slash:
        part1, part2, year = match_slash.groups()
        day, month = part1, part2  # Assumant le format JJ/MM/AAAA
        return f"{year}-{int(month):02d}-{int(day):02d}"
    elif match_mdy:
        month_name, day, year = match_mdy.groups()
        try:
            month_number = datetime.strptime(month_name, "%B").month
        except ValueError:
            return None  # En cas de mois inconnu en anglais
        return f"{year}-{month_number:02d}-{day.zfill(2)}"
    else:
        return None

# Now, we apply our RAG

In [13]:
OPENAI_MODEL = "gpt-3.5-turbo"
EMBEDDING_MODEL_NAME = "xlm-r-distilroberta-base-paraphrase-v1"

In [14]:
train = train[:2]
train

Unnamed: 0,ID,filename,texte,sexe,date_accident,date_consolidation
0,0,Agen_100515.txt,Le : 12/11/2019 Cour d’appel d’Agen chambre so...,homme,1991-04-09,n.c.
1,1,Agen_1100752.txt,Le : 12/11/2019 Cour d’appel d’Agen chambre ci...,homme,2005-06-10,2010-01-19


In [15]:
query = "what is the date of accident?"
train = train[:2]
res_date_accident = retriver(train, query, OPENAI_MODEL, EMBEDDING_MODEL_NAME)
train["date_accident_pred"] = res_date_accident

  0%|          | 0/2 [00:00<?, ?it/s]

  from .autonotebook import tqdm as notebook_tqdm
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 2/2 [00:16<00:00,  8.27s/it]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train["date_accident_pred"] = res_date_accident


In [16]:
from date_normalization import DateNormalization

train['date_accident_pred'] = train['date_accident_pred'].apply(lambda x: DateNormalization(x).extract_and_reformat_date())
train

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train['date_accident_pred'] = train['date_accident_pred'].apply(lambda x: DateNormalization(x).extract_and_reformat_date())


Unnamed: 0,ID,filename,texte,sexe,date_accident,date_consolidation,date_accident_pred
0,0,Agen_100515.txt,Le : 12/11/2019 Cour d’appel d’Agen chambre so...,homme,1991-04-09,n.c.,1997-04-24
1,1,Agen_1100752.txt,Le : 12/11/2019 Cour d’appel d’Agen chambre ci...,homme,2005-06-10,2010-01-19,n.c.


In [None]:
query = "what is the date of consolidation?"

res_date_consolidation = retriver(train, query, OPENAI_MODEL, EMBEDDING_MODEL_NAME)
train["date_consolidation_pred"] = res_date_consolidation

In [None]:
train['date_accident_pred'] = train['date_accident_pred'].apply(extract_and_reformat_date)
train['date_consolidation_pred'] = train['date_consolidation_pred'].apply(extract_and_reformat_date)
train = train.fillna(value=np.nan)
train = train.fillna(value='n.c.')

# Let's compute our metrics

In [36]:
def compute_accuracy(df, y_true, y_pred):

    good_preds = (df[y_true] == df[y_pred]).sum()
    lenght = df.shape[0]

    accuracy = good_preds / lenght
    return accuracy

accuracy_accident = compute_accuracy(train, 'date_accident', 'date_accident_pred')
accuracy_consolidation = compute_accuracy(train, 'date_consolidation', 'date_consolidation_pred')

print(f"Accuracy accident: {accuracy_accident}")
print(f"Accuracy consolidation: {accuracy_consolidation}")

Accuracy accident: 0.712987012987013
Accuracy consolidation: 0.5467532467532468


In [40]:
def compute_f1(df, y_true, y_pred):

    y_t = (df[y_true] == df[y_true]).astype(int)
    y_p = (df[y_pred] == df[y_true]).astype(int)

    f1 = f1_score(y_t, y_p)

    return f1

f1_accident = compute_f1(train, 'date_accident', 'date_accident_pred')
f1_consolidation = compute_f1(train, 'date_consolidation', 'date_consolidation_pred')

print(f"F1 accident: {f1_accident}")
print(f"F1 consolidation: {f1_consolidation}")

F1 accident: 0.8324488248673237
F1 consolidation: 0.7069689336691856


In [54]:
train.to_csv("../datas/preds_rag_openai.csv")