In [8]:
import json
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
from typing import List, Dict, Union, Optional

from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain_community.llms import Ollama

In [2]:
dev_data_df = pd.read_csv("dev_data_df.csv")

In [9]:
class JSONLLoader(BaseLoader):
    def __init__(
        self,
        file_path: Union[str, Path],
        content_key: Optional[str] = None,
        ):
        self.file_path = Path(file_path).resolve()
        self._content_key = content_key
        
    def load(self) -> List[Document]:
        docs = []
        with open(self.file_path, 'r', encoding="utf8") as file:
            for line in file:
                data = json.loads(line.strip())
                claim_id = data['claim_id']
                type_ = data['type']
                query = data['query']
                url = data['url']
                url2text = data['url2text']

                text = ' '.join(url2text)

                metadata = dict(
                    claim_id=claim_id,
                    type=type_,
                    query=query,
                    source=url
                )
                docs.append(Document(page_content=text, metadata=metadata))
                    
        return docs

In [6]:
llm = Ollama(model="phi3:14b-instruct", temperature=0.0)

prompt_template = ChatPromptTemplate.from_messages([
    ("system", "You are strictly prohibited from generating any text of your own."),
    ("system", "Limit your answer to 50 words."),
    ("system", "Your task is to extract a part of the given text that directly answers the given question. The extracted information should contain a conclusive answer to the question and should be either positive or negative in relation to the question. It should also be concise without irrelevant words"),
    ("human", "Question: {question}"),
    ("human", "Pay more attention to the later parts of the evidences, as the intitial sentences are only the introduction"),
    ("human", "Evidence: {evidence}"),
    ("system", "You do not need to explain your answer. Only return the extracted sentence and follow the system instructions strictly"),
])


def get_relevant_sentence(question, evidence):
    chain = prompt_template|llm|StrOutputParser()
    return chain.invoke({"question":question, "evidence":evidence.page_content}).strip()

In [7]:
prompt_template_2 = ChatPromptTemplate.from_messages([
    ("system", "You are tasked with classifying a given claim based on provided statements into one of the following four categories:"),
    ("system", "1. Supported: If there is sufficient evidence indicating that the claim is legitimate, classify it as Supported."),
    ("system", "2. Refuted: If there is any evidence contradicting the claim, classify it as Refuted."),
    ("system", "3. Not Enough Evidence: If you cannot find any conclusive factual evidence either supporting or refuting the claim, classify it as Not Enough Evidence. This means the available information is insufficient to make a definitive judgment."),
    ("system", "4. Conflicting Evidence/Cherrypicking: If there is factual evidence both supporting and refuting the claim, classify it as Conflicting Evidence/Cherrypicking. This indicates that the evidence is mixed, and there are evidences both supporting and refuting the claim."),

    ("system", "Examples:"),
    ("system", "Example 1:"),
    ("system", "Claim: The new drug is effective in treating diabetes."),
    ("system", "Statements: ['Clinical trials have shown no significant reduction in blood sugar levels among patients.', 'Several patients reported no change in their blood sugar levels after using the drug', 'The drug has not been widely tested in clinical trials.'] "),
    ("system", "Final Classification: Refuted"),

    ("system", "Example 2:"),
    ("system", "Claim: The new drug is effective in treating diabetes."),
    ("system", "Statements: ['Clinical trials have shown a significant reduction in blood sugar levels among patients.', 'Many patients reported positive results in managing their blood sugar levels.', 'Experts in the field have praised the effectiveness of the drug.']"),
    ("system", "Final Classification: Supported"),

    ("human", "Claim: {claim}"),
    ("human", "Statements: {relevant_sentences}"),

    ("system", "Pick only one from ['Supported', 'Refuted', 'Not Enough Evidence', 'Conflicting Evidence/Cherry-picking']"),
    ("system", "Do not print anything other than the final classification."),
])

In [8]:
chain = prompt_template_2|llm|StrOutputParser()

In [9]:
def generate_claim_dicts(claim: str, docs: List[Document], answers: List[str], pred_label: str) -> Dict:
    evidences = []

    for i in range(len(docs)):
        doc = docs[i]
        answer = answers[i]
        question = doc.metadata['query']
        source = doc.metadata['source']
        
        evidences.append({'question': question,'answer': answer,'url': source})

    claim_dict = {
        'claim_id': int(doc.metadata['claim_id']),
        'claim': claim,
        'evidence': evidences,
        'pred_label': pred_label
    }

    return claim_dict

In [10]:
claim_dicts = []
pred_labels = []

In [None]:
for i in tqdm(range(len(dev_data_df))):
    claim = dev_data_df['claim'][i]
    question = dev_data_df['questions'][i]
    json_path = f'top_three/{i}.json'

    loader = JSONLLoader(file_path=json_path)
    data = loader.load()
    relevant_sentences = [get_relevant_sentence(question, evidence) for evidence in data]
    pred_label = chain.invoke({"claim":claim, "relevant_sentences":' '.join(relevant_sentences)}).strip()

    claim_dicts.append(generate_claim_dicts(claim, data, relevant_sentences, pred_label))
    pred_labels.append(pred_label)

In [86]:
with open('dev_preds.json', 'w') as file:
    json.dump(claim_dicts, file, indent=4)

In [2]:
import plotly.graph_objects as go
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import LabelEncoder

dev_data_df['predicted'] = pred_labels

label_encoder = LabelEncoder()
true_labels_int = label_encoder.fit_transform(dev_data_df['actual'])
pred_labels_int = label_encoder.transform(dev_data_df['predicted'])

cm = confusion_matrix(true_labels_int, pred_labels_int)
labels = label_encoder.classes_

reversed_labels = labels[::-1]

fig = go.Figure(data=go.Heatmap(
                   z=np.flip(cm, axis=1),
                   x=reversed_labels,
                   y=labels,
                   hoverongaps=False,
                   colorscale='Viridis',
                   showscale=True))

fig.update_layout(
    title='Confusion Matrix',
    xaxis_title='True Label',
    yaxis_title='Predicted Label',
    xaxis=dict(
        tickmode='array',
        tickvals=np.arange(len(reversed_labels)),
        ticktext=reversed_labels
    ),
    yaxis=dict(
        tickmode='array',
        tickvals=np.arange(len(labels)),
        ticktext=labels
    ),
    margin=dict(l=100, r=100, t=100, b=100),
    width=800,
    height=700
)

annotations = []
for i in range(len(reversed_labels)):
    for j in range(len(labels)):
        annotations.append(
            go.layout.Annotation(
                text=str(cm[j][len(reversed_labels)-1-i]),
                x=reversed_labels[i],
                y=labels[j],
                xref='x1',
                yref='y1',
                showarrow=False,
                font=dict(
                    color="white" if cm[j][len(reversed_labels)-1-i] < np.max(cm) / 2 else "black"
                )
            )
        )

fig.update_layout(annotations=annotations)
fig.show()