In [None]:
!gdown 1cBiff7HEbaQsUGVhbdi5wmj83e9dDgU8
!unzip -q SceneTrialTrain.zip -d datasets

Downloading...
From (original): https://drive.google.com/uc?id=1cBiff7HEbaQsUGVhbdi5wmj83e9dDgU8
From (redirected): https://drive.google.com/uc?id=1cBiff7HEbaQsUGVhbdi5wmj83e9dDgU8&confirm=t&uuid=6f11f397-f28b-49be-b165-b6a5b37a3753
To: /content/SceneTrialTrain.zip
100% 45.7M/45.7M [00:00<00:00, 65.3MB/s]


In [None]:
!pip install ultralytics


Collecting ultralytics
  Downloading ultralytics-8.3.185-py3-none-any.whl.metadata (37 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.16-py3-none-any.whl.metadata (14 kB)
Downloading ultralytics-8.3.185-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.16-py3-none-any.whl (28 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.185 ultralytics-thop-2.0.16


In [None]:
from ultralytics import YOLO
import torch.nn as nn
from torchvision import transforms
import torch
import torchvision
import timm
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


# **Data Preprocessing**

# **Text Recognition Model**

In [None]:
class CRNN(nn.Module):
  def __init__(
      self,vocab_size,hidden_size,n_layers,dropout=0.2,unfreeze_layers=3
  ):
    super(CRNN,self).__init__()
    backbone = timm.create_model('resnet50',pretrained=True,in_chans=1)

    modules = list(backbone.children()) [:-2]
    for layer in modules[-int(unfreeze_layers):]:
      for param in layer.parameters():
        param.requires_grad = True

    modules.append(nn.AdaptiveAvgPool2d((1,None)))
    self.backbone = nn.Sequential(*modules)
    self.mapSeq = nn.Sequential(
        nn.Linear(2048,512),nn.ReLU(),nn.Dropout(dropout)
    )
    self.gru = nn.GRU(
        512,
        hidden_size,
        n_layers,
        bidirectional=True,
        batch_first=True,
        dropout=dropout
    )
    self.layer_norm = nn.LayerNorm(hidden_size*2)
    self.ff = nn.Sequential(
        nn.Linear(hidden_size*2,hidden_size),nn.ReLU()
    )
    self.batch_norm = nn.BatchNorm1d(hidden_size)
    self.out = nn.Sequential(
        nn.Linear(hidden_size,vocab_size), nn.LogSoftmax(dim=2)
    )
  def forward(self,x):
    x = self.backbone(x)
    x = x.squeeze(2)
    x = x.permute(0,2,1)
    x = self.mapSeq(x)
    x,_ = self.gru(x)
    x = self.layer_norm(x)
    x = self.ff(x)
    x = x.permute(0,2,1)
    x = self.batch_norm(x)
    x = x.permute(0,2,1)
    x = self.out(x)
    x = x.permute(1,0,2)
    return x

In [None]:
chars = "0123456789abcdefghijklmnopqrstuvwxyz-"
vocab_size=  len(chars) + 1
idx_2_chars = {idx+1:char for idx,char in enumerate(chars)}
chars_2_idx = {char:idx+1 for idx,char in enumerate(chars)}


In [None]:
hidden_size=256
n_layers = 3
dropout = 0.2
unfreeze_layers = 3
text_reg_model = CRNN(vocab_size,hidden_size,n_layers,dropout,unfreeze_layers).to(device)
text_reg_model.load_state_dict(torch.load('./weight/text_reg.pt',weights_only=True))

<All keys matched successfully>

In [None]:
def decode(encoded_outputs,idx_2_chars,blank_char='-'):
  encoded_outputs = encoded_outputs.permute(1,0,2).squeeze(0) #seqlen x vocabsize
  decoded_output = ''
  prev_char = None
  for token in encoded_outputs:
    pred_idx = torch.argmax(token).item()
    char = idx_2_chars[pred_idx]
    if char == blank_char or pred_idx == 0:
      prev_char = char
      continue
    if char != prev_char:
      decoded_output += char
    prev_char = char
  return decoded_output


In [None]:
data_transforms = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((100,420)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
    ])

In [None]:
def text_regconition(cropped_img,data_transforms,text_reg_model,idx_2_chars,blank_char):
  transformed_img = data_transforms(cropped_img).to(device)
  transformed_img = transformed_img.unsqueeze(0)
  text_reg_model.eval()
  with torch.no_grad():
    output = text_reg_model(transformed_img)
    decoded_output = decode(output,idx_2_chars,blank_char)
  return decoded_output


In [None]:
def predict(img_path,data_transforms,text_det_model,text_reg_model,idx_2_chars,blank_char,device):
  text_det_results = text_det_model(img_path)[0]
  bboxes = text_det_results.boxes.xyxy.tolist()
  confs = text_det_results.boxes.conf.tolist()
  classes = text_det_results.boxes.cls.tolist()
  outputs = []
  for bbox in bboxes:
    img = Image.open(img_path)
    cropped_img = img.crop((bbox[0],bbox[1],bbox[2],bbox[3]))
    output = text_regconition(cropped_img,data_transforms,text_reg_model,idx_2_chars,blank_char)
    outputs.append(output)
  return zip(bboxes,classes,confs,outputs)



In [None]:
def outputs_visualization(img_path,outputs):
  img = Image.open(img_path)
  plt.figure(figsize=(12,8))
  plt.imshow(img)
  plt.axis('off')

  for bbox,cls,conf,output in outputs:
    x1,y1,x2,y2 = bbox
    plt.gca().add_patch(
        plt.Rectangle(
            (x1,y1), x2-x1,y2-y1, fill=False,edgecolor='blue',linewidth=2
        )
    )
    plt.text(
        x1,y1-10,f'text ({round(conf,2)}) {output}',fontsize=10,bbox=dict(facecolor='red',alpha=0.5)
    )
  plt.show()

# **Inference**

In [None]:
text_det_model_path = './weight/text_detect.pt'
text_det_model = YOLO(text_det_model_path,verbose=False)
img_dir = 'datasets/SceneTrialTrain/apanar_06.08.2002'
for idx,img_filename in enumerate(os.listdir(img_dir)):
  img_path = os.path.join(img_dir,img_filename)
  outputs = predict(img_path,data_transforms,text_det_model,text_reg_model,idx_2_chars,'-',device)
  outputs_visualization(img_path,outputs)
  if idx >20:
    break