In [None]:
class ControlsModel:
    def __init__(self, a3c_model):
        self.a3c_model = a3c_model
        self.session = a3c_model.session
        self.opt = tf.train.AdamOptimizer(a3c_model.lr)
        
        
    def add_types(self):
        num_controls = len(Types.all_types)

        self.control_types = tf.placeholder(tf.int32, (None), "control_types")
                
        he_init = tf.contrib.layers.variance_scaling_initializer(mode="FAN_AVG")
        xavier_init = tf.contrib.layers.xavier_initializer()
        zero_init = tf.constant_initializer(0)
        
        l2_reg = slim.l2_regularizer(self.a3c_model.l2)
        
        with slim.arg_scope([slim.conv2d, slim.fully_connected],
                              weights_initializer = xavier_init,
                              biases_initializer = zero_init,
                              weights_regularizer = l2_reg
                             ):

            fc2 = slim.fully_connected(self.a3c_model.net, 100, weights_initializer=he_init)
            flat = slim.dropout(fc2, self.a3c_model.dropout, scope='dropout')

            self.types_logits = slim.fully_connected(flat, num_controls, activation_fn=None)
            
            self.types_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = self.control_types, 
                logits = self.types_logits)
            
            self.types_loss = tf.reduce_mean(self.types_loss)
            
            self.types_train = self.opt.minimize(self.types_loss)
            
    
    def split(self, arr, batch_size):
        batch = []
        for item in arr:
            if len(batch) >= batch_size:
                yield batch
                batch = []
            
            batch.append(item)
        
        if len(batch) > 0:
            yield batch
            
    
    def batch_as_feed_dict(self, batch, lr, dropout, l2):
        imgs = []
        types = []
        for ctrl in batch:
            img = misc.imread(ctrl['img_file'])
            img = (img - 128.)/128.
            shape = img.shape
            if shape[0] != 224 or shape[1] != 224:
                print('skip, image: {}, control {}'.format(shape, ctrl))
                continue

            imgs.append(img)

            ctrl_type = Types.all_types.index(ctrl['type'])
            types.append(ctrl_type)

        feed_dict = {
            self.a3c_model.img: imgs,
            self.a3c_model.lr: lr,
            self.a3c_model.dropout: dropout,
            self.a3c_model.l2: l2,

            self.control_types: types                
        }
        
        return feed_dict
    
    def train_types(self, controls, batch_size = 10, lr = 0.01, dropout = 0.7, l2 = 0.001, print_loss = True):
        random.shuffle(controls)
        
        batches = self.split(controls, batch_size)
        for batch in batches:
            feed_dict = self.batch_as_feed_dict(batch, lr, dropout, l2)
            
            _, loss = self.session.run([self.types_train, self.types_loss], feed_dict = feed_dict)
            
            if print_loss:
                print(loss)
    
    def measure_types_acc(self, controls):
        batches = self.split(controls, 10)
        correct = 0
        total = 0
        for batch in batches:
            feed_dict = self.batch_as_feed_dict(batch, 0.1, 1., 0.1)
            
            logits = self.session.run(self.types_logits, feed_dict = feed_dict)
            predicted = np.argmax(logits, axis = -1)
            total += len(predicted)
            correct += sum(predicted == feed_dict[self.control_types])
        
        if total == 0:
            return 0.
        
        return float(correct) / total

In [None]:
tf.reset_default_graph()
session = tf.Session()

a3c = A3CModel(len(Actions.actions), session = session)
a3c.init()

model = ControlsModel(a3c)
model.add_types()

session.run(tf.global_variables_initializer())

In [None]:
controls = []
with open('controls_dataset.jsonl') as f:
    for line in f:
        ctrl = json.loads(line)
        controls.append(ctrl)

In [None]:
for i in range(1):
    model.train_types(controls, batch_size = 10, lr = 0.01, dropout = 0.7)
    model.measure_types_acc(controls)
    print('acc: ', model.measure_types_acc(controls))

In [None]:
stat = {}
max_cnt = 0
for ctrl in controls:
    t = ctrl['type']
    stat[t] = stat.get(t, 0) + 1
    max_cnt = max(max_cnt, stat[t])

print('baseline: ', max_cnt / len(controls))
print('stat: ', stat)
