In [0]:
import data.dataset
from data.dataset import ImageDataSet,collate_fn
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from torch.optim import lr_scheduler
import torch.utils.data as data
import torch.optim as optim
from models.FOTS import FOTS
from loss import *
from config import opt
from utils.bbox import Toolbox
import logging
import pathlib
import traceback
import os
import models
import torch
import time
import cv2

logging.basicConfig(level=logging.DEBUG, format='')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [0]:
# for training

def train(epochs, model, trainloader, crit, optimizer,scheduler, save_step, weight_decay):
  for e in range(opt.epoch_num):
    print('Epoch - {} / {}'.format(e + 1, epochs))
    model.train()
    start = time.time()
    loss = 0.0
    total = 0.0

    for i, (img, score_map, geo_map, training_mask) in enumerate(trainloader):
      scheduler.step()
      optimizer.zero_grad()

      img = Variable(img.cuda())
      score_map = Variable(score_map.cuda())
      geo_map = Variable(geo_map.cuda())
      training_mask = Variable(training_mask.cuda())
      f_score, f_geometry,_= model(img)

      loss1 = crit(score_map, f_score, geo_map, f_geometry, training_mask)
      loss += loss1.item()

      loss1.backward()
      optimizer.step()

    during = time.time() - start
    print("Loss : {:.6f}, Time:{:.2f} s ".format(loss / len(trainloader), during))

    if (e + 1) % save_step == 0:
      if not os.path.exists('./save_model'):
        os.mkdir('./save_model')
      torch.save(model.state_dict(), './save_model/model_{}.pth'.format(e + 1))

In [0]:
opt.parse({})
model = getattr(models, opt.model)()
if os.path.exists(opt.load_model_path):
  model.load(opt.load_model_path)

if opt.use_gpu:
  model.cuda()

root_path = 'icdar_data'
train_img = root_path + 'images'
train_txt = root_path + 'labels'

trainset = ImageDataSet(train_img, train_txt)
trainloader = DataLoader(
    trainset, batch_size=opt.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=opt.num_workers)

crit = LossFunc()
weight_decay = 0
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10000,gamma=0.94)

train(epochs=opt.epoch_num, model=model, trainloader=trainloader,
      crit=crit, optimizer=optimizer, scheduler=scheduler,
      save_step=5, weight_decay=weight_decay)

In [0]:
# for testing

model_path = "save_model/model_185.pth"
op_dir = "test_results/"
ip_dir = "text_detection/images/"

In [0]:
def load_model(model_path, with_gpu):
  logger.info("Loading checkpoint: {} ...".format(model_path))
  checkpoints = torch.load(model_path)
  if not checkpoints:
    raise RuntimeError('No checkpoint found.')
  FOTS_model = FOTS()
  FOTS_model.load_state_dict(checkpoints)
  if with_gpu:
    FOTS_model = FOTS_model.cuda()
  return FOTS_model

In [0]:
logger = logging.getLogger()

with_image = True if op_dir else False
with_gpu = True if torch.cuda.is_available() else False

model = load_model(model_path, with_gpu)

for image_fn in os.listdir(ip_dir):
  try:
    with torch.no_grad():
      ploy, im = Toolbox.predict(image_fn, ip_dir, model, with_image, op_dir, with_gpu)
  except Exception as e:
    traceback.print_exc()