<a href="https://colab.research.google.com/github/sleepyzzpanda/Environment-RAG-Chatbot/blob/main/Climate_RAG_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-2 RAG Chatbot for Climate Information
This notebook sets up a retrieval-augmented generation (RAG) chatbot using GPT-2 and FAISS embeddings for climate data, with an interactive cell-based interface.

In [8]:
!pip install torch transformers datasets faiss-cpu sentence-transformers ipywidgets
!pip install openai




In [22]:
import torch
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from ipywidgets import interact_manual, widgets
from datasets import load_dataset
import pandas as pd
import openai
import os

# Load secret from Colab
# os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY")  # Already stored as a secret
from google.colab import userdata
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY').strip()
# !unzip archive.zip -d climate_news_data



In [10]:
# # Load GPT-2 model
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model = GPT2LMHeadModel.from_pretrained('gpt2')

In [11]:
# # Load the ClimateBERT environmental claims dataset
# claims_dataset = load_dataset("climatebert/environmental_claims")
# print(claims_dataset['train'].column_names)
# # Load the NER dataset
# ner_dataset = load_dataset("ibm-research/Climate-Change-NER")
# print(ner_dataset['train'].column_names)


# news_df = pd.read_csv("climate_news_data/climate-change-news.csv")
# print(news_df.columns)
# print(news_df.head())

climate_x = load_dataset("rlacombe/ClimateX")
print(climate_x['train'].column_names)

passages = []

for example in climate_x["train"]:
    passages.append(example["statement"])

# for example in ner_dataset["train"]:
#     # If the dataset has 'tokens' and 'ner_tags'
#     if "tokens" in example:
#         sentence = " ".join(example["tokens"])  # join tokens into plain text
#         passages.append(sentence)
#     elif "text" in example:  # for datasets like ClimateBERT
#         passages.append(example["text"])

# Optional: remove empty or malformed entries
clean_passages = [p.strip() for p in passages if len(p.strip()) > 0]




The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

ipcc_statements_dataset.tsv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/8094 [00:00<?, ? examples/s]

['statement_idx', 'report', 'page_num', 'sent_num', 'statement', 'confidence', 'score', 'split']


In [12]:
# from google.colab import drive
# drive.mount('/content/drive')

In [13]:

file_path = "/content/drive/MyDrive/IAT360FinalProject/climate_headlines_sentiment.csv"
news_df = pd.read_csv(file_path)
print(news_df.columns)
print(news_df.head())

# Fill NaNs with empty strings to avoid errors
text_columns = ['Headline', 'Content', 'Justification']
news_df[text_columns] = news_df[text_columns].fillna('')

# Combine columns row-wise
news_passages = (news_df[text_columns]
                 .agg(' '.join, axis=1)   # joins columns with a space
                 .tolist())
# Remove empty or whitespace-only passages
news_passages = [p.strip() for p in news_passages if len(p.strip()) > 0]

print(f"{len(news_passages)} combined passages ready for embedding")
print(news_passages[:3])


Index(['Unnamed: 0', 'Headline', 'Link', 'Content', 'Sentiment',
       'Justification'],
      dtype='object')
   Unnamed: 0                                           Headline  \
0           0  Australia's year ahead in climate and environm...   
1           1  Projections reveal the vulnerability of freshw...   
2           2  Record heat in 2023 worsened global droughts, ...   
3           3  It's not just the total rainfall "“ why is eas...   
4           4  Expert Commentary: 2023 was the warmest year o...   

                                                Link  \
0  https://www.abc.net.au/news/science/2024-01-23...   
1  https://news.griffith.edu.au/2024/01/09/projec...   
2  https://www.anu.edu.au/news/all-news/record-he...   
3  https://www.theguardian.com/australia-news/202...   
4  https://www.csiro.au/en/news/all/news/2024/jan...   

                                             Content  Sentiment  \
0   The year has barely started and extreme weath...        0.0   
1   “Wat

In [24]:
from openai import OpenAI
import numpy as np
import faiss

client = OpenAI(api_key=OPENAI_API_KEY)

# Helper function to embed a list of texts using text-embedding-3-small
def embed_texts(texts, model="text-embedding-3-small", batch_size=128):
    """
    Embeds a list of texts using the OpenAI Embeddings API in safe batches.
    Returns a list of embeddings in order.
    """
    all_embeddings = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]

        response = client.embeddings.create(
            model=model,
            input=batch
        )

        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)

    return np.array(all_embeddings)


# ---- Embed first set of passages ----
embeddings = embed_texts(clean_passages, batch_size=128)

# Create FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

# Track passages
passages.extend(clean_passages)

# ---- Embed news passages ----
news_embeddings = embed_texts(news_passages, batch_size=128)
index.add(news_embeddings)

# Track passages
passages.extend(news_passages)



In [28]:
# Retrieval function
def retrieve_passages(query, k=2):
    # Embed the query using OpenAI embeddings
    query_emb = embed_texts([query])[0]     # returns shape (1536,)
    query_emb = np.array(query_emb).reshape(1, -1)

    # Search FAISS
    _, indices = index.search(query_emb, k)
    return [passages[i] for i in indices[0]]


In [16]:
# # RAG generation function
# def generate_answer(query, k=2, max_new_tokens=75):
#     context_passages = retrieve_passages(query, k)
#     context = ' '.join(context_passages)
#     prompt = f"Question: {query}\nProvide accurate information concisely in 1-2 sentences based on the following context (in natural language, with a conversational tone): {context}. Do not repeat any sentences you have have already said in the same response."

#     # Encode input
#     inputs = tokenizer(prompt, return_tensors="pt")

#     # Generate output
#     output = model.generate(
#         **inputs,
#         max_new_tokens=max_new_tokens,
#         pad_token_id=tokenizer.eos_token_id  # avoids padding issues
#     )

#     return tokenizer.decode(output[0], skip_special_tokens=True)


In [30]:
import os
from IPython.display import display, clear_output
from ipywidgets import widgets
from openai import OpenAI
import re

# Load API key from colab secret
client = OpenAI(api_key=OPENAI_API_KEY)

# Store chat history
chat_history = []


# ---------------------------
# CLEANUP: REMOVE LABEL TAGS
# ---------------------------
def remove_labels(text):
    return re.sub(r'\b[I|B]-[A-Za-z0-9_-]+\b', '', text).strip()


# ---------------------------
# GPT-4 GENERATION USING RAG
# ---------------------------
def generate_answer_clean(query, k=2):
    """
    Retrieve passages + generate GPT-4 answer.
    """

    # Retrieve top-k context passages
    context_passages = retrieve_passages(query, k)
    context = " ".join(context_passages)

    # Build prompt
    prompt = f"""
You are a helpful assistant specializing in knowledge about climate change and the envorinment. Use the context below to answer the question.
Ensure the answer ends with a complete sentence.

Context:
{context}

Question: {query}
Answer:
""".strip()

    # Send to GPT-4
    response = client.chat.completions.create(
        model="gpt-4.1",       # change to gpt-4o, gpt-4.1, etc.
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3
    )

    answer = response.choices[0].message.content
    answer = remove_labels(answer)
    return answer


# ---------------------------
# CHAT WIDGET LOGIC
# ---------------------------
def chat_interface_widget(user_input):
    if user_input.strip() == "":
        return

    # Generate answer using GPT-4
    answer = generate_answer_clean(user_input)

    # Update history
    chat_history.append(("You", user_input))
    chat_history.append(("Bot", answer))

    # Refresh chat display
    clear_output(wait=True)
    for speaker, text in chat_history:
        print(f"{speaker}: {text}\n")
    display(input_widget, run_button)


# ---------------------------
# INPUT WIDGET
# ---------------------------
input_widget = widgets.Text(
    value='',
    description='Your Question:',
    placeholder='Ask something...'
)

run_button = widgets.Button(description="Send")

def on_button_click(b):
    chat_interface_widget(input_widget.value)
    input_widget.value = ""

run_button.on_click(on_button_click)

display(input_widget, run_button)


You: tell me about the effects of forest fires

Bot: Forest fires can lead to significant declines in forest productivity, exacerbate water stress, and cause severe impacts on natural ecosystems by increasing the burned area and severity of wildland fires.

You: what can we do to counteract the effects of climate change?

Bot: To counteract the effects of climate change, we can implement conservation approaches such as establishing marine protected areas (MPAs) and identifying climate refugia, undertake habitat restoration efforts, and adopt ecosystem-based management policies to help alleviate or adapt to climate-change impacts.

You: what causes rising sea levels?

Bot: Rising sea levels are primarily caused by ocean warming, which leads to seawater expansion, and by the addition of water from melting glaciers and ice sheets, with ocean warming accounting for more than 90% of the energy accumulated in the climate system between 1971 and 2010.



Text(value='what causes rising sea levels?', description='Your Question:', placeholder='Ask something...')

Button(description='Send', style=ButtonStyle())