In [1]:
local_path = '../'

"""## Prepare fastai"""
from fastai import *
from fastai.text import *
from fastai.metrics import *
from fastai.callbacks.tensorboard import LearnerTensorboardWriter
from fastai.callbacks.misc import StopAfterNBatches
from fastai.callbacks.oversampling import OverSamplingCallback
import datetime
from pytz import timezone

torch.cuda.set_device(3)
np.random.seed(0)

"""## Prepare Dataset"""
local_project_path = local_path + 'data/proteinnet/'
if not os.path.exists(local_project_path):
    os.makedirs(local_project_path)
print('local_project_path:', local_project_path)

local_project_path: ../data/proteinnet/


In [3]:
"""## Tokenization"""
class dna_tokenizer(BaseTokenizer):
    def tokenizer(self, t):
#         return list(t)
        res = []
        tokens = t.split(' ')
        before_seq = tokens[:-2]
        seq = tokens[-2]
        eos = tokens[-1]
        
        res = before_seq
        res += list(seq) # sequence string to list
        res.append(eos)
        
        return res
tokenizer = Tokenizer(tok_func=dna_tokenizer, pre_rules=[], post_rules=[], special_cases=[])
processor = [TokenizeProcessor(tokenizer=tokenizer, include_bos= True, include_eos=True), NumericalizeProcessor(max_vocab=30000)]

In [4]:
# batch size
bs = 128
data_lm = TextLMDataBunch.from_csv(local_project_path, 'test.csv',
                                   text_cols ='seq', valid_pct= 0.1, tokenizer=tokenizer,
                                   include_bos= True, include_eos=True, bs=bs)
print('data_cls Training set size', len(data_lm.train_ds))
print('data_cls Validation set size', len(data_lm.valid_ds))

data_cls Training set size 99908
data_cls Validation set size 11101


In [5]:
train_df_full = pickle.load(open('train_df', 'rb'))
train_df = train_df_full[['seq','GO']].copy()

In [6]:
train_df['is_GO0005525'] = train_df.apply(lambda row: 'T' if 'GO:0005525' in row.GO else 'F', axis=1)

In [7]:
(train_df['is_GO0005525'] == 'F')

Unnamed: 0_level_0,seq,GO,is_GO0005525
accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A0A060X6Z0,MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIEDARKE...,GO:0070852 GO:0043204 GO:0004511,F
A0A068FIK2,MEVGGGSEECCVKVAVHVRPLIGDEKVQGCKDCVTVIPGKPQVQIG...,GO:0055028 GO:0005737,F
A0A075F932,MVSESHHEALAAPPATTVAAAPPSNVTEPASPGGGGGKEDAFSKLK...,GO:0048609 GO:0046883,F
A0A078CGE6,MARQMTSSQFHKSKTLDNKYMLGDEIGKGAYGRVYIGLDLENGDFV...,GO:0005730 GO:0004674 GO:0046777 GO:0051302,F
A0A086F3E3,MTKGRLEAFSDGVLAIIITIMVLELKVPEGSSWASLQPILPRFLAY...,GO:0022841 GO:0071805,F


In [None]:
train_df.to_csv('train_df.csv', index=False)

In [9]:
(train_df['is_GO0005525'] == 'F').sum()

66550

In [None]:
# bs = 512
# data_cls = (TextList.from_df(train_df, path = local_project_path, cols='seq', vocab=data_lm.vocab, processor=processor)
#                     .split_by_rand_pct(0.10)
#                    .label_from_df(cols='is_GO0005525', label_delim=' ')
#                    .databunch(bs=bs))

In [None]:
bs = 512
data_cls = (TextList.from_csv('./', 'train_df.csv', cols='seq', vocab=data_lm.vocab, processor=processor)
                    .split_by_rand_pct(0.10)
                   .label_from_df(cols='is_GO0005525', label_delim=' ')
                   .databunch(bs=bs))

In [None]:
print(len(data_cls.train_ds))
print(len(data_cls.valid_ds))

In [None]:
data_cls.train_ds.x[0].text

In [None]:
data_cls.train_ds.y[0]

In [None]:
# # batch size
# bs = 256
# data_cls = TextClasDataBunch.from_csv(local_project_path, 'uniprot_sprot_exp_go_F.csv',
#                                    text_cols ='seq', valid_pct= 0.1, tokenizer=tokenizer,
#                                    include_bos= True, include_eos=True, classes='labels', bs=bs)
# print('data_cls Training set size', len(data_lm.train_ds))
# print('data_cls Validation set size', len(data_lm.valid_ds)) 

In [None]:
data_cls.show_batch()

In [None]:
len(data_lm.vocab.itos)

In [None]:
# acc_02 = partial(accuracy_thresh, thresh=0.5)
# f_score = partial(fbeta, thresh=0.5, beta=1)
from sklearn.metrics import f1_score
@np_func
def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1))

In [None]:
learn_cls = text_classifier_learner(data_cls, AWD_LSTM, drop_mult=0.5, pretrained=False, 
                                    metrics=[accuracy, f1],
                                    callback_fns=[
                                        OverSamplingCallback
                                    ]
                                   ).to_fp16()


In [None]:
from pathlib import Path

learn_cls.path = Path(local_project_path)
learn_cls.load_encoder('lm-gpu3-sp-40M-v2-enc');
learn_cls.freeze();

In [None]:
# learn_cls.data.batch_size = 512

In [None]:
def add_tensorboard_callback(learn):
    now = datetime.datetime.now().astimezone(timezone('US/Eastern'))
    time_for_different_run = f'{now.year}-{now.month}-{now.day}-{now.hour}-{now.minute}-{now.second}'

    proj_id = 'cafa' + time_for_different_run
    tboard_path = Path('log/' + proj_id)
    remove_tensorboard_callback(learn)
    learn.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=tboard_path, name='CafaLearner'))

def remove_tensorboard_callback(learn):
    if len(learn.callback_fns) > 1: # not the best way to check this !!
        learn.callback_fns.pop()


In [None]:
# add_tensorboard_callback(learn_cls)

In [None]:
# remove_tensorboard_callback(learn_cls)
# learn_cls.lr_find()
# add_tensorboard_callback(learn_cls)

In [None]:
# learn_cls.recorder.plot(skip_start=5, skip_end=10, suggestion = True)

In [None]:
# learn_cls.recorder.plot_losses()

In [None]:
# learn_cls.recorder.plot_lr(show_moms=True)

In [None]:
# learn_cls.recorder.plot_metrics()

In [None]:
lr = 2e-2
learn_cls.fit_one_cycle(1, lr, moms=(0.8,0.7))

In [None]:
%debug

In [None]:
bug

In [None]:
learn_cls.fit_one_cycle(10, slice(1e-3), moms=(0.8,0.7))

In [None]:
learn_cls.unfreeze()

In [None]:
learn_cls.fit_one_cycle(10, slice(1e-3), moms=(0.8,0.7))

In [None]:
learn_cls.fit_one_cycle(10, slice(1e-4), moms=(0.8,0.7))

In [None]:
learn_cls.fit_one_cycle(10, slice(1e-4), moms=(0.8,0.7))

In [None]:
learn_cls.unfreeze()

In [None]:
learn_cls.validate(metrics=[partial(accuracy_thresh, thresh=0.5), partial(fbeta, thresh=0.5, beta = 1), top_k_accuracy])

In [None]:
pred = learn_cls.get_preds()

In [None]:
learn_cls.summary()

In [None]:
learn_cls.model

In [None]:
preds,targs = learn_c.get_preds(ordered=True)
accuracy(preds,targs),f1(preds,targs)

In [None]:
interp = ClassificationInterpretation.from_learner(learn_cls)

In [None]:
losses,idxs = interp.top_losses()

In [None]:
len(data_cls.valid_ds)==len(losses)==len(idxs)

In [None]:
interp.plot_top_losses(9, figsize=(15,11))

In [None]:
interp.plot_confusion_matrix(figsize=(15,15), dpi=120)

In [None]:
interp.most_confused(min_val=2)

In [None]:
(losses > 1).sum()

In [None]:
len(losses)

In [None]:
len(data_cls.classes)