In [1]:
import tensorflow as tf
import common

In [2]:
class OutputLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(256,activation=common.mish)
        self.dense2 = tf.keras.layers.Dense(32,activation=common.mish)
        self.dense3 = tf.keras.layers.Dense(4)
    def call(self,fm,training):
        x = tf.squeeze(fm,axis=[1,2])
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x

In [3]:
class DwSubBlock(tf.keras.layers.Layer):
    def __init__(self,r=-1):
        super().__init__()
        self.depthwise_vf_outer = common.DepthwiseVisionField()
        self.decrease = common.DepthControl(r=r,a='mish')
        self.depthwise_vf_inner = common.DepthwiseVisionField()
        self.increase = common.DepthControl(r=-r,a='mish')
    def call(self,inp,training):
        x = self.depthwise_vf_outer(inp,training=training)
        x = self.decrease(x,training=training)
        x = self.depthwise_vf_inner(x,training=training)
        x = self.increase(x,training=training)
        return x
    
class SequentialDwSubBlock(tf.keras.layers.Layer):
    def __init__(self,repeat,r=-1):
        super().__init__()
        self.subblocks = [DwSubBlock(r=r) for _ in range(repeat)]
    def call(self,inp,training):
        x = inp
        for sb in self.subblocks:
            x = sb(x,training=training)
        return x

In [4]:
class Model(tf.keras.Model):
    def  __init__(self,levels=5,neck=2):
        super().__init__()
        self.dowmsamples = [common.DownSample(1),common.DownSample(2)]
        self.DwSubBlocks = [SequentialDwSubBlock(2),SequentialDwSubBlock(4),SequentialDwSubBlock(4)]
        self.ResBlockContracts = [common.ResBlockContract(2),common.ResBlockContract(4)]
        self.mixes = [common.DepthControl(r=-1,a='mish'),common.DepthControl(r=-1,a='mish')]
        self.Net = [
            common.ResBlockContract(repeat=4,r=-1),
            common.ResBlockContract(repeat=4,r=-1),
            common.ResBlockContract(repeat=6,r=-2),
            common.ResBlockContract(repeat=6,r=-2),
            common.ResBlockContract(repeat=2,r=-2),
        ]
        self.OutputLayer = OutputLayer()
        
    def call(self,image,training):
        fms = []
        
        fm = tf.tile(image,[1,1,1,16])
        fm = self.DwSubBlocks[0](fm,training=training)
        fms.append(fm)
        
        fm = self.dowmsamples[0](image,training=training)
        fm = tf.tile(fm,[1,1,1,32])
        fm = self.DwSubBlocks[1](fm,training=training)
        fms.append(fm)
        
        fm = self.dowmsamples[1](image,training=training)
        fm = tf.tile(fm,[1,1,1,64])
        fm = self.DwSubBlocks[2](fm,training=training)
        fms.append(fm)
        
        fm = self.ResBlockContracts[0](fms[0],training=training)
        fm = tf.concat([fm,fms[1]],axis=-1)
        fm = self.mixes[0](fm,training=training)
        
        fm = self.ResBlockContracts[1](fm,training=training)
        fm = tf.concat([fm,fms[2]],axis=-1)
        fm = self.mixes[1](fm,training=training)
        
        for N in self.Net:
            fm = N(fm,training=training)
        output = self.OutputLayer(fm,training=training)
        
        return output
        
    def predict(self,image):
        image = tf.image.resize(image,[128,128])
        if tf.rank(image) == 3:
            image = image[None,:,:,:]
        raw_xywh = self(image)
        xywh = tf.keras.activations.sigmoid(raw_xywh)
        xywh = xywh.numpy()
        return xywh

In [5]:
model = Model()

In [6]:
model.load_weights(r'D:\Competitions\ComputerVision\OCR\ChinaSteel\code\Segment\final_weights\model_1\weights')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x18212c1d580>

In [7]:
import FLAGS

FLAGS.DATA.TRAIN.TABLE_PATH = r'D:\Competitions\ComputerVision\OCR\ChinaSteel\dataset\train\official\train_table.csv'
FLAGS.DATA.TRAIN.DROP_BAD_BBOX_DATA = True
FLAGS.DATA.TRAIN.VALIDATION_SPLIT_RATIO = 0.05
FLAGS.DATA.TRAIN.VALIDATION_SPLIT_RANDOM_STATE = 100
FLAGS.DATA.TRAIN.IMAGE_PATH = r'D:\Competitions\ComputerVision\OCR\ChinaSteel\dataset\train\official\train'
FLAGS.DATA.TRAIN.TRAIN_BATCH_SIZE = 4
FLAGS.DATA.TRAIN.VAL_BATCH_SIZE = 8
FLAGS.DATA.TEST.IMAGE_PATH = r'D:\Competitions\ComputerVision\OCR\ChinaSteel\dataset\test\official\sample'
FLAGS.DATA.TEST.BATCH_SIZE = 8

FLAGS.MODEL = model

FLAGS.OPTIMIZER.TYPE = tf.keras.optimizers.Adam
FLAGS.OPTIMIZER.MAX_LEARNING_RATE = 1e-8
FLAGS.OPTIMIZER.SCHEDULE_GAMMA = -1.5

FLAGS.EPOCHS.TOTAL = 100
FLAGS.EPOCHS.WARMUP = 2

FLAGS.LOGGING.PATH = 'logs'
FLAGS.LOGGING.MODEL_NAME = 1
FLAGS.LOGGING.TRIAL_NUMBER = 2
FLAGS.LOGGING.NOTE = 'final weights finetune'
FLAGS.LOGGING.SAMPLES_PER_LOG = 128
FLAGS.LOGGING.TEST_IMAGE_COLUMNS = 5

In [None]:
%run train.py