In [14]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import pandas as pd

In [15]:
iris = load_iris()

In [16]:
iris_data = iris.data

In [17]:
iris_label = iris.target
print('iris target값:', iris_label)
print('iris target명:', iris.target_names)

iris target값: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
iris target명: ['setosa' 'versicolor' 'virginica']


In [18]:
iris_df = pd.DataFrame(data = iris_data, columns=iris.feature_names)
iris_df['label'] = iris_label
iris_df['species'] = [iris.target_names[i] for i in iris.target]
iris_df.tail(3)

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),label,species
147,6.5,3.0,5.2,2.0,2,virginica
148,6.2,3.4,5.4,2.3,2,virginica
149,5.9,3.0,5.1,1.8,2,virginica


In [19]:
X_train, X_test, y_train, y_test = train_test_split\
    (iris_data, iris_label, test_size = 0.2, random_state = 11)

In [20]:
y_train

array([0, 2, 2, 0, 0, 2, 2, 1, 0, 1, 1, 2, 0, 1, 2, 1, 1, 0, 2, 0, 2, 2,
       1, 2, 1, 0, 0, 1, 0, 0, 2, 2, 2, 0, 0, 0, 1, 0, 1, 2, 2, 1, 1, 2,
       2, 0, 1, 1, 2, 2, 2, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 1, 1, 2,
       1, 0, 0, 0, 1, 1, 1, 2, 1, 0, 1, 2, 0, 2, 2, 1, 0, 0, 0, 2, 1, 0,
       2, 1, 2, 0, 0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 0, 1, 2, 0, 2, 2, 0,
       1, 2, 0, 1, 1, 1, 0, 1, 1, 1])

In [21]:
df_clf = DecisionTreeClassifier(random_state = 11)

In [22]:
df_clf.fit(X_train, y_train)

In [23]:
pred = df_clf.predict(X_test)

In [24]:
from sklearn.metrics import accuracy_score
print('예측 정확도 {0:4f}'.format(accuracy_score(y_test, pred)))

예측 정확도 0.933333


In [25]:
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns

In [26]:
iris_df.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),label,species
0,5.1,3.5,1.4,0.2,0,setosa
1,4.9,3.0,1.4,0.2,0,setosa
2,4.7,3.2,1.3,0.2,0,setosa
3,4.6,3.1,1.5,0.2,0,setosa
4,5.0,3.6,1.4,0.2,0,setosa


In [27]:
iris_df = iris_df.drop("label", axis = 1)
iris_df.head()

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [28]:
def visualize_data(plot_type):
    if plot_type == "Histogram":
        _, axes = plt.subplots(2, 2, figsize = (15,10))
        axes = axes.flatten() # index 1차원
        for i, ax in enumerate(axes):
            sns.histplot(iris_df[iris.feature_name[i]], kde=True, ax=ax)
            ax.set_title(iris.feature_names[i])
    elif plot_type == "Correlation Martix":
        corr_df = iris_df.iloc[:, :-1]
        sns.heatmap(corr_df.corr(), annot=True, cmap='coolwarm')
    elif plot_type == "Pairplot":
        sns.pairplot(iris_df, hue = "species")

    plt.tight_layout()
    return plt

In [29]:
def predict_iris(sepal_length, sepal_width, petal_length, petal_width):
    input_data = [[sepal_length, sepal_width, petal_length, petal_width]]
    pred = df_clf.predict(input_data)
    class_name = iris.target_names
    return class_names[pred[0]]

In [30]:
plot_types = ["Histogram", "Correation Matrix", "pairplot"]
visualization_interface = gr.Interface(
    fn = visualize_data,
    inputs = gr.Dropdown(choices=plot_types, label = "Select Plot Type"),
    outputs= "plot",
    title = "Iris Data Visualization",
    description= "Select a plot type to visualize the Iris dataset"
)

In [31]:
prediction_interface = gr.Interface(
    fn = predict_iris,
    inputs = [
        gr.Slider(4.0, 8.0, step = 0.1, label = "Sepal Length (cm)"),
        gr.Slider(2.0, 4.5, step = 0.1, label = "Sepal width (cm)"),
        gr.Slider(1.0, 7.0, step = 0.1, label = "Petal Length (cm)"),
        gr.Slider(0.1, 2.5, step = 0.1, label = "Petal Length (cm)"),
    ],
    outputs = "text",
    title = "Iris Species Classifier",
    description= "Adjust the sliders to predict the Iris species",
    live = True
)

In [32]:
iface = gr.TabbedInterface([visualization_interface, prediction_interface],\
                           ["Data Visualization", "Species Prediction"])

In [34]:
iface.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


