In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai import *
from fastai.text import *
from fastai.callbacks import CSVLogger, SaveModelCallback

import torch.backends.cudnn as cudnn

from train_search import DartsRnnSearch, ArchParamUpdate, PrintGenotype
from darts_callbacks import HidInit, Regu, SaveModel, ResumeModel, GcCol 

In [2]:
# random seed for reproducibility.
seed = 135
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

gpu = 0
torch.cuda.set_device(gpu)
cudnn.benchmark = True
cudnn.enabled=True
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)

bptt = 35
bs_train, bs_val = 256, 64
dat = load_data('data', 'penn_db', bs=bs_train, bptt=bptt)
dat.valid_dl.batch_size = bs_val
search_dat = list(iter(load_data('data', 'penn_db', bs=bs_train, bptt = bptt).valid_dl))

arch_lr = 3e-3
arch_wdecay = 1e-3
vocab_sz = len(dat.train_ds.x.vocab.itos)
emb_sz = 300
hid_sz = 300
wdecay = 5e-7
dropout = 0.75
dropouth = 0.25
dropoutx = 0.75
dropouti = 0.2
dropoute = 0.
clip = 0.25

In [3]:
csv_name = 'train_search'
model_name = 'train_search'
# resume_model = 'train_search'

In [4]:
learn = Learner(dat, DartsRnnSearch(emb_sz = emb_sz, vocab_sz = vocab_sz,
                                    ninp = emb_sz, nhid = hid_sz, 
                                    dropout = dropout, dropouth = dropouth, dropoutx = dropoutx,
                                    dropouti = dropouti, dropoute = dropoute,
                                    bs_train = bs_train, bs_val = bs_val),
                opt_func = torch.optim.SGD,
                callback_fns = [
                    HidInit,
                    partial(ArchParamUpdate, search_dat=search_dat,
                            arch_lr=arch_lr, arch_wdecay=arch_wdecay, wdecay=wdecay),
                    Regu,
                    PrintGenotype,
                    partial(GradientClipping, clip=clip),
                    partial(CSVLogger, filename = csv_name, append = True),
                    GcCol,
#                     partial(ResumeModel, name = resume_model) 
                ], 
                wd = wdecay
                )

total_params = sum(x.nelement() for x in learn.model.parameters())
print('Total params:', total_params)
print(learn.model.genotype_parse())
# learn.data.valid_dl=None

Total params: 4810000
Genotype(recurrent=[('sigmoid', 0), ('tanh', 1), ('identity', 1), ('relu', 3), ('sigmoid', 4), ('sigmoid', 3), ('identity', 6), ('identity', 1)], concat=range(1, 9))


In [5]:
learn.fit(50, 20, callbacks=[
#     SaveModel(learn, gap = 1, name=model_name),
    SaveModelCallback(learn, name=model_name) # save on improvement
                           ])

epoch,train_loss,valid_loss,time
0,9.377739,8.9144,05:15
1,9.0244,7.8746,05:13
2,8.465348,7.4553,05:16
3,8.251266,7.209497,05:15
4,8.137778,7.105574,05:14
5,8.029912,6.999118,05:15
6,7.972732,6.931962,05:13
7,7.771575,6.910496,05:16
8,7.630065,6.853231,05:14
9,7.517681,6.790563,05:15


Genotype(recurrent=[('identity', 0), ('identity', 0), ('relu', 2), ('sigmoid', 2), ('sigmoid', 2), ('tanh', 2), ('relu', 2), ('relu', 5)], concat=range(1, 9))
Better model found at epoch 0 with val_loss value: 8.914400100708008.
Genotype(recurrent=[('relu', 0), ('tanh', 1), ('identity', 2), ('relu', 3), ('sigmoid', 0), ('tanh', 2), ('tanh', 5), ('relu', 4)], concat=range(1, 9))
Better model found at epoch 1 with val_loss value: 7.874600410461426.
Genotype(recurrent=[('sigmoid', 0), ('tanh', 0), ('tanh', 2), ('identity', 0), ('sigmoid', 1), ('tanh', 2), ('identity', 5), ('relu', 4)], concat=range(1, 9))
Better model found at epoch 2 with val_loss value: 7.455300331115723.
Genotype(recurrent=[('sigmoid', 0), ('tanh', 0), ('tanh', 2), ('identity', 2), ('tanh', 2), ('tanh', 2), ('tanh', 4), ('sigmoid', 2)], concat=range(1, 9))
Better model found at epoch 3 with val_loss value: 7.209496974945068.
Genotype(recurrent=[('sigmoid', 0), ('tanh', 0), ('tanh', 2), ('identity', 2), ('tanh', 2), ('t