In [1]:
%load_ext line_profiler

In [2]:
from collections import Counter
from functools import partial

import numpy as np
rng = np.random.default_rng()
np.set_printoptions(suppress=True, linewidth=150)
from tqdm import tqdm, trange
from tqdm.contrib.concurrent import process_map

In [3]:
def softmax(x, b, axis=-1):
    x = x - np.max(x)
    return np.exp(b*x) / np.exp(b*x).sum(axis=axis, keepdims=True)

def make_model(map_size=5, num_nodes=128, num_dense=4):
    
    import tensorflow as tf
    for gpu in tf.config.list_physical_devices('GPU'):
        tf.config.experimental.set_memory_growth(gpu, True)
    
    inputs = tf.keras.layers.Input(shape=(map_size*map_size+map_size+1))
    x = tf.keras.layers.Flatten()(inputs)
    
    for _ in range(num_dense):
        x = tf.keras.layers.Dense(num_nodes, activation='relu')(x)
    
    # output1 = tf.keras.layers.Dense(5, name='Y0')(x)
    # output2 = tf.keras.layers.Dense(5, name='Y1')(x)
    # model = tf.keras.models.Model(inputs=inputs, outputs=[output1, output2])
    
    output1 = tf.keras.layers.Dense(5, name='Y0')(x)
    output1 = tf.keras.layers.Softmax()(output1)
    model = tf.keras.models.Model(inputs=inputs, outputs=output1)

    
    opt = tf.keras.optimizers.Adam(learning_rate=0.01)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    model.compile(optimizer=opt, loss=loss_fn, metrics=['accuracy'])
    return model

def softmax(x, b, axis=-1):
    x = x - np.max(x)
    return np.exp(b*x) / np.exp(b*x).sum(axis=axis, keepdims=True)

def make_x():
    x = arrs
    
    rand = np.random.random(arrs.shape)
    rand = (rand - 0.5) * 2 / 10
    # x = x + rand
    
    x = rng.permuted(x, axis=-1)
    return x

def eval_model(model, n = 1, disable=True):
    rewards = []
    branch_rewards = []
    for _ in trange(n, disable=disable):
        x = make_x()
        x_tree = np.concatenate([np.eye(6).astype(int), np.tile(x.flatten(), (6,1))], axis=1)
        y_pred = model(x_tree)
        branch_reward = (x_tree[0,6:].reshape(5,5) * y_pred[1:].numpy()).sum(axis=1)
        rewards.append(y_pred[0].numpy())
        # reward = (y_pred[0] * branch_reward).numpy().sum()
        # rewards.append(reward)
        branch_rewards.append(branch_reward)
    return np.array(rewards), np.array(branch_rewards)

def make_data(b, n=1, disable=True):
    X, Y = [], []
    
    
    for _ in trange(n, disable=disable):
        x = make_x()
        r1 = (softmax(x, b) * x)

        
        x0 = [1,0,0,0,0,0] + list(x.flatten())
        y0 = rng.choice(np.arange(5), p=softmax((softmax(x, b) * x).sum(axis=1), b))
        
        pos = [0,0,0,0,0,0]
        pos[y0+1] = 1
        x1 = pos + list(x.flatten())
        y1 = rng.choice(np.arange(5), p=softmax(x, b)[y0])


        X.append(x0)
        Y.append(y0)
        X.append(x1)
        Y.append(y1)
    
    X = np.array(X)
    Y = np.array(Y)
    
    return X,Y

In [4]:
ns = [0,5,5,6,7,8,9,10,11,12,13,14,15,16,17]

students = np.logspace(-2, 1, 16)
display(students)

arrs = np.triu(np.tile(np.arange(32, 37),(5,1)).T) - 20
display(arrs)
display(arrs.sum())

array([ 0.01      ,  0.01584893,  0.02511886,  0.03981072,  0.06309573,  0.1       ,  0.15848932,  0.25118864,  0.39810717,  0.63095734,  1.        ,
        1.58489319,  2.51188643,  3.98107171,  6.30957344, 10.        ])

array([[ 12,  12,  12,  12,  12],
       [-20,  13,  13,  13,  13],
       [-20, -20,  14,  14,  14],
       [-20, -20, -20,  15,  15],
       [-20, -20, -20, -20,  16]])

0

In [8]:
def train_model(student):
    verbose = False
    if student == students[len(students)//2]:
        print(f'student == {student:.3f}')
        verbose = True
    model = make_model()
    Xtrain, Ytrain = make_data(student, 2**20, disable=not verbose)
    model.fit(Xtrain, Ytrain, verbose=verbose)
    reward = eval_model(model, 10000, disable=not verbose)
    model.save(f'models/starting_{student:.3f}.keras')
    return reward

In [None]:
%%time
rewards = process_map(train_model, students, disable=True, max_workers=8)

In [12]:
for r, s in zip(rewards, students):
    
    res = r[0]*r[1]
    res = res.sum(axis=1).mean()
    
    print(f"{s:.2f} {res}")

0.01 0.2964804212139081
0.01 0.33785846288511184
0.02 0.4378278642599571
0.02 0.647389075434481
0.03 0.693736171869005
0.03 0.7874579513270769
0.04 7.078741750161363
0.05 7.795493925425227
0.06 9.658377430802128
0.08 11.098758554756069
0.10 11.640565835966449
0.13 12.342592899024144
0.16 12.941697855760658
0.20 13.488325845577119
0.25 13.93474442162644
0.32 13.7257961992912
0.40 13.984633652485046
0.50 14.146800002602752
0.63 14.408104944015532
0.79 14.844945047422428
1.00 14.989739154708671


In [13]:
%%time
for b in students:
    r1s = []
    r2s = []
    for _ in trange(10000, disable=True):
        x = make_x()
        r1 = (softmax(x, b) * x).sum(axis=1)
        r2 = (softmax(r1, b) * r1).sum()
        
        r1s.append(r1)
        r2s.append(r2)
        
    r1s = np.array(r1s)
    r2s = np.array(r2s)
    print(f"{b:.3f}, {r2s.mean():.3f}, {r1s.mean(axis=0)}")


0.010, 2.536, [ 12.           7.97248005   3.05845208  -2.98519532 -10.5030866 ]
0.013, 3.140, [12.          8.32589919  3.70121042 -2.19239231 -9.83719922]
0.016, 3.866, [12.          8.74106985  4.4790896  -1.19650299 -8.95987224]
0.020, 4.725, [12.          9.2186661   5.40557111  0.04450878 -7.79798594]
0.025, 5.718, [12.          9.75305389  6.48399883  1.56902907 -6.25605078]
0.032, 6.834, [12.         10.3294776   7.6987513   3.39690653 -4.21951701]
0.040, 8.044, [12.         10.92197276  9.00499434  5.50362179 -1.57805424]
0.050, 9.305, [12.         11.49381646 10.32199117  7.78703938  1.70798935]
0.063, 10.564, [12.         12.00263198 11.53914376 10.04749419  5.48399391]
0.079, 11.760, [12.         12.41085569 12.54299769 12.02067143  9.28803482]
0.100, 12.808, [12.         12.69849399 13.26000329 13.48333628 12.45304773]
0.126, 13.599, [12.         12.87102756 13.68920879 14.37096085 14.51479765]
0.158, 14.082, [12.         12.95589895 13.89676861 14.79651112 15.52717187]
0.

In [10]:
def eval_training(student, teacher, verbose=False):
    import tensorflow as tf
    for gpu in tf.config.list_physical_devices('GPU'):
        tf.config.experimental.set_memory_growth(gpu, True)

    
    if student == -1:
        model2 = make_model()
    else:
        model2 = tf.keras.models.load_model(f"models/starting_{student:.3f}.keras")
    # print(f"{b:.2f}", end=':    ')
    nsum = 0
    
    X, Y = currs[f"{teacher:.3f}"]
    perm = np.random.permutation(len(X))
    X, Y = X[perm], Y[perm]
    
    rewards = []
    for i in ns:
        if i != 0:
            n = 2**i
            # X, Y = make_data(teacher, n, disable=True)
                        
            model2.fit(X[nsum:nsum+n], Y[nsum:nsum+n], verbose=False)
            nsum += n
            
        r0, r1 = eval_model(model2, 50)
        reward = (r0*r1).sum(axis=1).mean()
        rewards.append(reward)
        # print(f"{nsum}    ", end='\r')
    return rewards

In [11]:
def exp(student, verbose=True, n=20):
    # teachers = np.arange(student, 0.40001, 0.02)
    
    teachers = students
    
    if verbose:
        print(f'Original Student: {student:.3f}')
        print()
        print('         Curriculum Size')
        print('Teacher', end='')
        nsum = 0
        for i in ns:
            if i: nsum += 2**i
            print(f"{nsum:>8}", end='')
        print()


    all_res = []
    for teacher in teachers:
        # print(f"teacher = {teacher:.2f} ", end='\r')
        res = process_map(eval_training, [student]*n, [teacher]*n, max_workers=10, disable=True)
        res = np.array(res).mean(axis=0)
        all_res.append(res)
        
        if verbose:
            print(f"  {teacher:.3f}", end='  ')
            for reward in res:
                print(f"{reward: 6.2f}", end='  ')
            print()

    all_res = np.array(all_res)

    return all_res

In [12]:
len(students)

31

In [13]:
ntotal = (2**np.array(ns[1:])).sum()
ntotal

1048576

In [14]:
ntotal = (2**np.array(ns[1:])).sum()

currs = process_map(make_data, students, [ntotal]*len(students), max_workers=4)
currs = dict(zip([f"{s:.3f}" for s in students], currs))

  0%|          | 0/31 [00:00<?, ?it/s]

In [15]:
%%time
res_fresh = exp(-1, n=60)

Original Student: -1.000

         Curriculum Size
Teacher       0      32      64     128     256     512    1024    2048    4096    8192   16384   32768   65536  131072  262144  524288 1048576
  0.010   -0.35   -0.49    0.20    0.68    0.52    0.65    0.40    0.70    0.35    0.52    0.43    0.34    0.36    0.35    0.37    0.35    0.36  
  0.013    0.69   -0.40   -0.22    0.39    1.48    0.40    0.89    0.87    0.58    0.64    0.47    0.46    0.45    0.42    0.43    0.43    0.43  
  0.016   -0.56    0.35    0.63    0.52    1.06    0.42    0.50    0.86    0.58    0.75    0.65    0.56    0.52    0.49    0.52    0.50    0.51  
  0.020    0.65   -0.30   -0.02    1.04    1.63    0.54    0.95    1.08    0.75    0.94    0.73    0.65    0.60    0.63    0.61    0.61    0.67  
  0.025    0.19    0.58    1.32    0.89    0.20    0.77    1.06    0.66    1.02    0.97    0.83    0.77    0.75    0.76    0.72    0.76    1.14  
  0.032   -0.62    1.44    1.83    0.27    1.01    0.77    1.39    0.87    

In [None]:
%%time
res020 = exp(0.010, n=20)

Original Student: 0.010

         Curriculum Size
Teacher       0      32      64     128     256     512    1024    2048    4096    8192   16384   32768   65536  131072  262144  524288 1048576
  0.010    0.30    0.31    0.29    0.29    0.28    0.28    0.29    0.30    0.30    0.31    0.35    0.39    0.42    0.31    0.38    0.35    0.34  
  0.013    0.30    0.31    0.29    0.29    0.28    0.28    0.30    0.31    0.32    0.35    0.41    0.46    0.48    0.41    0.45    0.45    0.44  
  0.016    0.30    0.31    0.29    0.29    0.28    0.28    0.30    0.32    0.34    0.39    0.47    0.53    0.56    0.48    0.52    0.53    0.50  
  0.020    0.30    0.31    0.29    0.29    0.28    0.29    0.31    0.34    0.37    0.43    0.54    0.60    0.65    0.57    0.61    0.62    0.58  
  0.025    0.30    0.31    0.29    0.30    0.28    0.29    0.33    0.35    0.42    0.47    0.57    0.57    0.66    0.67    0.63    0.71    0.64  
  0.032    0.30    0.31    0.29    0.30    0.28    0.29    0.33    0.35    0

In [None]:
%%time
res010 = exp(0.020, n=20)

In [None]:
%%time
res040 = exp(0.040, n=20)

In [13]:
%%time
res020 = exp(0.020)

Original Student: 0.020

         Curriculum Size
Teacher       0      32      64     128     256     512    1024    2048    4096    8192   16384   32768   65536  131072  262144
  0.010    0.65    0.65    0.61    0.65    0.69    0.67    0.64    0.62    0.58    0.44    0.39    0.35    0.38    0.27    0.36  
  0.013    0.65    0.65    0.61    0.65    0.69    0.67    0.65    0.63    0.59    0.47    0.44    0.43    0.46    0.35    0.44  
  0.016    0.65    0.65    0.61    0.65    0.69    0.67    0.65    0.63    0.58    0.41    0.43    0.45    0.50    0.46    0.50  
  0.020    0.65    0.65    0.61    0.65    0.69    0.67    0.66    0.65    0.61    0.45    0.49    0.53    0.56    0.55    0.59  
  0.025    0.65    0.65    0.61    0.65    0.69    0.68    0.67    0.71    0.72    0.62    0.67    0.68    0.65    0.62    0.68  
  0.032    0.65    0.65    0.61    0.65    0.69    0.68    0.68    0.72    0.74    0.65    0.71    0.76    0.70    0.69    0.74  
  0.040    0.65    0.65    0.61    0.65   

In [18]:
%%time
res_fresh = exp(-1)

Original Student: -1.000

         Curriculum Size
Teacher       0      32      64     128     256     512    1024    2048    4096    8192   16384   32768   65536  131072  262144
  0.010    1.88    0.66    0.41    0.45    0.85    1.04   -0.04   -0.02    0.43    0.43    0.42    0.38    0.41    0.26    0.34  
  0.013    0.60    1.12    1.19    0.86    0.41    1.27   -0.20    0.10    0.54    0.66    0.35    0.60    0.47    0.37    0.43  
  0.016    0.32   -1.52   -0.68    1.11    1.93    0.96    0.87    0.64    0.36    0.30    0.69    0.48    0.53    0.50    0.50  
  0.020   -0.00    0.64    0.66    0.79    2.07    1.34    0.82    0.59    0.10    0.37    0.63    0.59    0.63    0.60    0.60  
  0.025    1.92    0.77    0.51    1.13    0.32    0.81    1.17    1.19    0.75    0.84    0.77    0.76    0.65    0.62    0.72  
  0.032    2.64   -0.06    0.73    1.47   -0.61    0.93    0.90    0.90    1.08    0.83    0.69    0.84    0.90    0.78    0.86  
  0.040    1.08    0.04    0.99    2.99  

In [26]:
res020.round(2)

array([[ 1.88,  0.66,  0.41,  0.45,  0.85,  1.04, -0.04, -0.02,  0.43,  0.43,  0.42,  0.38,  0.41,  0.26,  0.34],
       [ 0.6 ,  1.12,  1.19,  0.86,  0.41,  1.27, -0.2 ,  0.1 ,  0.54,  0.66,  0.35,  0.6 ,  0.47,  0.37,  0.43],
       [ 0.32, -1.52, -0.68,  1.11,  1.93,  0.96,  0.87,  0.64,  0.36,  0.3 ,  0.69,  0.48,  0.53,  0.5 ,  0.5 ],
       [-0.  ,  0.64,  0.66,  0.79,  2.07,  1.34,  0.82,  0.59,  0.1 ,  0.37,  0.63,  0.59,  0.63,  0.6 ,  0.6 ],
       [ 1.92,  0.77,  0.51,  1.13,  0.32,  0.81,  1.17,  1.19,  0.75,  0.84,  0.77,  0.76,  0.65,  0.62,  0.72],
       [ 2.64, -0.06,  0.73,  1.47, -0.61,  0.93,  0.9 ,  0.9 ,  1.08,  0.83,  0.69,  0.84,  0.9 ,  0.78,  0.86],
       [ 1.08,  0.04,  0.99,  2.99,  1.46,  1.41,  2.07,  0.72,  1.65,  0.85,  0.72,  0.91,  0.84,  0.95,  1.04],
       [ 1.8 ,  0.01,  0.93,  1.95,  1.11,  1.87,  1.51,  0.6 ,  1.3 ,  1.01,  0.7 ,  0.88,  1.15,  1.09,  1.13],
       [ 1.28,  0.96,  3.03,  0.88, -0.09,  1.72,  0.72,  0.77,  1.77,  1.38,  0.92,  0.

In [13]:
%%time
res040 = exp(0.040)

Original Student: 0.040

         Curriculum Size
Teacher       0      32      64     128     256     512    1024    2048    4096    8192   16384   32768   65536
  0.010    7.18    7.04    7.05    7.00    6.90    6.98    6.47    5.79    4.76    3.11    2.10    1.65    1.36  
  0.013    7.18    7.04    7.05    7.00    6.90    6.98    6.49    5.89    5.19    3.85    2.89    2.44    2.02  
  0.016    7.18    7.04    7.05    7.01    6.94    7.07    6.60    6.25    5.55    4.52    3.45    3.14    2.77  
  0.020    7.18    7.04    7.06    7.02    6.97    7.14    6.71    6.50    5.90    5.08    4.42    3.97    3.71  
  0.025    7.18    7.05    7.06    7.02    7.05    7.12    6.84    6.33    5.44    5.54    5.08    5.10    4.81  
  0.032    7.18    7.06    7.07    7.04    7.11    7.20    6.95    6.73    6.01    6.10    5.84    6.01    5.91  
  0.040    7.18    7.05    7.07    7.05    7.14    7.29    7.16    7.10    6.66    7.03    6.79    7.11    7.11  
  0.050    7.18    7.05    7.07    7.05 

In [14]:
%%time
res020 = exp(0.100)

Original Student: 0.100

         Curriculum Size
Teacher       0      32      64     128     256     512    1024    2048    4096    8192   16384   32768   65536
  0.010   11.71   11.62   11.52   11.40   10.97    9.13    5.82    2.99    1.70    0.62    0.46    0.39    0.41  
  0.013   11.71   11.62   11.52   11.40   10.98    9.21    6.04    3.33    1.80    1.05    0.76    0.60    0.69  
  0.016   11.71   11.62   11.52   11.42   11.05    9.35    6.59    4.42    2.50    1.32    1.06    0.99    1.11  
  0.020   11.71   11.62   11.53   11.44   11.10    9.64    7.25    5.05    3.91    2.43    2.22    2.04    2.20  
  0.025   11.71   11.63   11.55   11.48   11.29   10.47    7.76    4.92    3.59    3.41    3.00    3.36    3.66  
  0.032   11.71   11.63   11.55   11.49   11.36   10.71    7.49    5.08    3.90    4.27    4.29    4.78    5.22  
  0.040   11.71   11.63   11.56   11.53   11.45   10.96    8.78    6.83    5.56    6.20    5.93    6.31    6.60  
  0.050   11.71   11.64   11.57   11.55 