In [22]:
import torch
from torch import nn, optim
import RVQE
import numpy as np

In [8]:
torch.set_num_threads(2)

Our goal is to create a RNN or LSTM with roughly 837 parameters, and compare it in the dna long sequence task implemented within RVQE.
In either case the batch size is 128.

In [9]:
dataset_t = lambda length: RVQE.datasets.all_datasets["dna"](0, num_shards=0, batch_size=128, sentence_length=length)

In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_one_hot(labels, num_classes=2**3):
    return torch.eye(num_classes)[labels]

In [11]:
SEEDS = [9120, 2783, 2057, 6549, 3201, 7063, 5243, 3102, 5303, 5819, 3693, 4884, 2231, 5514, 8850, 6861, 3106, 2378, 8697, 1821, 9480, 8483, 1633, 9678, 6596, 4509, 8618, 9765, 6346, 2969];
LENGTHS = [5, 10, 20, 50, 100, 200, 500, 1000];

# RNN

In [12]:
HIDDEN_SIZE_837 = 22
NUM_LAYERS_837 = 1
ARGS_837 = (HIDDEN_SIZE_837, NUM_LAYERS_837)

class SimpleRNN(nn.Module):
    """
        This is a very simplistic RNN setup. We found a single layer performs
        much better than two layers with a smaller hidden size.
        Without doubt one can improve the performance of this model.
        Yet we didn't optimize the QRNN setup for the task at hand either.
    """
    def __init__(self, hidden_size: int, num_layers: int, io_size=2**3):
        super().__init__()
        
        self.rnn = nn.RNN(input_size=io_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.lin = nn.Linear(hidden_size, io_size)
        
    def reset(self):
        self.lin.reset_parameters()
        for name, param in self.rnn.named_parameters():
            # give an orthogonal start
            if "weight_hh" in name:
                torch.nn.init.orthogonal_(param.data)
            elif "bias" in name:
                param.data.fill_(0)
            elif "weight_ih" in name:
                torch.nn.init.xavier_uniform_(param.data)
            else:
                raise Exception(f"cannot initialize {name}")
        
    @property
    def num_parameters(self):
        return count_parameters(self.rnn) + count_parameters(self.lin)
        
    def forward(self, sentence):
        rnn_out, _ = self.rnn(sentence)
        return self.lin(rnn_out)

In [14]:
SimpleRNN(*ARGS_837).num_parameters

888

In [42]:
def run_model(lrs: list, lengths: list, seeds: list, results: dict, model_args: tuple):
    for lr in lrs:
        results[lr] = results[lr] if lr in results else {}
        _results = results[lr]

        for length in lengths:

            dataset = dataset_t(length)
            print(f"created RNN with {SimpleRNN(*model_args).num_parameters} parameters")

            criterion = nn.CrossEntropyLoss()

            _results[length] = _results[length] if length in _results else []
            __results = _results[length]

            for seed in seeds:
                if seed in [ s for s, _ in __results ]:
                    continue

                torch.manual_seed(seed)
                model = SimpleRNN(*model_args)
                model.reset()
                optimizer = optim.Adam(model.parameters(), lr=lr)   # this has been found to converge fastest

                for step in range(1, 100*1000): # cap amounts to the same number of samples seen as for qrnn
                    sentence, target = dataset.next_batch(0, RVQE.data.TrainingStage.TRAIN)

                    # transform sentence to one-hot as in the qrnn case
                    sentence = to_one_hot(RVQE.data.targets_for_loss(sentence))            

                    optimizer.zero_grad()
                    out = model(sentence.float())

                    # unlike the qrnn case, we use the entire output as loss
                    # this gives the rnn an advantage!
                    out = out.transpose(1, 2)
                    target = RVQE.data.targets_for_loss(target)
                    loss = criterion(out, target)

                    loss.backward()
                    optimizer.step()

                    if torch.isnan(loss):
                        print("nan")
                        __results.append([seed, np.nan])
                        break

                    if loss < 0.001:
                        __results.append([seed, step])
                        print(f"length {length} converged after {step} steps.")
                        break

                    if step % 500 == 0:
                        pass
                        print(f"{step:06d} {loss:.2e}")

                else:
                    print(f"length {length} did not converge after {step} steps.")
                    __results.append([seed, -1])

## Small Net

In [35]:
#lr_results_small = {}
run_with_lrs([.03, .01, .003], LENGTHS[:4], SEEDS[:5], lr_results_small, ARGS_837)

created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
000500 6.50e-01
001000 2.66e-01
001500 2.45e-01
002000 2.54e-01
002500 1.05e-02
003000 4.35e-03
003500 2.53e-03
004000 1.47e-03
length 50 converged after 4431 steps.
000500 5.80e-01
001000 5.81e-01
001500 6.13e-01
002000 1.36e-02
002500 5.21e-03
003000 2.56e-03
003500 1.79e-03
004000 1.12e-03
length 50 converged after 4067 steps.
000500 6.36e-01
001000 2.54e-01
001500 1.23e-01
002000 9.49e-03
002500 3.57e-03
003000 2.07e-03
003500 1.25e-03
length 50 converged after 3661 steps.
000500 5.72e-01
001000 2.40e-02
001500 4.81e-03
002000 2.28e-03
002500 1.34e-03
length 50 converged after 2761 steps.
000500 8.01e-01


In [36]:
lr_results_small

{0.03: {5: [[9120, 72], [2783, 59], [2057, 54], [6549, 58], [3201, 53]],
  10: [[9120, 189], [2783, 212], [2057, 265], [6549, 266], [3201, 205]],
  20: [[9120, 474], [2783, 1554], [2057, 1447], [6549, 529], [3201, 884]],
  50: [[9120, 8021], [2783, 5704], [2057, -1], [6549, -1], [3201, 14176]]},
 0.01: {5: [[9120, 414], [2783, 417], [2057, 374], [6549, 382], [3201, 386]],
  10: [[9120, 582], [2783, 573], [2057, 569], [6549, 554], [3201, 555]],
  20: [[9120, 803], [2783, 1118], [2057, 1109], [6549, 835], [3201, 1055]],
  50: [[9120, 3187], [2783, 1709], [2057, 5477], [6549, 5048], [3201, 8736]]},
 0.003: {5: [[9120, 1163],
   [2783, 1219],
   [2057, 1158],
   [6549, 1224],
   [3201, 1178]],
  10: [[9120, 1577], [2783, 1434], [2057, 1521], [6549, 1553], [3201, 1534]],
  20: [[9120, 1973], [2783, 2167], [2057, 2096], [6549, 1660], [3201, 1838]],
  50: [[9120, 4431], [2783, 4067], [2057, 3661], [6549, 2761], [3201, 4616]]}}

In [37]:
{ k: np.mean([ np.mean([ t for __, t in vv if t != -1 ]) for _, vv in v.items() ]) for k, v in lr_results_small.items() }

{0.03: 2641.1333333333337, 0.01: 1694.1499999999999, 0.003: 2141.55}

In [53]:
results_small

{0.01: {5: [[9120, 414], [2783, 417], [2057, 374], [6549, 382], [3201, 386]],
  10: [[9120, 582], [2783, 573], [2057, 569], [6549, 554], [3201, 555]],
  20: [[9120, 803], [2783, 1118], [2057, 1109], [6549, 835], [3201, 1055]],
  50: [[9120, 3187], [2783, 1709], [2057, 5477], [6549, 5048], [3201, 8736]],
  100: [[9120, 6362], [2783, 7928], [2057, 4586], [6549, 2477], [3201, 7326]],
  200: [[9120, 3752], [2783, 2476], [2057, 11872], [6549, 5613], [3201, 6212]],
  500: [[9120, 3283], [2783, 14294]]}}

In [None]:
#results_small = {.01: lr_results_small[.01].copy()}
run_model([.01], LENGTHS, SEEDS[:5], results_small, ARGS_837)

created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
created RNN with 888 parameters
000500 1.32e+00
001000 1.47e+00
001500 6.34e-01
002000 3.27e-01
002500 2.89e-01


In [46]:
import pandas as pd

In [55]:
pd.DataFrame([ [key, seed, step, .0] for key in results_small[.01] for seed, step in results_small[.01][key] ], columns=["sentence_length", "seed", "hparams/epoch", "hparams/validate_best"], index=None).to_csv("~/small-rnn.csv")

In [30]:
results.items() 

dict_items([(5, [[9120, 48], [2783, 45], [2057, 43], [6549, 57], [3201, 53], [7063, 65], [5243, 44], [3102, 41], [5303, 53], [5819, 50], [3693, 47], [4884, 47], [2231, 49], [5514, 46], [8850, 58], [6861, 42], [3106, 40], [2378, 68], [8697, 44], [1821, 46], [9480, 47], [8483, 53], [1633, 53], [9678, 49], [6596, 43], [4509, 43], [8618, 46], [9765, 46], [6346, 44], [2969, 49]]), (10, [[9120, 386], [2783, 276], [2057, 285], [6549, 304], [3201, 387], [7063, 432], [5243, 216], [3102, 352], [5303, 298], [5819, 415], [3693, 262], [4884, 317], [2231, 386], [5514, 342], [8850, 436], [6861, 424], [3106, 294], [2378, 285], [8697, 331], [1821, 348], [9480, 299], [8483, 419], [1633, 374], [9678, 401], [6596, 412], [4509, 422], [8618, 385], [9765, 277], [6346, 602], [2969, 302]]), (20, [[9120, 15999], [2783, 15999], [2057, 15999], [6549, 15999], [3201, 15999], [7063, 15999], [5243, 15999], [3102, 15999], [5303, 15999], [5819, 15999], [3693, 15999], [4884, 15999], [2231, 15999], [5514, 15999], [8850, 