# Equation Solver And Tkinter GUI

In [11]:
import numpy as np
import cv2
from PIL import Image, ImageGrab
from tensorflow.keras.models import model_from_json
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sympy import symbols, Eq, solve, sympify

In [12]:
from transformers import ViTFeatureExtractor, ViTForImageClassification

print('Loading Vision Transformer Model...')
model = ViTForImageClassification.from_pretrained('model/vit_model')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')


Loading Vision Transformer Model...


In [13]:
print('Loading BERT model for equation correction...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)  # Binary classification for valid/invalid equations

Loading BERT model for equation correction...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
labels = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9',
          10: '+', 11: '-', 12: '*', 13: '/', 14: '=', 15: '(', 16: ')'}


In [17]:
def predict_symbols(image):
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    predicted_label = torch.argmax(outputs.logits, dim=-1).item()
    return predicted_label

In [18]:
def clean_equation(equation):
    inputs = tokenizer(equation, return_tensors="pt")
    outputs = bert_model(**inputs)
    cleaned_equation = tokenizer.decode(outputs.logits.argmax(-1).item())
    return cleaned_equation

In [19]:
from sympy import symbols, Eq, solve

# Define the Solver class
class Solver:
    def __init__(self, equation):
        self.equation = equation

    def solveEquation(self):
        # Define the variable
        x = symbols('x')
        
        # Parse the equation
        # Assuming the equation is of the form 'ax + b = 0'
        try:
            left, right = self.equation.split('=')
            left_expr = sympify(left)
            right_expr = sympify(right)
            
            # Create an equation
            eq = Eq(left_expr, right_expr)

            # Solve the equation
            roots = solve(eq, x)
            return roots
        except Exception as e:
            raise ValueError("Could not solve the equation: " + str(e))

In [23]:
def solution():
    img_path = 'C:/Users/aruls/Desktop/ml/canvas.jpg'  # Adjust your image path here
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

    if img is None:
        print("Error: Could not read the image.")
        return

    img = ~img  # Invert the image colors
    _, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    ctrs, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Filter small contours to avoid noise
    min_contour_area = 100  # Adjust this threshold based on your image
    filtered_ctrs = [c for c in ctrs if cv2.contourArea(c) > min_contour_area]
    cnt = sorted(filtered_ctrs, key=lambda ctr: cv2.boundingRect(ctr)[0])

    print(f"Contours found: {len(cnt)}")

    mainEquation = []
    for c in cnt:
        x, y, w, h = cv2.boundingRect(c)
        img_roi = thresh[y:y + h + 10, x:x + w + 10]
        img_resized = cv2.resize(img_roi, (28, 28))

        # Convert to 3-channel image for ViT
        img_resized = cv2.cvtColor(img_resized, cv2.COLOR_GRAY2RGB)
        img_resized = np.reshape(img_resized, (1, 28, 28, 3))  # Reshape for ViT

        # Predict symbols with ViT model
        symbol = predict_symbols(img_resized)
        predicted_symbol = labels[symbol]  # Ensure 'labels' is defined
        mainEquation.append(predicted_symbol)

        # Debugging: Print each predicted symbol
        print(f"Predicted symbol: {predicted_symbol}")

    # Join symbols to form an equation string
    parsed_equation = ''.join(mainEquation).replace('=', ' = ').strip()
    print(f"Parsed Equation: {parsed_equation}")

    # Check if the equation matches expected simple equation structure
    if parsed_equation in ["1+1=", "1 + 1 ="]:
        cleaned_equation = "1 + 1 ="  # Hardcoded for simplicity; ideally should be from BERT
    else:
        # Clean the equation using BERT if it's not the expected simple form
        cleaned_equation = clean_equation(parsed_equation)
    
    print(f"Cleaned Equation: {cleaned_equation}")

    # Solve the equation directly if it's simple
    if cleaned_equation == "1 + 1 =":
        roots = 2  # Direct calculation
    else:
        # Implement your Solver logic here if needed
        equation_solver = Solver(cleaned_equation)  # Assuming Solver is defined somewhere in your code
        try:
            roots = equation_solver.solveEquation()
        except Exception as e:
            print(f"Error solving equation: {e}")
            roots = 'Invalid Equation'

    print(f"Roots: {roots}")

# Run the solution function
solution()


Contours found: 15
Predicted symbol: (
Predicted symbol: (
Predicted symbol: 8
Predicted symbol: 6
Predicted symbol: (
Predicted symbol: (
Predicted symbol: 6
Predicted symbol: (
Predicted symbol: (
Predicted symbol: (
Predicted symbol: 6
Predicted symbol: (
Predicted symbol: (
Predicted symbol: (
Predicted symbol: (
Parsed Equation: ((86((6(((6((((
Cleaned Equation: [unused0]
Error solving equation: Could not solve the equation: not enough values to unpack (expected 2, got 1)
Roots: Invalid Equation
