In [None]:
!pip install transformers pandas matplotlib seaborn




In [20]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import torch
import seaborn as sns
import matplotlib.pyplot as plt

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

# Load the Iris dataset
from sklearn.datasets import load_iris

iris = load_iris()
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
iris_df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)


def query_dataset(query):
    """
    Process a query and retrieve relevant information or visualizations for the Iris dataset.

    Args:
        query (str): The user's query.

    Returns:
        str: The result of the query or a default message.
    """
    if "columns" in query.lower():
        return f"Available columns: {', '.join(iris_df.columns)}"
    elif "summary" in query.lower() and "visualize" in query.lower():
        visualize_summary(iris_df)
        return "Generated a visualization for the dataset summary."
    elif "correlation" in query.lower() and "visualize" in query.lower():
        visualize_correlation(iris_df)
        return "Generated a correlation heatmap for the dataset."
    elif "summary" in query.lower():
        return f"Dataset summary:\n{iris_df.describe(include='all')}"
    elif "correlation" in query.lower():
        return f"Correlation matrix:\n{iris_df.iloc[:, :-1].corr()}"
    elif "rows" in query.lower():
        return f"The Iris dataset contains {len(iris_df)} rows."
    elif "species" in query.lower():
        return f"Iris species in the dataset: {', '.join(iris.target_names)}"
    return "I'm not sure how to process that query with the Iris dataset."


def visualize_summary(data):
    """
    Visualize the summary statistics of the dataset.

    Args:
        data (pd.DataFrame): The dataset.
    """
    plt.figure(figsize=(10, 6))
    sns.boxplot(data=data.iloc[:, :-1])  # Exclude the species column for boxplots
    plt.title("Summary Statistics (Boxplot)")
    plt.xlabel("Features")
    plt.ylabel("Values")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()


def visualize_correlation(data):
    """
    Visualize the correlation heatmap of the dataset.

    Args:
        data (pd.DataFrame): The dataset.
    """
    correlation_matrix = data.iloc[:, :-1].corr()  # Exclude the species column
    plt.figure(figsize=(8, 6))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", cbar=True)
    plt.title("Correlation Heatmap")
    plt.tight_layout()
    plt.show()


def chatbot_response(user_input, chat_history_ids):
    """
    Generate a response from the chatbot given user input and chat history.

    Args:
        user_input (str): The user's input.
        chat_history_ids (torch.Tensor or None): Chat history tensor.

    Returns:
        tuple: (str, torch.Tensor) The bot's reply and updated chat history.
    """
    # Encode the user input
    new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')

    # Combine with chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if chat_history_ids is not None else new_user_input_ids

    # Generate response
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

    # Decode the response
    bot_reply = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)

    return bot_reply, chat_history_ids


# Chat loop
chat_history_ids = None

print("Chatbot is ready! Type 'exit' or 'quit' to end the chat.")

for step in range(5):
    user_input = input(">> User: ")

    if user_input.lower() in ["exit", "quit"]:
        print("DialoGPT: Goodbye!")
        break

    # Check if user input relates to the dataset
    if any(keyword in user_input.lower() for keyword in ["iris", "dataset", "visualize", "summary", "correlation", "rows", "species"]):
        result = query_dataset(user_input)
        print(f"DialoGPT: {result}")
    else:
        # Generate chatbot response
        bot_reply, chat_history_ids = chatbot_response(user_input, chat_history_ids)
        print(f"DialoGPT: {bot_reply}")


Chatbot is ready! Type 'exit' or 'quit' to end the chat.
>> User: visualize
DialoGPT: I'm not sure how to process that query with the Iris dataset.
>> User: summary
DialoGPT: Dataset summary:
        sepal length (cm)  sepal width (cm)  petal length (cm)  \
count          150.000000        150.000000         150.000000   
unique                NaN               NaN                NaN   
top                   NaN               NaN                NaN   
freq                  NaN               NaN                NaN   
mean             5.843333          3.057333           3.758000   
std              0.828066          0.435866           1.765298   
min              4.300000          2.000000           1.000000   
25%              5.100000          2.800000           1.600000   
50%              5.800000          3.000000           4.350000   
75%              6.400000          3.300000           5.100000   
max              7.900000          4.400000           6.900000   

        petal w