In [None]:
!pip install -Uqq fastbook
!pip install -Uqq wandb

In [None]:
import wandb
import fastbook
import sentencepiece as spm

from fastbook import *
from fastai.text.all import *
from fastai.callback.wandb import *
from typing import List

In [None]:
wandb.login()

In [None]:
wandb_init_kwargs = {
    'reinit': True, 
    'project': "ml-base", 
    'entity': "<your account name>",
    'group': 'ml-base-001',
    'name': 'ml-base-001-001',  
    'notes': 'Finetuning ml-base with fastai', 
    'tags': ['malayalam', 'ml-base', 'fastai']
  }

In [None]:
LCL_PATH="/nlp-for-malyalam/"

# Feature Engineering

In [None]:
URL_MAL = 'https://calicut.qburst.in/commoncrawl/malayalam/2020-10/malayalam_filtered_html_body.tar.gz'

In [None]:
path = untar_data(URL_MAL)

In [None]:
path.ls()

In [None]:
files = get_text_files(path)
files

In [None]:
txt = files[0].open().read(); txt[:75]

In [None]:
class MalyalamTokenizer(BaseTokenizer):
    def __init__(self, split_char=' ',lang:str='ml'):
        self.split_char=split_char
        self.lang = lang
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(LCL_PATH + "models/tokenizer/malyalam_lm.model")
        
    def tokenizer(self, t:str) -> List[str]:
        return self.sp.EncodeAsPieces(t)

In [None]:
get_mal = partial(get_text_files)
# bs=16
# bs=24
# bs=48
bs=64
# bs=128

wiki_ml = DataBlock(
    blocks=TextBlock.from_folder(path, is_lm=True,seq_len=80,max_vocab=9998, extensions='.txt'),
    get_items=get_mal, splitter=RandomSplitter(0.1, seed=42)
)

dls_lm = wiki_ml.dataloaders(path, path=path, bs=bs, seq_len=80)

In [None]:
## Save the vocab
pickle.dump(dls_lm.vocab, open( LCL_PATH + '/data/ml_001.vocab.pkl', 'wb'))

In [None]:
dls_lm.show_batch(max_n=2)

In [None]:
config = awd_lstm_lm_config.copy()
config['n_hid'] = 1150
lm_fns = [LCL_PATH + 'models/language-model/ULMFiT/third_ml_lm', LCL_PATH + 'models/tokenizer/malyalam_lm.vocab']
learn_lm = language_model_learner(dls_lm, AWD_LSTM,config=config, pretrained_fnames=lm_fns, drop_mult=0.3)
# learn_lm = language_model_learner(dls_lm, AWD_LSTM,config=config, pretrained_fnames=lm_fns, drop_mult=0.3).to_fp16()

In [None]:
wandb.init(**wandb_init_kwargs)

In [None]:
learn_lm.freeze()

In [None]:
learn_lm.fit_one_cycle(3, lr_max=5e-5, cbs=[WandbCallback(log_preds=False, log_model=False)]) 

In [None]:
learn_lm.lr_find()

In [None]:
learn_lm.unfreeze()

In [None]:
learn_lm.fit_one_cycle(5,  lr_max=1e-3, cbs=[WandbCallback(log_preds=False, log_model=False)]) 

# Evaluate

In [None]:
val_res = learn_lm.validate()

val_res_d = { 'loss': val_res[0]}
for idx, m in enumerate(learn_lm.metrics):
    val_res_d[m.name] = val_res[idx+1]
    
val_res_d

In [None]:
preds, targs, losses = learn_lm.get_preds(with_loss=True)
print(preds.shape, targs.shape, losses.shape)
print(losses.mean(), accuracy(preds, targs))

In [None]:
wandb.finish()

# Predict

In [None]:
learn_lm.predict('മലയാള ികളായ ▁വിമാന യാത്ര ക്കാര',n_words=10)

In [None]:
TEXT = "ബംഗാളിലെ ▁ഭരണം ▁കമ്പനി"
N_WORDS = 40
N_SENTENCES = 2

In [None]:
print("\n".join(learn_lm.predict(TEXT, N_WORDS, temperature=0.75) for _ in range(N_SENTENCES)))

# Save Model

In [None]:
learn_lm.save_encoder(LCL_PATH +'/data/fine_tuned_enc_001')

In [None]:
learn_lm.save(f'{LCL_PATH}/models/language-model/ml-001epoch', with_opt=True)

# ReLoad Model

In [None]:
learn_lm.load(f'{LCL_PATH}/models/language-model/ml-001epoch', with_opt=True)