In [None]:
#export
from local.torch_basics import *
from local.test import *
from local.core import *
from local.layers import *
from local.data.all import *
from local.notebook.showdoc import show_doc
from local.optimizer import *
from local.learner import *
from local.metrics import *
from local.text.core import *
from local.text.data import *
from local.text.models.core import *
from local.text.models.awdlstm import *
from local.callback.rnn import *
from local.callback.all import *

In [None]:
#default_exp text.learner

# Text learner

> Convenience functions to easily create a `Learner` for text applications

In [None]:
#export
def match_embeds(old_wgts, old_vocab, new_vocab):
    "Convert the embedding in `wgts` to go with a new vocabulary."
    bias, wgts = old_wgts.get('1.decoder.bias', None), old_wgts['0.encoder.weight']
    wgts_m = wgts.mean(0)
    new_wgts = wgts.new_zeros((len(new_vocab),wgts.size(1)))
    if bias is not None: 
        bias_m = bias.mean(0)
        new_bias = bias.new_zeros((len(new_vocab),))
    old_o2i = old_vocab.o2i if hasattr(old_vocab, 'o2i') else {w:i for i,w in enumerate(old_vocab)} 
    for i,w in enumerate(new_vocab):
        idx = old_o2i.get(w, -1)
        new_wgts[i] = wgts[idx] if idx>=0 else wgts_m
        if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m
    old_wgts['0.encoder.weight'] = new_wgts
    if '0.encoder_dp.emb.weight' in old_wgts: old_wgts['0.encoder_dp.emb.weight'] = new_wgts.clone()
    old_wgts['1.decoder.weight'] = new_wgts.clone()
    if bias is not None: old_wgts['1.decoder.bias'] = new_bias
    return old_wgts

In [None]:
wgts = {'0.encoder.weight': torch.randn(5,3)}
new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])
old,new = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']
test_eq(new[0], old[0])
test_eq(new[1], old[2])
test_eq(new[2], old.mean(0))
test_eq(new[3], old[1])

In [None]:
#With bias
wgts = {'0.encoder.weight': torch.randn(5,3), '1.decoder.bias': torch.randn(5)}
new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])
old_w,new_w = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']
old_b,new_b = wgts['1.decoder.bias'],  new_wgts['1.decoder.bias']
test_eq(new_w[0], old_w[0])
test_eq(new_w[1], old_w[2])
test_eq(new_w[2], old_w.mean(0))
test_eq(new_w[3], old_w[1])
test_eq(new_b[0], old_b[0])
test_eq(new_b[1], old_b[2])
test_eq(new_b[2], old_b.mean(0))
test_eq(new_b[3], old_b[1])

In [None]:
#export
@delegates(Learner.__init__)
class RNNLearner(Learner):
    "Basic class for a `Learner` in NLP."
    def __init__(self, model, dbunch, loss_func, alpha=2., beta=1., **kwargs):
        super().__init__(model, dbunch, loss_func, **kwargs)
        self.add_cb(RNNTrainer(alpha=alpha, beta=beta))
        
    def save_encoder(self, file):
        "Save the encoder to `self.path/self.model_dir/file`"
        encoder = get_model(self.model)[0]
        if hasattr(encoder, 'module'): encoder = encoder.module
        torch.save(encoder.state_dict(), join_path_file(file,self.path/self.model_dir, ext='.pth'))

    def load_encoder(self, file, device=None):
        "Load the encoder `name` from the model directory."
        encoder = get_model(self.model)[0]
        if device is None: device = self.dbunch.device
        if hasattr(encoder, 'module'): encoder = encoder.module
        encoder.load_state_dict(torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device))
        self.freeze()
        return self

    #TODO: When access is easier, grab new_vocab from self.dbunch
    def load_pretrained(self, wgts_fname, vocab_fname, new_vocab, strict=True):
        "Load a pretrained model and adapt it to the data vocabulary."
        old_vocab = pickle.load(open(vocab_fname, 'rb'))
        wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)
        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer
        wgts = match_embeds(wgts, old_vocab, new_vocab)
        self.model.load_state_dict(wgts, strict=strict)
        self.freeze()
        return self

In [None]:
#export
from local.text.models.core import _model_meta

In [None]:
#export
#TODO: When access is easier, grab vocab from dbunch
@delegates(Learner.__init__)
def language_model_learner(dbunch, arch, vocab, config=None, drop_mult=1., pretrained=True, pretrained_fnames=None, **kwargs):
    "Create a `Learner` with a language model from `data` and `arch`."
    model = get_language_model(arch, len(vocab), config=config, drop_mult=drop_mult)
    meta = _model_meta[arch]
    learn = RNNLearner(dbunch, model, loss_func=CrossEntropyLossFlat(), splitter=meta['split_lm'], **kwargs)
    #TODO: add backard
    #url = 'url_bwd' if data.backwards else 'url'
    if pretrained or pretrained_fnames:
        if pretrained_fnames is not None:
            fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]
        else:
            if 'url' not in meta:
                warn("There are no pretrained weights for that architecture yet!")
                return learn
            model_path = untar_data(meta['url'] , c_key=ConfigKey.Model)
            fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
        learn = learn.load_pretrained(*fnames, vocab)
    return learn

In [None]:
#export
#TODO: When access is easier, grab vocab from dbunch
@delegates(Learner.__init__)
def text_classifier_learner(dbunch, arch, vocab, bptt=72, config=None, pretrained=True, drop_mult=1., 
                            lin_ftrs=None, ps=None, **kwargs):
    "Create a `Learner` with a text classifier from `data` and `arch`."
    model = get_text_classifier(arch, len(vocab), get_c(dbunch), bptt=bptt, config=config, 
                                drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps)
    meta = _model_meta[arch]
    learn = RNNLearner(dbunch, model, loss_func=CrossEntropyLossFlat(), splitter=meta['split_clas'], **kwargs)
    if pretrained:
        if 'url' not in meta:
            warn("There are no pretrained weights for that architecture yet!")
            return learn
        model_path = untar_data(meta['url'], c_key=ConfigKey.Model)
        fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
        learn = learn.load_pretrained(*fnames, vocab, strict=False)
        learn.freeze()
    return learn

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_torch_core.ipynb.
Converted 02_script.ipynb.
Converted 03_dataloader.ipynb.
Converted 04_transform.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 11_layers.ipynb.
Converted 11a_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 22_vision_learner.ipynb.
Converted 23_tutorial_transfer_learning.ipynb.
Converted 30_text_core.ipynb.
Converted 31_text_data.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 33_text_models_core.ipynb.
Converted 34_callback_rnn.ipynb.