<h1><img width='30' src='../resources/gradio-icon.png'> Iris Gradio</h1>

## 📥 Importing

In [151]:
! pip install gradio



In [23]:
import json
import gradio as gr
import matplotlib.pyplot as plt

# Thanks to **dhaitz**!
plt.style.use('https://github.com/dhaitz/matplotlib-stylesheets/raw/master/pitayasmoothie-dark.mplstyle')

# 한글폰트 설정
import matplotlib as mpl
mpl.rc('font', family='D2Coding')
mpl.rcParams['axes.unicode_minus'] = False

# Gradio 테마 설정
theme = gr.themes.Base(
    primary_hue='violet',
    secondary_hue='teal',
    neutral_hue='slate',
    font_mono=[gr.themes.GoogleFont('JetBrains Mono'), 'ui-monospace', 'Consolas', 'monospace'],
)
set_darkmode = '''
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') !== 'dark') {
        url.searchParams.set('__theme', 'dark');
        window.location.href = url.href;
    }
}
'''

## 🖐️ Gradio 살펴보기

In [2]:
def greet(name):
    return f'환영합니다, {name}님!' if name else '👀'

In [3]:
demo = gr.Interface(fn=greet, inputs='text', outputs='text', theme=theme, js=set_darkmode)
demo.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [4]:
with gr.Blocks(theme=theme, js=set_darkmode) as demo:
    gr.Markdown('# 🖐️ Hello, World!')
    
    gr.Markdown('## 이름을 입력하고 버튼을 누르면...?')
    with gr.Row():
        input_textbox = gr.Textbox(label='📛 이름을 입력하세요.')
        output_textbox = gr.Textbox(label='🖐️ 어서오세요!', interactive=False)
    send_button = gr.Button('전송', variant='primary')
    send_button.click(fn=greet, inputs=input_textbox, outputs=output_textbox)
    
    gr.Markdown('## 이번엔 이름을 입력하기만 해도?')
    with gr.Row():
        input_textbox = gr.Textbox(label='📛 이름을 입력하세요.')
        output_textbox = gr.Textbox(label='🖐️ 어서오세요!', interactive=False)
    input_textbox.change(fn=greet, inputs=input_textbox, outputs=output_textbox)
    
demo.launch()

* Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.




## ⚜️ 붓꽃 품종 예측

### 📡 Azure 엔드포인트와 통신

In [None]:
import requests

endpoint = 'http://c8036432-a374-4b6f-b89d-067ef0d15b8a.koreacentral.azurecontainer.io/score'
headers = {
    'Content-Type': 'application/json',
    'Authorization': '!!! THIS IS SECRET KEY !!!'
}

def predict(data):
    response = requests.post(endpoint, headers=headers, json={'Inputs': {'input1': data}})
    if response.status_code == 200:
        return response.json()["Results"]["WebServiceOutput0"]
    return '문제가 발생했습니다.'

### 📈 Plot 이미지 출력

In [38]:
def save_plot(data_points):
    # 센터로이드의 평균 위치를 계산하기 위한 변수 초기화
    centroid_positions = {0: [0, 0], 1: [0, 0], 2: [0, 0]}
    centroid_colors = {0: 'b', 1: 'r', 2: 'g'}  # 클러스터 색상
    
    # 데이터 포인트를 기반으로 센터로이드 위치 계산
    for point in data_points:
        assignment = point["Assignments"]
    
        # 각 클러스터별로 거리 데이터 가져오기
        for i in range(3):
            dist_key = f"DistancesToClusterCenter no.{i}"
            if dist_key in point:
                # 위치의 평균 계산
                centroid_positions[i][0] += (point["sepal_length_cm"] + point[dist_key]) / 2
                centroid_positions[i][1] += (point["sepal_width_cm"] + point[dist_key]) / 2
    
    # 평균값으로 센터로이드 위치 계산
    for i in range(3):
        centroid_positions[i][0] /= len(data_points)
        centroid_positions[i][1] /= len(data_points)
    
    plt.figure(figsize=(8, 6))
    
    point_index = 0
    # 데이터 포인트 그리기
    for point in data_points:
        point_index += 1
        plt.scatter(point["sepal_length_cm"], point["sepal_width_cm"],
                    c='b' if point["Assignments"] == 0 else 'r' if point["Assignments"] == 1 else 'g')
        plt.text(point["sepal_length_cm"], point["sepal_width_cm"], f"{point_index}")

    # 클러스터 센터로이드 그리기
    for cluster, (x, y) in centroid_positions.items():
        plt.scatter(x, y, c=centroid_colors[cluster], marker='X', s=200)

    plt.title('Data Points and Cluster Centroids')
    plt.xlabel('Sepal Length (cm)')
    plt.ylabel('Sepal Width (cm)')
    plt.savefig('output/iris_clusters.png')
    plt.tight_layout()
    plt.close()  # plt.show() 대신 plt.close()를 사용
    return 'output/iris_clusters.png'  # 현재 figure 반환

### 👆 사용자 인터페이스 동작

In [39]:
data = [{
    'sepal_length_cm': 0.0,
    'sepal_width_cm': 0.0,
    'petal_length_cm': 0.0,
    'petal_width_cm': 0.0,
    'class': 'Iris-setosa'
}]


def click_add():
    print(f'@click_add()')
    data.append({
        'sepal_length_cm': 0.0,
        'sepal_width_cm': 0.0,
        'petal_length_cm': 0.0,
        'petal_width_cm': 0.0,
        'class': 'Iris-setosa'
    })
    return len(data)

def click_delete(index):
    print(f'@click_delete({index=})')
    if len(data) > 1:
        data.pop(index)
    return len(data)

def click_copy(index):
    print(f'@click_copy({index=})')
    data.append(dict(data[index]))
    return len(data)

def change_data(index, sepal_length, sepal_width, petal_length, petal_width):
    print(f'@click_copy({index=}, {sepal_length=}, {sepal_width=}, {petal_length=}, {petal_width=})')
    data[index]['sepal_length_cm'] = float(sepal_length)
    data[index]['sepal_width_cm'] = float(sepal_width)
    data[index]['petal_length_cm'] = float(petal_length)
    data[index]['petal_width_cm'] = float(petal_width)

def click_predict():
    result = predict(data)
    if type(result) is str:
        return result, None
    return json.dumps(result, ensure_ascii=False, indent=4), save_plot(result)

<h3><img width='22' src='../resources/gradio-icon.png'> Gradio 뷰 및 이벤트 리스터 정의</h1>

In [40]:
with gr.Blocks(theme=theme, js=set_darkmode) as demo:
    gr.Markdown('# ⚜️ 붓꽃 품종 예측')
    
    data_count = gr.State(0)
    @gr.render(inputs=data_count)
    def data_input_view(_):
        for i, datum in enumerate(data):
            row_idx = gr.State(i)
            with gr.Row():
                sepal_lenght_textbox = gr.Textbox(label='꽃받침 길이', value=datum['sepal_length_cm'], key=f'sl-{i}', interactive=True)
                sepal_width_textbox = gr.Textbox(label='꽃받침 넓이', value=datum['sepal_width_cm'], key=f'sw-{i}', interactive=True)
                petal_length_textbox = gr.Textbox(label='꽃잎 길이', value=datum['petal_length_cm'], key=f'pl-{i}', interactive=True)
                petal_width_textbox = gr.Textbox(label='꽃잎 넓이', value=datum['petal_width_cm'], key=f'pw-{i}', interactive=True)
                
                with gr.Column():
                    delete_button = gr.Button('🗑️ 이 데이터 삭제하기', size='md')
                    copy_button = gr.Button('📋 이 데이터 복제하기', size='md')
                
            """ 이벤트 처리 """
            change_inputs = [row_idx, sepal_lenght_textbox, sepal_width_textbox, petal_length_textbox, petal_width_textbox ]
            sepal_lenght_textbox.change(change_data, inputs=change_inputs)
            sepal_width_textbox.change(change_data, inputs=change_inputs)
            petal_length_textbox.change(change_data, inputs=change_inputs)
            petal_width_textbox.change(change_data, inputs=change_inputs)
            
            delete_button.click(click_delete, inputs=row_idx, outputs=data_count)
            copy_button.click(click_copy, inputs=row_idx, outputs=data_count)
    
    plus_button = gr.Button('➕ 새 데이터 추가하기', variant='secondary')
    send_button = gr.Button('🧠 품종 예측하기', variant='primary')
    
    gr.Markdown('## 🖨️ 예측 결과')
    plot_image = gr.Image(label="Plot", interactive=False)
    species_textbox = gr.TextArea(label='Json', interactive=False)
    
    """ 이벤트 처리 """
    plus_button.click(click_add, outputs=data_count)
    send_button.click(click_predict, outputs=[species_textbox, plot_image])

demo.launch()

* Running on local URL:  http://127.0.0.1:7870

To create a public link, set `share=True` in `launch()`.




@click_copy(index=0, sepal_length='1', sepal_width='0.0', petal_length='0.0', petal_width='0.0')
@click_copy(index=0, sepal_length='1', sepal_width='1', petal_length='0.0', petal_width='0.0')
@click_copy(index=0, sepal_length='1', sepal_width='1', petal_length='1', petal_width='0.0')
@click_copy(index=0, sepal_length='1', sepal_width='1', petal_length='1', petal_width='1')
@click_copy(index=0)
@click_copy(index=1, sepal_length='1.0', sepal_width='2', petal_length='1.0', petal_width='1.0')
@click_copy(index=1, sepal_length='1.0', sepal_width='2', petal_length='1.0', petal_width='2')
@click_copy(index=1, sepal_length='1.0', sepal_width='2', petal_length='2', petal_width='2')
@click_copy(index=1, sepal_length='2', sepal_width='2', petal_length='2', petal_width='2')
@click_copy(index=1)
@click_copy(index=2, sepal_length='3', sepal_width='2.0', petal_length='2.0', petal_width='2.0')
@click_copy(index=2, sepal_length='3', sepal_width='2.0', petal_length='3', petal_width='2.0')
@click_copy(in