In [None]:
!pip install streamlit_chat pix2tex transformers accelerate

In [None]:
import torch
import streamlit as st
from PIL import Image
from pix2tex.cli import LatexOCR
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
from streamlit_chat import message

In [None]:
model_id="wonik-hi/phi3_fine_tuning"
math_model = LatexOCR()

torch.random.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_prompt=True)
model = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    streamer=streamer,
)

In [None]:
%%writefile app.py

import torch
import streamlit as st
from PIL import Image
from pix2tex.cli import LatexOCR
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer
from streamlit_chat import message

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

@st.cache_resource()
def load_model():
    #model_id = "microsoft/Phi-3-mini-4k-instruct"
    model_id="oz1115/phi3_fine_tuning"
    math_model = LatexOCR()

    torch.random.manual_seed(0)
    model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
    trust_remote_code=True,
    )
    #
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    streamer = TextStreamer(tokenizer, skip_prompt=True)
    model = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    streamer =streamer,
    )
    return math_model, model

def image_math(image):
    latex_code = math_model(image)
    if '\stackrel' in latex_code:
        latex_code = latex_code.split('\stackrel')[0]
    elif latex_code.startswith('\left['):
        latex_code = latex_code.split('\left[')[1]
    else:
        pass
    return latex_code

def generate_response(user_input, uploaded_file):
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        latex_code = image_math(image)
        new_image = image.resize((500, 150))
        st.image(new_image, caption=f" OCR Result:  {latex_code} ")
        contents = f"{user_input}: {latex_code}"
    else:
        contents = f"{user_input}?"

    messages = [{"role": "user", "content": contents},]
    return model(messages, **generation_args)[0]['generated_text']


# model
math_model, model = load_model()
generation_args = {
    "max_new_tokens": 500,
    "return_full_text": False,
    "temperature": 0.5,
    "do_sample": True,
}


# st format setting
st.header("🤖Phi-3 for math (영어버전)")
st.sidebar.markdown("## Information")
st.sidebar.info ("- 수학문제 풀이를 위한 'Phi-3'.\n - 현) 한글로 질문시 답변이 어려움.\n - This tool uses 'Phi-3' to solve math problems.\n - It is only understand 'English'")
st.sidebar.markdown("### Guide")
st.sidebar.write("1. (Optional) 이미지 파일 업로드 / Upload a image file.")
st.sidebar.write("2. 영어로 질문하기 / Ask something.\n\t - ex) solve the problem.")
st.sidebar.write("3. 버튼 클릭 / Click 'Send' button to activate the AI.")

# do chat
if 'generated' not in st.session_state:
    st.session_state['generated'] = []

if 'past' not in st.session_state:
    st.session_state['past'] = []

with st.form('form', clear_on_submit=True):
    uploaded_file = st.file_uploader("Upload a file", type=['png', 'jpg'], )
    user_input = st.text_input('You: ', '', key='input')
    submitted = st.form_submit_button('Send')

if submitted and user_input :
    with st.spinner('Thinking...'):
        output = generate_response(user_input, uploaded_file)
        st.session_state.past.append(user_input)
        st.session_state.generated.append(output)

if st.session_state['generated']:
    st.button("Clear History", on_click=lambda: st.session_state.clear())
    for i in range(len(st.session_state['generated']) - 1, -1, -1):
        message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
        message(st.session_state["generated"][i], key=str(i))

# import gc
# gc.collect()


In [None]:
!streamlit run /content/app.py &>/content/logs.txt &
!npx localtunnel --port 8501 & curl ipv4.icanhazip.com