In [17]:
# !pip install easydict
# !pip install tensorboardX
# !pip install shapely

In [1]:
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F

import warnings
warnings.filterwarnings('ignore')

import torchvision
import torchvision.models as models

import torch.backends.cudnn as cudnn

import warnings
warnings.filterwarnings("ignore")

import torchvision.models.resnet as resnet
import torch.nn as nn
import torch.nn.functional as F

import torch
import torchvision
import torchvision.models as models

import torch.backends.cudnn as cudnn

import warnings
warnings.filterwarnings("ignore")

import os
import time
import argparse
from datetime import datetime

import torch.utils.data as data
from torch.optim import lr_scheduler
from util.shedule import FixLR

from dataset.total_text import TotalText
from dataset.synth_text import SynthText

from util.augmentation import BaseTransform, Augmentation
from util.config import config as cfg, update_config, print_config
from util.misc import AverageMeter
from util.misc import mkdirs, to_device
from util.option import BaseOptions
from util.visualize import visualize_network_output
from util.summary import LogSummary

from easydict import EasyDict
import matplotlib.pyplot as plt

import torch.multiprocessing as mp

warnings.filterwarnings("ignore", message="indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (function expandTensors)")

In [2]:
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

In [3]:
class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG16(nn.Module):

    def __init__(self, pretrain=True):
        super().__init__()
        net = VGG(make_layers(cfg['D']), init_weights=False)
        if pretrain:
            net.load_state_dict(model_zoo.load_url(model_urls['vgg16']))

        self.stage1 = nn.Sequential(*[net.features[layer] for layer in range(0, 5)])
        self.stage2 = nn.Sequential(*[net.features[layer] for layer in range(5, 10)])
        self.stage3 = nn.Sequential(*[net.features[layer] for layer in range(10, 17)])
        self.stage4 = nn.Sequential(*[net.features[layer] for layer in range(17, 24)])
        self.stage5 = nn.Sequential(*[net.features[layer] for layer in range(24, 31)])

    def forward(self, x):
        C1 = self.stage1(x)
        C2 = self.stage2(C1)
        C3 = self.stage3(C2)
        C4 = self.stage4(C3)
        C5 = self.stage5(C4)
        return C1, C2, C3, C4, C5


input = torch.randn((4, 3, 512, 512))
net = VGG16()
C1, C2, C3, C4, C5 = net(input)
print(C1.size())
print(C2.size())
print(C3.size())
print(C4.size())
print(C5.size())


torch.Size([4, 64, 256, 256])
torch.Size([4, 128, 128, 128])
torch.Size([4, 256, 64, 64])
torch.Size([4, 512, 32, 32])
torch.Size([4, 512, 16, 16])


In [4]:
class Upsample(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, upsampled, shortcut):
        # print("Upsampled shape:", upsampled.shape) # Print the shape of upsampled tensor
        # print("Shortcut shape:", shortcut.shape)
        x = torch.cat([upsampled, shortcut], dim=1)
        x = self.conv1x1(x)
        x = F.relu(x)
        x = self.conv3x3(x)
        x = F.relu(x)
        x = self.deconv(x)
        return x

In [5]:
class TextNet(nn.Module):

    def __init__(self, backbone='vgg', output_channel=7, is_training=True):
        super().__init__()

        self.is_training = is_training
        self.backbone_name = backbone
        self.output_channel = output_channel
        
        if backbone == 'vgg':
            
            self.backbone = VGG16(pretrain=self.is_training)

            self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)

            self.merge4 = Upsample(512 + 256, 128)
            self.merge3 = Upsample(256 + 128, 64)
            self.merge2 = Upsample(128 + 64, 32)
            self.merge1 = Upsample(64 + 32, 16)

            self.predict = nn.Sequential(
                nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
                nn.Conv2d(16, self.output_channel, kernel_size=1, stride=1, padding=0)
            )

    def forward(self, x):
        C1, C2, C3, C4, C5 = self.backbone(x)
        up5 = self.deconv5(C5)
        up5 = F.relu(up5)

        up4 = self.merge4(C4, up5)
        up4 = F.relu(up4)

        up3 = self.merge3(C3, up4)
        up3 = F.relu(up3)

        up2 = self.merge2(C2, up3)
        up2 = F.relu(up2)

        up1 = self.merge1(C1, up2)
        output = self.predict(up1)
        
        return output
    
#     def load_model(self, model_path):
#         print('Loading from {}'.format(model_path))
#         state_dict = torch.load(model_path, map_location=torch.device('cpu'))
#         self.load_state_dict(state_dict['model'])

In [6]:
input = torch.randn((4, 3, 512, 512))
net = VGG16()
C1, C2, C3, C4, C5 = net(input)
print(C1.size())
print(C2.size())
print(C3.size())
print(C4.size())
print(C5.size())

torch.Size([4, 64, 256, 256])
torch.Size([4, 128, 128, 128])
torch.Size([4, 256, 64, 64])
torch.Size([4, 512, 32, 32])
torch.Size([4, 512, 16, 16])


In [7]:
input = torch.randn((4, 3, 512, 512))
net = TextNet().cuda()
output = net(input.cuda())
print(output.size())

torch.Size([4, 7, 512, 512])


In [8]:
class TextLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def ohem(self, predict, target, train_mask, negative_ratio=3.):
        pos = (target * train_mask).bool()
        neg = ((1 - target) * train_mask).bool()

        n_pos = pos.float().sum()

        if n_pos.item() > 0:
            # loss_pos = F.cross_entropy(predict[pos.view(-1, 1)], target[pos], reduction='sum')
            # loss_neg = F.cross_entropy(predict[neg.view(-1, 1)], target[neg], reduction='none')
            
            loss_pos = F.cross_entropy(predict[pos], target[pos], reduction='sum')
            loss_neg = F.cross_entropy(predict[neg], target[neg], reduction='none')

            ###########
            
            n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
        else:
            # loss_pos = 0.
            # loss_neg = F.cross_entropy(predict[neg.view(-1, 1)], target[neg], reduction='none')
            
            loss_pos = torch.tensor(0.).to(predict.device)
            loss_neg = F.cross_entropy(predict[neg], target[neg].long(), reduction='none')
            
            n_neg = 100
        loss_neg, _ = torch.topk(loss_neg, n_neg)

        return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()

    def forward(self, input, tr_mask, tcl_mask, sin_map, cos_map, radii_map, train_mask):
        tr_pred = input[:, :2].permute(0, 2, 3, 1).contiguous().view(-1, 2)
        tcl_pred = input[:, 2:4].permute(0, 2, 3, 1).contiguous().view(-1, 2)
        sin_pred = input[:, 4].contiguous().view(-1)
        cos_pred = input[:, 5].contiguous().view(-1)

        scale = torch.sqrt(1.0 / (sin_pred ** 2 + cos_pred ** 2))
        sin_pred = sin_pred * scale
        cos_pred = cos_pred * scale

        radii_pred = input[:, 6].contiguous().view(-1)

        batch_size = tr_pred.size(0)
        height, width = tr_mask.size(-2), tr_mask.size(-1)

        train_mask = train_mask.view(-1)
        tr_mask = tr_mask.view(-1)
        tcl_mask = tcl_mask.view(-1)
        sin_map = sin_map.view(-1)
        cos_map = cos_map.view(-1)
        radii_map = radii_map.view(-1)

        # loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())
        loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
        

        loss_tcl = 0.
        tr_train_mask = train_mask * tr_mask
        if tr_train_mask.sum().item() > 0:
            loss_tcl = F.cross_entropy(tcl_pred[tr_train_mask], tcl_mask[tr_train_mask].long())

        # geometry losses
        loss_radii, loss_sin, loss_cos = 0., 0., 0.
        tcl_train_mask = train_mask * tcl_mask
        if tcl_train_mask.sum().item() > 0:
            ones = radii_map.new(radii_pred[tcl_train_mask].size()).fill_(1.).float()
            loss_radii = F.smooth_l1_loss(radii_pred[tcl_train_mask] / radii_map[tcl_train_mask], ones)
            loss_sin = F.smooth_l1_loss(sin_pred[tcl_train_mask], sin_map[tcl_train_mask])
            loss_cos = F.smooth_l1_loss(cos_pred[tcl_train_mask], cos_map[tcl_train_mask])

        return loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos

In [9]:
lr=None
train_step=0

def save_model(model, epoch, lr, optimzer):

    save_dir = os.path.join(cfg['save_dir'], cfg['exp_name'])
    if not os.path.exists(save_dir):
        mkdirs(save_dir)

    save_path = os.path.join(save_dir, 'textsnake_{}_{}.pth'.format(model.backbone_name, epoch))
    print('Saving to {}.'.format(save_path))
    state_dict = {
        'lr': lr,
        'epoch': epoch,
        'model': model.state_dict() if not cfg['mgpu'] else model.module.state_dict(),
        'optimizer': optimzer.state_dict()
    }
    torch.save(state_dict, save_path)

In [10]:
def load_model(model, model_path):
    print('Loading from {}'.format(model_path))
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict['model'])

In [11]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.title("Training and Validation Losses")
    plt.plot(train_losses, label="Training Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [12]:
def train(model, train_loader, criterion, scheduler, optimizer, epoch, logger):

    global train_step

    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    end = time.time()
    model.train()
    scheduler.step()

    print('Epoch: {} : LR = {}'.format(epoch, lr))

    for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(train_loader):
        data_time.update(time.time() - end)

        train_step += 1

        img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map)
        
        output = model(img)
        tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
            criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask)
        loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item())
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if cfg['viz'] and i % cfg['viz_freq'] == 0:
            visualize_network_output(output, tr_mask, tcl_mask, mode='train')

        if i % cfg['display_freq'] == 0:
            print('({:d} / {:d}) - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}'.format(
                i, len(train_loader), loss.item(), tr_loss.item(), tcl_loss.item(), sin_loss.item(), cos_loss.item(), radii_loss.item())
            )

        if i % cfg['log_freq'] == 0:
            logger.write_scalars({
                'loss': loss.item(),
                'tr_loss': tr_loss.item(),
                'tcl_loss': tcl_loss.item(),
                'sin_loss': sin_loss.item(),
                'cos_loss': cos_loss.item(),
                'radii_loss': radii_loss.item()
            }, tag='train', n_iter=train_step)

    if epoch % cfg['save_freq'] == 0:
        save_model(model, epoch, scheduler.get_lr(), optimizer)

    print('Training Loss: {}'.format(losses.avg))
    return losses.avg

In [13]:
def validation(model, valid_loader, criterion, epoch, logger):
    
    with torch.no_grad():
        model.eval()
        losses = AverageMeter()
        tr_losses = AverageMeter()
        tcl_losses = AverageMeter()
        sin_losses = AverageMeter()
        cos_losses = AverageMeter()
        radii_losses = AverageMeter()

        for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(valid_loader):

            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
                img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map)
            
            output = model(img)

            tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
                criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask)
            loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss

            # update losses
            losses.update(loss.item())
            tr_losses.update(tr_loss.item())
            tcl_losses.update(tcl_loss.item())
            sin_losses.update(sin_loss.item())
            cos_losses.update(cos_loss.item())
            radii_losses.update(radii_loss.item())

            if cfg['viz'] and i % cfg['viz_freq'] == 0:
                visualize_network_output(output, tr_mask, tcl_mask, mode='val')

            if i % cfg['display_freq'] == 0:
                print(
                    'Validation: - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}'.format(
                        loss.item(), tr_loss.item(), tcl_loss.item(), sin_loss.item(),
                        cos_loss.item(), radii_loss.item())
                )

        logger.write_scalars({
            'loss': losses.avg,
            'tr_loss': tr_losses.avg,
            'tcl_loss': tcl_losses.avg,
            'sin_loss': sin_losses.avg,
            'cos_loss': cos_losses.avg,
            'radii_loss': radii_losses.avg
        }, tag='val', n_iter=epoch)

        print('Validation Loss: {}'.format(losses.avg))
        return losses.avg       

In [26]:
def main():
    
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
        
    train_losses = []
    val_losses = []
        
    global lr
    print(cfg['dataset'])
    #######
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if cfg['dataset'] == 'total-text':

        trainset = TotalText(
            data_root='data/total-text',
            ignore_list=None,
            is_training=True,
            transform=Augmentation(size=cfg['input_size'], mean=cfg['means'], std=cfg['stds'])
        )

        valset = TotalText(
            data_root='data/total-text',
            ignore_list=None,
            is_training=False,
            transform=BaseTransform(size=cfg['input_size'], mean=cfg['means'], std=cfg['stds'])
        )

    elif cfg['dataset'] == 'synth-text':
        trainset = SynthText(
            data_root='data/SynthText',
            is_training=True,
            transform=Augmentation(size=cfg['input_size'], mean=cfg['means'], std=cfg['stds'])
        )
        valset = None
    else:
        pass

    train_loader = data.DataLoader(trainset, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'])
    if valset:
        val_loader = data.DataLoader(valset, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'])
    else:
        valset = None

    log_dir = os.path.join(cfg['log_dir'], datetime.now().strftime('%b%d_%H-%M-%S_') + cfg['exp_name'])
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(is_training=True, backbone='vgg', output_channel=7)
        
    if cfg['mgpu']:
        model = nn.DataParallel(model)

    model = model.to(cfg['device'])
    # model = model.to(device)
    
    #print(sum(p.numel() for p in model.parameters()))

    
    if cfg['cuda']:
        cudnn.benchmark = True

    if cfg['resume']:
        load_model(model, cfg['resume'])

    criterion = TextLoss()
    lr = cfg['lr']
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['lr'])

    if cfg['dataset'] == 'synth-text':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    
    for epoch in range(cfg['start_epoch'], cfg['max_epoch']):
        # train(model, train_loader, criterion, scheduler, optimizer, epoch, logger)
        train_loss=train(model, train_loader, criterion, scheduler, optimizer, epoch, logger)
        train_losses.append(train_loss)
        
        if valset:
            # validation(model, val_loader, criterion, epoch, logger)
            val_loss=validation(model, val_loader, criterion, epoch, logger)
            val_losses.append(val_loss)
            
    print('End.')
    
    plot_losses(train_losses, val_losses)

In [13]:
option = BaseOptions()
command_line_args = ["vgg123456_78", "--viz", "--net", "vgg", "--cuda", "True", "--dataset", "total-text", "--vis_dir", "./vis/"]

args = option.initialize(command_line_args)

update_config(cfg, args)
print_config(cfg)

# main
main()

total-text
A: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
B: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
D: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
E: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
exp_name: vgg123456_78
net: vgg
dataset: total-text
resume: None
num_workers: 8
cuda: True
mgpu: False
save_dir: ./save/
vis_dir: ./vis/
log_dir: ./logs/
loss: CrossEntropyLoss
input_channel: 1
pretrain: False
verbose: True
viz: True
start_iter: 0
max_epoch: 201
start_epoch: 0
lr: 0.0001
lr_adjust: fix
stepvalues: []
weight_decay: 0.0
gamma: 0.1
momentum: 0.9
batch_size: 8
optim: SGD
display_freq: 50
viz_freq: 50
save_freq: 50
log_freq: 100
val_freq: 100
rescale: 255.0
means: (0.485, 0.456, 0.406)
stds: (0.229, 0.224, 0.225)
input_size: 512
checkepoch: -1
img_root: None
device: cuda
total-text
Start training TextSnake.
Epoch: 0 : LR = 0.

terminate called without an active exception
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f81de0fc310>
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1430, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.9/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.9/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
  File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/signal_han

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

In [14]:
import os
import time
import cv2
import numpy as np
import torch
import subprocess
import torch.backends.cudnn as cudnn
import torch.utils.data as data
import multiprocessing
# from network.textnet import TextNet
from util.detection import TextDetector
# from util.augmentation import BaseTransform
# from util.option import BaseOptions
from util.visualize import visualize_detection
from util.misc import rescale_result

In [15]:
def write_to_file(contours, file_path):
    # according to total-text evaluation method, output file shoud be formatted to: y0,x0, ..... yn,xn
    with open(file_path, 'w') as f:
        for cont in contours:
            cont = np.stack([cont[:, 1], cont[:, 0]], 1)
            cont = cont.flatten().astype(str).tolist()
            cont = ','.join(cont)
            f.write(cont + '\n')


In [16]:
def inference(detector, test_loader, output_dir):
    total_time = 0.

    for i, (image, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(test_loader):

        image, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
            image, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map)

        torch.cuda.synchronize()
        start = time.time()

        idx = 0 # test mode can only run with batch_size == 1

        # get detection result
        contours, output = detector.detect(image)

        torch.cuda.synchronize()
        end = time.time()
        total_time += end - start
        fps = (i + 1) / total_time
        print('detect {} / {} images: {}. ({:.2f} fps)'.format(i + 1, len(test_loader), meta['image_id'][idx], fps))

        # visualization
        tr_pred, tcl_pred = output['tr'], output['tcl']
        img_show = image[idx].permute(1, 2, 0).cpu().numpy()
        img_show = ((img_show * cfg['stds'] + cfg['means']) * 255).astype(np.uint8)

        pred_vis = visualize_detection(img_show, contours, tr_pred[1], tcl_pred[1])
        gt_contour = []
        for annot, n_annot in zip(meta['annotation'][idx], meta['n_annotation'][idx]):
            if n_annot.item() > 0:
                gt_contour.append(annot[:n_annot].int().cpu().numpy())
        gt_vis = visualize_detection(img_show, gt_contour, tr_mask[idx].cpu().numpy(), tcl_mask[idx].cpu().numpy())
        im_vis = np.concatenate([pred_vis, gt_vis], axis=0)
        path = os.path.join(cfg['vis_dir'], '{}_test'.format(cfg['exp_name']), meta['image_id'][idx])
        cv2.imwrite(path, im_vis)

        H, W = meta['Height'][idx].item(), meta['Width'][idx].item()
        img_show, contours = rescale_result(img_show, contours, H, W)

        # write to file
        mkdirs(output_dir)
        write_to_file(contours, os.path.join(output_dir, meta['image_id'][idx].replace('jpg', 'txt')))


In [17]:
def main():
    if multiprocessing.get_start_method(allow_none=True) is None:
        multiprocessing.set_start_method('spawn')
    testset = TotalText(
        data_root='data/total-text',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=cfg['input_size'], mean=cfg['means'], std=cfg['stds'])
    )
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg['num_workers'])

    # Model
    model = TextNet(is_training=False, backbone=cfg['net'])
    model_path = os.path.join(cfg['save_dir'], f"{cfg['exp_name']}/textsnake_{model.backbone_name}_{cfg['checkepoch']}.pth")
    load_model(model,model_path)

    # copy to cuda
    model = model.to(cfg['device'])
    if cfg['cuda']:
        cudnn.benchmark = True
    detector = TextDetector(model, tr_thresh=cfg['tr_thresh'], tcl_thresh=cfg['tcl_thresh'])

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg['output_dir'], cfg['exp_name'])
    inference(detector, test_loader, output_dir)

    # compute DetEval
    print('Computing DetEval in {}/{}'.format(cfg['output_dir'], cfg['exp_name']))
    subprocess.call(['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name, '--tr', '0.6', '--tp', '0.4'])
    subprocess.call(['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name, '--tr', '0.8', '--tp', '0.4'])
    print('End.')


In [18]:
option = BaseOptions()

command_line_args = ["vgg123456", "--net", "vgg", "--checkepoch", "200"]  
args = option.initialize(command_line_args)

update_config(cfg, args)
print_config(cfg)

vis_dir = os.path.join(cfg['vis_dir'], '{}_test'.format(cfg['exp_name']))
if not os.path.exists(vis_dir):
    mkdirs(vis_dir)

cfg['D'] = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
cfg['tr_thresh'] = 0.6
cfg['tcl_thresh'] = 0.4
# config['post_process_expand'] = 0.3
cfg['output_dir'] = "./output/"
main()

total-text
A: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
B: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
D: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
E: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
exp_name: vgg123456
net: vgg
dataset: total-text
resume: None
num_workers: 8
cuda: True
mgpu: False
save_dir: ./save/
vis_dir: ./vis/
log_dir: ./logs/
loss: CrossEntropyLoss
input_channel: 1
pretrain: False
verbose: True
viz: False
start_iter: 0
max_epoch: 201
start_epoch: 0
lr: 0.0001
lr_adjust: fix
stepvalues: []
weight_decay: 0.0
gamma: 0.1
momentum: 0.9
batch_size: 4
optim: SGD
display_freq: 50
viz_freq: 50
save_freq: 50
log_freq: 100
val_freq: 100
rescale: 255.0
means: (0.485, 0.456, 0.406)
stds: (0.229, 0.224, 0.225)
input_size: 512
data_root: data/total-text
data_custom: False
checkepoch: 200
img_root: None
device: cuda
Loading fr

  return lib.intersection(a, b, **kwargs)
  if (gt[5] == '#') and (gt[1].shape[1] > 1):
  dc_id = np.where(groundtruths[:, 5] == '#')
100%|██████████| 301/301 [00:42<00:00,  7.12it/s]


Skipped directory .ipynb_checkpoints
Input: output/vgg123456
Config: tr: 0.6 - tp: 0.4
Precision = 0.8367 - Recall = 0.6616 - Fscore = 0.7389

Done.


  return lib.intersection(a, b, **kwargs)
  if (gt[5] == '#') and (gt[1].shape[1] > 1):
  dc_id = np.where(groundtruths[:, 5] == '#')
 99%|█████████▊| 297/301 [00:41<00:00, 11.41it/s]

Skipped directory .ipynb_checkpoints
Input: output/vgg123456
Config: tr: 0.8 - tp: 0.4
Precision = 0.7802 - Recall = 0.6244 - Fscore = 0.6937

Done.
End.


100%|██████████| 301/301 [00:41<00:00,  7.24it/s]
