## WideResnet 01

In [23]:
from torch.utils.data import DataLoader
from helpers import TextTransform, SearchEngineDataset, pad_collate
from fastai.basic_train import Learner, LearnerCallback
from fastai.basic_data import DataBunch
from models import WideResNetEmbedding
from config import WideResnetConfig
from fastai.basic_data import DatasetType

In [44]:
wrn_config = WideResnetConfig()
TRN_PATH = wrn_config.trn_50_20_data_path
VAL_PATH = wrn_config.val_50_20_data_path
BS = wrn_config.bs
TRN_WORKERS = wrn_config.trn_workers
VAL_WORKERS = wrn_config.val_workers
EMB_SIZE = wrn_config.emb_dim
VOCAB_SIZE = wrn_config.vocab_size
PRETRAINED_PATH = wrn_config.emb_pretrained
EVAL_01 = wrn_config.eval_data_path_01

In [45]:
EVAL_01

'/media/saqib/ni/Projects/Microsoft/AI_Challenge_18/data/eval1_unlabelled.tsv'

In [46]:
TRN_PATH

'/media/saqib/ni/Projects/Microsoft/AI_Challenge_18/data/data_0.5_20_trn.tsv'

In [47]:
text_transform = TextTransform('../../data/train_lm_data/itos.pkl')
trn_dataset = SearchEngineDataset(TRN_PATH,
                                    ['query', 'passage', 'label'], 
                                     transform=text_transform.text_to_ints)
val_dataset = SearchEngineDataset(EVAL_01,
                                    ['query', 'passage', 'label'], 
                                     transform=text_transform.text_to_ints)

In [48]:
trn_dl = DataLoader(trn_dataset, batch_size=BS, 
                    shuffle=True, num_workers=TRN_WORKERS, collate_fn=pad_collate)
val_dl = DataLoader(val_dataset, batch_size=BS, 
                    shuffle=False, num_workers=TRN_WORKERS, collate_fn=pad_collate)

In [49]:
databunch = DataBunch(train_dl=trn_dl, valid_dl=val_dl, test_dl=tst_dl, collate_fn=pad_collate)

In [50]:
model_name = 'wideresnet-01'
learner = Learner(databunch, WideResNetEmbedding(vocab_size=VOCAB_SIZE, 
                                                 pretrained_wts_pth=PRETRAINED_PATH, 
                                                 emb_dim=EMB_SIZE,
                                                 n_grps=wrn_config.n_grps, 
                                                 N=wrn_config.n_blocks,
                                                 k=wrn_config.widening
                                       ))

In [51]:
learner = learner.load('wideresnet-01_3')

In [None]:
preds = learner.get_preds()

In [42]:
preds[0][200:300]

tensor([[1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1.2402],
        [1