# Evotuning UniRep

Getting familiar with the UniRep model and how to fine-tune it.

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
from unirep import babbler64, babbler1900
from sklearn.model_selection import train_test_split
import random

In [2]:
def format_sequences(seq_file:str, formatted_file:str, model, use_stop:bool=False, max_len:int=275) -> None:
    with open(seq_file, "r") as source:
        with open(formatted_file, "w") as destination:
            for i,seq in enumerate(source):
                seq = seq.strip()
                if model.is_valid_seq(seq) and len(seq) < max_len: 
                    formatted = ",".join(map(str, model.format_seq(seq, stop=use_stop)))
                    destination.write(formatted)
                    destination.write('\n')

def format_dataset(seqs, model, max_len:int=275, use_stop=True) -> list:
    """
    formats input sequences into integer lists using model's vocabulary
    
    returns a list of integer lists
    """
    seqs_fmt = []
    for s in seqs:
        s = s.strip()
        if model.is_valid_seq(s) and len(s) < max_len: 
            seqs_fmt.append(model.format_seq(s, stop=use_stop))
    return seqs_fmt

def batch_generator(data:list, batch_size:int, shuffle:bool=True) -> np.array:
    """
    creates a batch generator over the input dataset that pads sequences
    to the length of the longest seq in a batch. Optionally shuffles the
    input dataset in place at the start. Remainders are dropped.
    -------------------------------
    data - list of integer lists
    
    returns np.array of shape (batch_size, max_len)
    """
    if shuffle:
        random.shuffle(data)
    data_size = len(data)
    assert data_size >= batch_size, "dataset must be larger than batch_size"
    n_batches = data_size // batch_size # will drop any remainders
    for i in range(n_batches):
        # make a batch
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        batch = data[start_idx:end_idx]
        max_len = max([len(s) for s in batch])
        batch_pad = np.zeros((batch_size, max_len), dtype=np.int32)
        for i, seq in enumerate(batch):
            batch_pad[i, 0:len(seq)] = seq
        yield batch_pad  

def calc_val_loss(model, sess, val_data:list) -> float:
    """computes the validation set loss. tf Variables should all be initialized within the session"""
    batch_sz = model._batch_size
    # get training operations
    _, loss, x_placeholder, y_placeholder, batch_size_placeholder, initial_state_placeholder = model.get_babbler_ops()
    val_losses = []
    for batch in batch_generator(val_data, batch_sz, shuffle=True):
        batch_x, batch_y = model.split_to_tuple(batch)
        loss_ = sess.run(loss,
                feed_dict={
                     x_placeholder: batch_x,
                     y_placeholder: batch_y,
                     batch_size_placeholder: batch_sz,
                     initial_state_placeholder: model._zero_state
                }
        )
        val_losses.append(loss_)
    return np.mean(val_losses)

def train_model(model, train_data:list, val_data:list, n_epochs=1, lr=1e-5, shuffle=True, 
                ckpt_dir=None, val_freq=50):
    """trains the model using specified settings"""
    batch_sz = model._batch_size
    # get training operations
    _, loss, x_placeholder, y_placeholder, batch_size_placeholder, initial_state_placeholder = model.get_babbler_ops()
    # create an optimizer to fine-tune the model
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    fine_tuning_op = optimizer.minimize(loss)
    # train model
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        i = 0 # global iteration
        min_val_loss = np.inf
        for e in range(n_epochs):
            print("Running epoch: %i" % (e + 1))
            # train on the training set batches
            for batch in batch_generator(train_data, batch_sz, shuffle=shuffle):
                i += 1
                batch_x, batch_y = model.split_to_tuple(batch)
                loss_, __, = sess.run([loss, fine_tuning_op],
                        feed_dict={
                             x_placeholder: batch_x,
                             y_placeholder: batch_y,
                             batch_size_placeholder: batch_sz,
                             initial_state_placeholder: model._zero_state
                        }
                )
                print("Iteration {0}: {1}".format(i, loss_))
                if i % val_freq == 0:
                    # calculate validation set performance
                    val_loss = calc_val_loss(model, sess, val_data)
                    print("Validation set loss: {}".format(val_loss))
                    if val_loss < min_val_loss:
                        min_val_loss = val_loss
                        if ckpt_dir is not None:
                            # store model weights
                            print("Checkpointing model weights")
                            model.dump_weights(sess, ckpt_dir)


In [23]:
# global variables
RANDOM_SEED = 42
BATCH_SZ = 256
UNIREP_WEIGHTS = "unirep_weights/1900_weights/"
eUNIREP_WEIGHTS = "eunirep_weights"

In [4]:
model1900 = babbler1900(model_path=UNIREP_WEIGHTS, batch_size=BATCH_SZ)

  from ._conv import register_converters as _register_converters


In [6]:
model1900 = babbler1900(model_path=eUNIREP_WEIGHTS, batch_size=BATCH_SZ)

  from ._conv import register_converters as _register_converters


## Basic model testing

Formatting data, producing batches for training

In [4]:
test_seq = "MRKGEELFTGVVPILVELDGDVNGHKFSVRGEGEGDATNGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFARYPDHMKQHDFFKSAMPEGYVQERTISFKDDGTYKTRAEVKFEGDTLVNRIE"

In [5]:
np.array(model1900.format_seq(test_seq))

array([24,  1,  2,  4, 13,  6,  6, 21, 18,  8, 13, 16, 16, 14, 17, 21, 16,
        6, 21,  5, 13,  5, 16,  9, 13,  3,  4, 18,  7, 16,  2, 13,  6, 13,
        6, 13,  5, 15,  8,  9, 13,  4, 21,  8, 21,  4, 18, 17, 11,  8,  8,
       13,  4, 21, 14, 16, 14, 20, 14,  8, 21, 16,  8,  8, 21,  8, 19, 13,
       16, 10, 11, 18, 15,  2, 19, 14,  5,  3,  1,  4, 10,  3,  5, 18, 18,
        4,  7, 15,  1, 14,  6, 13, 19, 16, 10,  6,  2,  8, 17,  7, 18,  4,
        5,  5, 13,  8, 19,  4,  8,  2, 15,  6, 16,  4, 18,  6, 13,  5,  8,
       21, 16,  9,  2, 17,  6])

In [7]:
# need to use stop tokens when fine-tuning the model
format_sequences("data/vh_vl_hum_pair_seq.txt", "data/vh_vl_hum_pair_seq_fmt.txt", model1900, use_stop=True)

In [7]:
# bucketting: use large bucket because most Fv sequences are of very similar length
bucket_op = model1900.bucket_batch_pad("data/vh_vl_hum_pair_seq_fmt.txt", interval=1000)

In [9]:
# test producing batches
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    batch = sess.run(bucket_op)

print(batch)
print(batch.shape)

[[24  6 16 ...  0  0  0]
 [24 10 21 ...  0  0  0]
 [24  6 16 ...  0  0  0]
 ...
 [24 10 16 ...  0  0  0]
 [24 10 16 ...  0  0  0]
 [24 10 16 ...  0  0  0]]
(128, 247)


In [10]:
# this is how to produce input/target pairs from a batch (shift input sequence 1 position to the left) 
batch_x, batch_y = model1900.split_to_tuple(batch)

## Testing model fine-tuning

Fine-tune on a set of test sequences using the published code.

In [8]:
# get training operations
_, loss, x_placeholder, y_placeholder, batch_size_placeholder, initial_state_placeholder = model1900.get_babbler_ops()

In [9]:
# create an optimizer to fine-tune the model
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
fine_tuning_op = optimizer.minimize(loss)

In [10]:
# test training
num_iters = 100
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(num_iters):
        batch = sess.run(bucket_op)
        batch_x, batch_y = model1900.split_to_tuple(batch)
        loss_, __, = sess.run([loss, fine_tuning_op],
                feed_dict={
                     x_placeholder: batch_x,
                     y_placeholder: batch_y,
                     batch_size_placeholder: BATCH_SZ,
                     initial_state_placeholder: model1900._zero_state
                }
        )
        print("Iteration {0}: {1}".format(i,loss_))


Iteration 0: 2.3585262298583984
Iteration 1: 2.1982650756835938
Iteration 2: 2.0852699279785156
Iteration 3: 1.9932196140289307
Iteration 4: 1.8987133502960205
Iteration 5: 1.8282603025436401
Iteration 6: 1.7320784330368042
Iteration 7: 1.6901164054870605
Iteration 8: 1.6452765464782715
Iteration 9: 1.5979597568511963
Iteration 10: 1.5181970596313477
Iteration 11: 1.451095461845398
Iteration 12: 1.4329373836517334
Iteration 13: 1.3730506896972656
Iteration 14: 1.353617787361145
Iteration 15: 1.3298561573028564
Iteration 16: 1.266579270362854
Iteration 17: 1.2201709747314453
Iteration 18: 1.2068146467208862
Iteration 19: 1.1697735786437988
Iteration 20: 1.1552271842956543
Iteration 21: 1.1199240684509277
Iteration 22: 1.1223554611206055
Iteration 23: 1.0888433456420898
Iteration 24: 1.057464361190796
Iteration 25: 1.0247516632080078
Iteration 26: 1.0601595640182495
Iteration 27: 1.027543544769287
Iteration 28: 1.0130473375320435
Iteration 29: 1.0150216817855835
Iteration 30: 0.960443735

## Make train / validation split

### Paired VH.VL sequences

Split using a 90:10 random split and persist the files.

In [6]:
data = pd.read_csv("data/vh_vl_hum_pair_seq.txt", header=None)
train_data, val_data = train_test_split(data, test_size=0.1, shuffle=True, random_state=RANDOM_SEED)
# save train/val sets to disk
train_data.to_csv("data/vh_vl_hum_pair_seq_train.csv", index=False, header=False)
val_data.to_csv("data/vh_vl_hum_pair_seq_val.csv", index=False, header=False)

### Unpaired VH, VL sequences

Use a **99:1** random split for VH data and **98:2** for VL data so that the validation set is approx. **10k** sequences. We can do this because the datasets are considerably larger.

In [5]:
vh_data = pd.read_csv("data/vh_human_unpaired.csv.gz", header=None)
vh_train, vh_val = train_test_split(vh_data, test_size=0.01, shuffle=True, random_state=RANDOM_SEED)
print("Train set size: %i" % len(vh_train))
print("Val set size: %i" % len(vh_val))
# save train/val sets to disk
vh_train.to_csv("data/vh_human_unpaired_train.csv.gz", index=False, header=False, compression="gzip")
vh_val.to_csv("data/vh_human_unpaired_val.csv.gz", index=False, header=False, compression="gzip")

In [15]:
vl_data = pd.read_csv("data/vl_human_unpaired.csv.gz", header=None)
vl_train, vl_val = train_test_split(vl_data, test_size=0.02, shuffle=True, random_state=RANDOM_SEED)
print("Train set size: %i" % len(vl_train))
print("Val set size: %i" % len(vl_val))
# save train/val sets to disk
vl_train.to_csv("data/vl_human_unpaired_train.csv.gz", index=False, header=False, compression="gzip")
vl_val.to_csv("data/vl_human_unpaired_val.csv.gz", index=False, header=False, compression="gzip")

Train set size: 542851
Val set size: 11079


## Load training / validation data

### Paired VH.VL sequences

In [4]:
# load train/val data
vh_vl_train = pd.read_csv("data/vh_vl_hum_pair_seq_train.csv", header=None)
vh_vl_val = pd.read_csv("data/vh_vl_hum_pair_seq_val.csv", header=None)

# format data
vh_vl_train_fmt = format_dataset(vh_vl_train[0], model1900)
vh_vl_val_fmt = format_dataset(vh_vl_val[0], model1900)

print("Train set size: %i" % len(vh_vl_train_fmt))
print("Val set size %i" % len(vh_vl_val_fmt))

### Unpaired VH, VL sequences

In [16]:
# load train/val data
vh_train = pd.read_csv("data/vh_human_unpaired_train.csv.gz", header=None)
vh_val = pd.read_csv("data/vh_human_unpaired_val.csv.gz", header=None)
vl_train = pd.read_csv("data/vl_human_unpaired_train.csv.gz", header=None)
vl_val = pd.read_csv("data/vl_human_unpaired_val.csv.gz", header=None)

In [4]:
# format data
vh_train_fmt = format_dataset(vh_train[0], model1900)
vh_val_fmt = format_dataset(vh_val[0], model1900)
print("Train set size: %i" % len(vh_train_fmt))
print("Val set size %i" % len(vh_val_fmt))

In [21]:
# format data
vl_train_fmt = format_dataset(vl_train[0], model1900)
vl_val_fmt = format_dataset(vl_val[0], model1900)
print("Train set size: %i" % len(vl_train_fmt))
print("Val set size %i" % len(vl_val_fmt))

Train set size: 542851
Val set size 11079


## Perform evotuning

### Paired VH.VL sequences

In [25]:
train_model(model1900, vh_vl_train_fmt, vh_vl_val_fmt, n_epochs=2, ckpt_dir=eUNIREP_WEIGHTS, val_freq=20)

Running epoch: 1
Iteration 1: 2.3604955673217773
Iteration 2: 2.352841377258301
Iteration 3: 2.32208514213562
Iteration 4: 2.3070454597473145
Iteration 5: 2.2848551273345947
Iteration 6: 2.2726306915283203
Iteration 7: 2.257080554962158
Iteration 8: 2.2407965660095215
Iteration 9: 2.224301338195801
Iteration 10: 2.2126011848449707
Iteration 11: 2.189615249633789
Iteration 12: 2.180088758468628
Iteration 13: 2.164987564086914
Iteration 14: 2.159646987915039
Iteration 15: 2.1344480514526367
Iteration 16: 2.1224279403686523
Iteration 17: 2.1162428855895996
Iteration 18: 2.1007754802703857
Iteration 19: 2.0874247550964355
Iteration 20: 2.0763068199157715
Validation set loss: 2.0638530254364014
Checkpointing model weights
embed_matrix:0
[[ 9.45316315e-01 -6.63625717e-01 -2.47278929e-01 -7.91995525e-01
   2.66303062e-01 -6.48389816e-01  4.97818947e-01  3.40559959e-01
   2.15496302e-01 -8.56987476e-01]
 [ 2.47897342e-01  2.91691691e-01 -1.54972672e+00  1.97414660e+00
  -1.40773439e+00  2.2510

Iteration 21: 2.060631513595581
Iteration 22: 2.0527069568634033
Iteration 23: 2.0321450233459473
Iteration 24: 2.0320181846618652
Iteration 25: 2.015864849090576
Iteration 26: 2.010056972503662
Iteration 27: 2.002638339996338
Iteration 28: 1.9847220182418823
Iteration 29: 1.9726496934890747
Iteration 30: 1.9570761919021606
Iteration 31: 1.948577642440796
Iteration 32: 1.9332184791564941
Iteration 33: 1.9207336902618408
Iteration 34: 1.9123426675796509
Iteration 35: 1.8973829746246338
Iteration 36: 1.8911224603652954
Iteration 37: 1.885469913482666
Iteration 38: 1.8687690496444702
Iteration 39: 1.8479735851287842
Iteration 40: 1.849135160446167
Validation set loss: 1.8384064435958862
Checkpointing model weights
embed_matrix:0
[[ 9.4531631e-01 -6.6362572e-01 -2.4727893e-01 -7.9199553e-01
   2.6630306e-01 -6.4838982e-01  4.9781895e-01  3.4055996e-01
   2.1549630e-01 -8.5698748e-01]
 [ 2.4813329e-01  2.9190114e-01 -1.5496962e+00  1.9739761e+00
  -1.4075466e+00  2.2508352e+00 -1.0549065e+0

Iteration 42: 1.8186644315719604
Iteration 43: 1.8181188106536865
Iteration 44: 1.8053663969039917
Iteration 45: 1.7989122867584229
Iteration 46: 1.7744907140731812
Iteration 47: 1.7724850177764893
Iteration 48: 1.7684277296066284
Iteration 49: 1.7621151208877563
Iteration 50: 1.7495207786560059
Iteration 51: 1.7281476259231567
Iteration 52: 1.7078548669815063
Iteration 53: 1.707136869430542
Iteration 54: 1.6997097730636597
Iteration 55: 1.709835171699524
Iteration 56: 1.6955236196517944
Iteration 57: 1.674385905265808
Iteration 58: 1.6794599294662476
Iteration 59: 1.6657766103744507
Iteration 60: 1.6403164863586426
Validation set loss: 1.638649582862854
Checkpointing model weights
embed_matrix:0
[[ 9.4531631e-01 -6.6362572e-01 -2.4727893e-01 -7.9199553e-01
   2.6630306e-01 -6.4838982e-01  4.9781895e-01  3.4055996e-01
   2.1549630e-01 -8.5698748e-01]
 [ 2.4838099e-01  2.9214755e-01 -1.5495369e+00  1.9736888e+00
  -1.4073393e+00  2.2505920e+00 -1.0551145e+00  1.9344575e+00
   4.7912890e

Iteration 63: 1.6090056896209717
Iteration 64: 1.6020092964172363
Iteration 65: 1.6028627157211304
Iteration 66: 1.6065953969955444
Iteration 67: 1.5923277139663696
Iteration 68: 1.577838659286499
Iteration 69: 1.5730609893798828
Iteration 70: 1.5724117755889893
Iteration 71: 1.554697871208191
Iteration 72: 1.5580081939697266
Iteration 73: 1.5116350650787354
Iteration 74: 1.5447421073913574
Iteration 75: 1.5318025350570679
Iteration 76: 1.5167207717895508
Iteration 77: 1.4968770742416382
Iteration 78: 1.49821138381958
Iteration 79: 1.4695228338241577
Iteration 80: 1.501794695854187
Validation set loss: 1.4710005521774292
Checkpointing model weights
embed_matrix:0
[[ 9.4531631e-01 -6.6362572e-01 -2.4727893e-01 -7.9199553e-01
   2.6630306e-01 -6.4838982e-01  4.9781895e-01  3.4055996e-01
   2.1549630e-01 -8.5698748e-01]
 [ 2.4860415e-01  2.9240626e-01 -1.5492815e+00  1.9734262e+00
  -1.4071857e+00  2.2503958e+00 -1.0553311e+00  1.9342350e+00
   4.7892049e-01  2.1527829e+00]
 [-2.3108695e+

Iteration 81: 1.4790693521499634
Iteration 82: 1.4762099981307983
Iteration 83: 1.4504681825637817
Iteration 84: 1.4370183944702148
Iteration 85: 1.4206838607788086
Iteration 86: 1.4419970512390137
Iteration 87: 1.429807424545288
Iteration 88: 1.3959097862243652
Iteration 89: 1.4226740598678589
Iteration 90: 1.4041531085968018
Iteration 91: 1.3835899829864502
Iteration 92: 1.3889468908309937
Iteration 93: 1.383147954940796
Iteration 94: 1.3858243227005005
Iteration 95: 1.368599534034729
Iteration 96: 1.3832664489746094
Iteration 97: 1.351119875907898
Iteration 98: 1.3589686155319214
Iteration 99: 1.3413363695144653
Iteration 100: 1.338550090789795
Validation set loss: 1.3386447429656982
Checkpointing model weights
embed_matrix:0
[[ 9.4531631e-01 -6.6362572e-01 -2.4727893e-01 -7.9199553e-01
   2.6630306e-01 -6.4838982e-01  4.9781895e-01  3.4055996e-01
   2.1549630e-01 -8.5698748e-01]
 [ 2.4880153e-01  2.9263237e-01 -1.5490030e+00  1.9732169e+00
  -1.4071124e+00  2.2503138e+00 -1.0555607

Iteration 101: 1.3316149711608887
Iteration 102: 1.3109371662139893
Iteration 103: 1.326864242553711
Iteration 104: 1.3233047723770142
Iteration 105: 1.31251859664917
Iteration 106: 1.296633005142212
Iteration 107: 1.3066174983978271
Iteration 108: 1.3147526979446411
Iteration 109: 1.2783615589141846
Iteration 110: 1.2821468114852905
Iteration 111: 1.2728285789489746
Iteration 112: 1.2717379331588745
Iteration 113: 1.256028413772583
Iteration 114: 1.2749866247177124
Iteration 115: 1.2980268001556396
Iteration 116: 1.264419674873352
Iteration 117: 1.2487927675247192
Iteration 118: 1.2299460172653198
Iteration 119: 1.2580543756484985
Iteration 120: 1.2316769361495972
Validation set loss: 1.2335293292999268
Checkpointing model weights
embed_matrix:0
[[ 9.45316315e-01 -6.63625717e-01 -2.47278929e-01 -7.91995525e-01
   2.66303062e-01 -6.48389816e-01  4.97818947e-01  3.40559959e-01
   2.15496302e-01 -8.56987476e-01]
 [ 2.49004647e-01  2.92832613e-01 -1.54874659e+00  1.97303510e+00
  -1.40707

Iteration 121: 1.2613993883132935
Iteration 122: 1.2293028831481934
Iteration 123: 1.2197444438934326
Iteration 124: 1.234028935432434
Iteration 125: 1.2203621864318848
Iteration 126: 1.236161470413208
Iteration 127: 1.1709864139556885
Iteration 128: 1.2308762073516846
Running epoch: 2
Iteration 129: 1.1882468461990356
Iteration 130: 1.2107423543930054
Iteration 131: 1.1676459312438965
Iteration 132: 1.1851141452789307
Iteration 133: 1.189477562904358
Iteration 134: 1.184917688369751
Iteration 135: 1.1724234819412231
Iteration 136: 1.155107021331787
Iteration 137: 1.1561074256896973
Iteration 138: 1.1805483102798462
Iteration 139: 1.1576683521270752
Iteration 140: 1.1512322425842285
Validation set loss: 1.1459368467330933
Checkpointing model weights
embed_matrix:0
[[ 9.4531631e-01 -6.6362572e-01 -2.4727893e-01 -7.9199553e-01
   2.6630306e-01 -6.4838982e-01  4.9781895e-01  3.4055996e-01
   2.1549630e-01 -8.5698748e-01]
 [ 2.4921446e-01  2.9302832e-01 -1.5485163e+00  1.9728621e+00
  -1.4

Iteration 141: 1.156868815422058
Iteration 142: 1.1617822647094727
Iteration 143: 1.1407716274261475
Iteration 144: 1.1306432485580444
Iteration 145: 1.1366240978240967
Iteration 146: 1.1178923845291138
Iteration 147: 1.1217576265335083
Iteration 148: 1.1237945556640625
Iteration 149: 1.1258301734924316
Iteration 150: 1.11794912815094
Iteration 151: 1.1027498245239258
Iteration 152: 1.1088948249816895
Iteration 153: 1.0870909690856934
Iteration 154: 1.1014472246170044
Iteration 155: 1.0728273391723633
Iteration 156: 1.0803096294403076
Iteration 157: 1.0880398750305176
Iteration 158: 1.0806398391723633
Iteration 159: 1.0851517915725708
Iteration 160: 1.0841621160507202
Validation set loss: 1.0737414360046387
Checkpointing model weights
embed_matrix:0
[[ 9.45316315e-01 -6.63625717e-01 -2.47278929e-01 -7.91995525e-01
   2.66303062e-01 -6.48389816e-01  4.97818947e-01  3.40559959e-01
   2.15496302e-01 -8.56987476e-01]
 [ 2.49385744e-01  2.93178707e-01 -1.54832935e+00  1.97276568e+00
  -1.40

Iteration 161: 1.1093132495880127
Iteration 162: 1.0780775547027588
Iteration 163: 1.0575995445251465
Iteration 164: 1.0633379220962524
Iteration 165: 1.0503755807876587
Iteration 166: 1.057457447052002
Iteration 167: 1.041911244392395
Iteration 168: 1.0422340631484985
Iteration 169: 1.063398838043213
Iteration 170: 1.0638158321380615
Iteration 171: 1.0504807233810425
Iteration 172: 1.0326154232025146
Iteration 173: 1.0509617328643799
Iteration 174: 1.0345159769058228
Iteration 175: 1.0105550289154053
Iteration 176: 1.033798336982727
Iteration 177: 1.0276439189910889
Iteration 178: 1.0253477096557617
Iteration 179: 1.0154099464416504
Iteration 180: 1.015563726425171
Validation set loss: 1.0117703676223755
Checkpointing model weights
embed_matrix:0
[[ 9.4531631e-01 -6.6362572e-01 -2.4727893e-01 -7.9199553e-01
   2.6630306e-01 -6.4838982e-01  4.9781895e-01  3.4055996e-01
   2.1549630e-01 -8.5698748e-01]
 [ 2.4954446e-01  2.9331377e-01 -1.5481584e+00  1.9726956e+00
  -1.4069918e+00  2.250

Iteration 181: 1.0389901399612427
Iteration 182: 0.9965644478797913
Iteration 183: 0.9881734848022461
Iteration 184: 1.0161374807357788
Iteration 185: 1.017624855041504
Iteration 186: 0.9815428853034973
Iteration 187: 1.002522349357605
Iteration 188: 0.997539758682251
Iteration 189: 0.9883897304534912
Iteration 190: 0.9706243276596069
Iteration 191: 0.9736074805259705
Iteration 192: 1.001194953918457
Iteration 193: 0.9918186664581299
Iteration 194: 0.9867278337478638
Iteration 195: 0.9929357767105103
Iteration 196: 0.9736071825027466
Iteration 197: 0.9694329500198364
Iteration 198: 0.9532836675643921
Iteration 199: 0.9778310060501099
Iteration 200: 0.9805277585983276
Validation set loss: 0.9544306397438049
Checkpointing model weights
embed_matrix:0
[[ 9.45316315e-01 -6.63625717e-01 -2.47278929e-01 -7.91995525e-01
   2.66303062e-01 -6.48389816e-01  4.97818947e-01  3.40559959e-01
   2.15496302e-01 -8.56987476e-01]
 [ 2.49679133e-01  2.93436021e-01 -1.54799139e+00  1.97264469e+00
  -1.407

Iteration 201: 0.9597222805023193
Iteration 202: 0.9489302635192871
Iteration 203: 0.982284665107727
Iteration 204: 0.9380064010620117
Iteration 205: 0.9448215961456299
Iteration 206: 0.9515045285224915
Iteration 207: 0.9287463426589966
Iteration 208: 0.9240758419036865
Iteration 209: 0.9024471044540405
Iteration 210: 0.9207359552383423
Iteration 211: 0.9251551628112793
Iteration 212: 0.9580086469650269
Iteration 213: 0.9192925691604614
Iteration 214: 0.9623205065727234
Iteration 215: 0.9208948016166687
Iteration 216: 0.9455440640449524
Iteration 217: 0.9184147119522095
Iteration 218: 0.9102000594139099
Iteration 219: 0.8999665975570679
Iteration 220: 0.8845711946487427
Validation set loss: 0.9022896885871887
Checkpointing model weights
embed_matrix:0
[[ 9.45316315e-01 -6.63625717e-01 -2.47278929e-01 -7.91995525e-01
   2.66303062e-01 -6.48389816e-01  4.97818947e-01  3.40559959e-01
   2.15496302e-01 -8.56987476e-01]
 [ 2.49806002e-01  2.93531418e-01 -1.54782975e+00  1.97260547e+00
  -1.

Iteration 221: 0.8851616382598877
Iteration 222: 0.9021520018577576
Iteration 223: 0.9117578864097595
Iteration 224: 0.9057590961456299
Iteration 225: 0.8949304819107056
Iteration 226: 0.8987687826156616
Iteration 227: 0.8809919357299805
Iteration 228: 0.8968148231506348
Iteration 229: 0.9168075919151306
Iteration 230: 0.8762199878692627
Iteration 231: 0.8730795979499817
Iteration 232: 0.8808779120445251
Iteration 233: 0.8844360113143921
Iteration 234: 0.8854647874832153
Iteration 235: 0.8696203827857971
Iteration 236: 0.8509671688079834
Iteration 237: 0.8942853212356567
Iteration 238: 0.8774974346160889
Iteration 239: 0.8709871768951416
Iteration 240: 0.8461918830871582
Validation set loss: 0.8531339764595032
Checkpointing model weights
embed_matrix:0
[[ 9.45316315e-01 -6.63625717e-01 -2.47278929e-01 -7.91995525e-01
   2.66303062e-01 -6.48389816e-01  4.97818947e-01  3.40559959e-01
   2.15496302e-01 -8.56987476e-01]
 [ 2.49924064e-01  2.93612301e-01 -1.54768550e+00  1.97255850e+00
  -1

Iteration 241: 0.8399404883384705
Iteration 242: 0.8593591451644897
Iteration 243: 0.8609640598297119
Iteration 244: 0.8599852323532104
Iteration 245: 0.8646277785301208
Iteration 246: 0.8265372514724731
Iteration 247: 0.8379403948783875
Iteration 248: 0.8425906896591187
Iteration 249: 0.8413389325141907
Iteration 250: 0.8188079595565796
Iteration 251: 0.8380677700042725
Iteration 252: 0.8485528230667114
Iteration 253: 0.8415624499320984
Iteration 254: 0.8258153200149536
Iteration 255: 0.8388980627059937
Iteration 256: 0.8213112950325012


### Unpaired VL sequences

In [26]:
train_model(model1900, vl_train_fmt[:50000], vl_val_fmt[:5000], n_epochs=1, ckpt_dir=None, val_freq=20)

Running epoch: 1
Iteration 1: 2.373396396636963
Iteration 2: 2.352776050567627
Iteration 3: 2.3582653999328613
Iteration 4: 2.3323936462402344
Iteration 5: 2.29256272315979
Iteration 6: 2.2453246116638184
Iteration 7: 2.2492876052856445
Iteration 8: 2.21063232421875
Iteration 9: 2.187523365020752
Iteration 10: 2.1535391807556152
Iteration 11: 2.135596752166748
Iteration 12: 2.136962413787842
Iteration 13: 2.108055591583252
Iteration 14: 2.102959156036377
Iteration 15: 2.0731234550476074
Iteration 16: 2.0773251056671143
Iteration 17: 2.0528900623321533
Iteration 18: 2.042029619216919
Iteration 19: 2.047182083129883
Iteration 20: 1.9991891384124756
Validation set loss: 1.9893437623977661
Iteration 21: 2.0254154205322266
Iteration 22: 1.961630940437317
Iteration 23: 1.9676557779312134
Iteration 24: 1.9528197050094604
Iteration 25: 1.913130521774292
Iteration 26: 1.891546368598938
Iteration 27: 1.8947314023971558
Iteration 28: 1.890244483947754
Iteration 29: 1.864936113357544
Iteration 30: