# 1. PyTorch Installation

_NOTE_: If you run this notebook on Google Colab or any similar services, there's a possibility that they might have `pytorch` package installed. But if it's on a local machine, run the following cells (depends on what type of machine you have)

## 1.1. CPU

In [None]:
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu


## 1.2. GPU (an alternative to 1.1, if you have a high-end NVIDIA GPU with support of CUDA)

In [None]:
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

# 2. Checking for models and downloading them

_NOTE_: If you have downloaded the models from our [repository], then just create a folder called `models` and put your model files inside it. Otherwise, run the following cells

In [8]:
import os

if 'models' in os.listdir():
    if os.listdir('models'):
        print('Models are there, where are you?')
    else:
        !cd models && wget https://persianocr.cam/models/letters.pt
        !cd models && wget https://persianocr.cam/models/numbers.pt
else:
    os.mkdir('models')
    !cd models && wget https://persianocr.cam/models/letters.pt
    !cd models && wget https://persianocr.cam/models/numbers.pt

Models are there, where are you?


# 3. Loading the models for inference and testing

## 3.1. Importing libraries

In [9]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch

## 3.2. Resolving SSL problem (Optional and Occasional)

Only run this cell if:

1. You're on a Mac. Most of macOS users got errors regarding SSL problems (so did I)
2. You get SSL related errors in other OSs. 

If you're on Colab or similar services, you obviously don't need this cell.

In [10]:
import ssl

try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

## 3.3. Functions for text extraction

### 3.3.1. Numbers

In [11]:
_translator = str.maketrans("1234567890", "۱۲۳۴۵۶۷۸۹۰")

In [12]:
def latin_to_persian(number):
    return number.translate(_translator)

In [13]:
def extract_chars(result):
    df = result.pandas().xyxy[0] # because we only need the first one
    df = df.sort_values('xmin') # sorts them write to left'
    
    output_string = []
    for name, confidence in zip(df['name'], df['confidence']):
        if confidence > 0.8:
            output_string.append(name)
            
    output_string = ''.join(output_string)
    output_string = latin_to_persian(output_string)
    return output_string