In [1]:
"""## Prepare fastai"""
from fastai import *
from fastai.text import *
from fastai.metrics import *
from fastai.callbacks.tensorboard import LearnerTensorboardWriter
from fastai.callbacks.misc import StopAfterNBatches
from fastai.callbacks.oversampling import OverSamplingCallback
import datetime
import pickle
from pytz import timezone

torch.cuda.set_device(3)
np.random.seed(0)

"""## Prepare Dataset"""
local_project_path = local_path + 'data/proteinnet/'
if not os.path.exists(local_project_path):
    os.makedirs(local_project_path)
print('local_project_path:', local_project_path)

local_project_path: ../data/proteinnet/


In [2]:
"""## Tokenization"""
class dna_tokenizer(BaseTokenizer):
    def tokenizer(self, t):
#         return list(t)
        res = []
        tokens = t.split(' ')
        before_seq = tokens[:-2]
        seq = tokens[-2]
        eos = tokens[-1]
        
        res = before_seq
        res += list(seq) # sequence string to list
        res.append(eos)
        
        return res
tokenizer = Tokenizer(tok_func=dna_tokenizer, pre_rules=[], post_rules=[], special_cases=[])
processor = [TokenizeProcessor(tokenizer=tokenizer, include_bos= True, include_eos=True), NumericalizeProcessor(max_vocab=30000)]

In [3]:
# # batch size
# bs = 128
# data_lm = TextLMDataBunch.from_csv(local_project_path, 'test.csv',
#                                    text_cols ='seq', valid_pct= 0.1, tokenizer=tokenizer,
#                                    include_bos= True, include_eos=True, bs=bs)
# print('data_cls Training set size', len(data_lm.train_ds))
# print('data_cls Validation set size', len(data_lm.valid_ds))

In [15]:
train_df_full = pickle.load(open('../data/uniprot_sprot/train_df.p', 'rb'))
train_df = train_df_full[['seq','GO', 'GO_ancestors_C', 'GO_ancestors_P', 'GO_ancestors_F']].copy()

In [16]:
train_df_full.head(1)

Unnamed: 0_level_0,GO,domain,seq,taxonomy_id,GO_ancestors_C,GO_ancestors_P,GO_ancestors_F,Tax_ancestors,seq_taxons,seq_tax
accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
A0A060X6Z0,GO:0070852 GO:0043204 GO:0004511,C C F,MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIEDARKE...,8022,"{GO:0043005, GO:0005575, GO:0097458, GO:004446...",{},"{GO:0003674, GO:0016491, GO:0016714, GO:000449...","{NCBITaxon:32443, NCBITaxon:8006, NCBITaxon:1,...",131567 33154 504568 33511 89593 2759 8006 7776...,8022 MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIE...


In [17]:
del train_df_full
gc.collect()

len(train_df)

66840

In [81]:
go_id = 'GO:0017076'

In [82]:
def find_go(row):
    if go_id in row.GO or go_id in row.GO_ancestors_C or go_id in row.GO_ancestors_P or go_id in row.GO_ancestors_F :
        res = 'T'
    else:
        res = 'F'
    return res
train_df[go_id] = train_df.apply(find_go, axis=1)

In [83]:
# train_df['GO'].value_counts()

In [84]:
available_T = (train_df[go_id] == 'T').sum(); available_T

1087

In [85]:
train_df_undersampled_F = train_df[train_df[go_id] == 'F'][:available_T].copy()
train_df_undersampled_T = train_df[train_df[go_id] == 'T'].copy()

In [86]:
len(train_df)

66840

In [87]:
train_df_undersampled = pd.concat([train_df_undersampled_F, train_df_undersampled_T])
print(len(train_df_undersampled))
train_df_undersampled.sample(5)

2174


Unnamed: 0_level_0,seq,GO,GO_ancestors_C,GO_ancestors_P,GO_ancestors_F,GO:0005886,GO:0017076
accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
O15648,MAVESRSRVTSKLVKAHRAMLNSVTQEDLKVDRLPGADYPNPSKKY...,GO:0020015 GO:0003872 GO:0005524 GO:0006096,"{GO:0043227, GO:0005777, GO:0043229, GO:004322...","{GO:0009117, GO:0044763, GO:0006757, GO:190157...","{GO:0019200, GO:0017076, GO:0032550, GO:001677...",F,T
A2VDJ0,MAGLRRPQPGCYCRTAAAVNLLLGVFQVLLPCCRPGGAQGQAIEPL...,GO:0005737 GO:0005886 GO:0090090 GO:0033088,"{GO:0044464, GO:0016020, GO:0044424, GO:0005575}","{GO:0050672, GO:0065007, GO:0051250, GO:001064...",{},T,F
A8MTZ0,MLKAAAKRPELSGKNTISNNSDMAEVKSMFREVLPKQGPLFVEDIM...,GO:0034464 GO:0005737 GO:0005829 GO:0042384,"{GO:0032991, GO:0005575, GO:0044444, GO:004442...","{GO:0044763, GO:0030030, GO:0010927, GO:004478...",{},F,F
A4FU49,MVQSELQLQPRAGGRAEAASWGDRGNDKGGLGNPDMPSVSPGPQRP...,GO:0070062,"{GO:0031982, GO:0043230, GO:0043227, GO:004322...",{},{},F,F
Q99NG0,MSDESASGSDPDLDPDVELEDEEEEEEEEEVAVEEHDRDDEEGLLD...,GO:0016607 GO:0005634 GO:0005524 GO:0016887 GO...,"{GO:0044446, GO:0044422, GO:0044428, GO:004322...","{GO:0006357, GO:0045935, GO:0065007, GO:004852...","{GO:0017076, GO:0032550, GO:0000988, GO:003563...",F,T


In [88]:
# train_df_undersampled.to_csv('train_df_undersampled.csv', index=False)

In [89]:
# bs = 512
# data_cls = (TextList.from_df(train_df, path = local_project_path, cols='seq', vocab=data_lm.vocab, processor=processor)
#                     .split_by_rand_pct(0.10)
#                    .label_from_df(cols='is_GO0005525') #, label_delim=' '
#                    .databunch(bs=bs))

In [90]:
vocab = pickle.load(open(local_project_path + 'lm-whole-sp-v2-vocab.pkl', 'rb')); 

In [91]:
bs = 256
data_cls = (TextList.from_df(train_df_undersampled, path = local_project_path, cols='seq', vocab=vocab, processor=processor)
                    .split_by_rand_pct(0.10)
                   .label_from_df(cols=go_id) # , label_delim=' '
                   .databunch(bs=bs))

In [92]:
print(len(data_cls.train_ds))
print(len(data_cls.valid_ds))

1957
217


In [93]:
_, train_class_counts = np.unique(data_cls.train_ds.y.items, return_counts=True)
_, valid_class_counts = np.unique(data_cls.valid_ds.y.items, return_counts=True)
len(train_class_counts)

2

In [94]:
w0 = 1/train_class_counts[0]*1000
w1 = 1/train_class_counts[1]*1000

In [95]:
weight = torch.FloatTensor([w0,w1]).cuda();weight

tensor([1.0194, 1.0246], device='cuda:3')

In [96]:
valid_class_counts

array([106, 111])

In [97]:
torch.FloatTensor(train_class_counts/np.sum(train_class_counts))

tensor([0.5013, 0.4987])

In [98]:
valid_class_counts/np.sum(valid_class_counts)

array([0.488479, 0.511521])

In [99]:
# pos_weight = torch.ones(len(train_class_counts)) * (loss_weights[0] / loss_weights[1])
# pos_weight # for multi-label!?

In [100]:
data_cls.train_ds.x[0].text[:40]

'xxbos M P I S S S S S S S T K S M R R A '

In [101]:
data_cls.train_ds.y[0]

Category F

In [102]:
# # batch size
# bs = 256
# data_cls = TextClasDataBunch.from_csv(local_project_path, 'uniprot_sprot_exp_go_F.csv',
#                                    text_cols ='seq', valid_pct= 0.1, tokenizer=tokenizer,
#                                    include_bos= True, include_eos=True, classes='labels', bs=bs)
# print('data_cls Training set size', len(data_lm.train_ds))
# print('data_cls Validation set size', len(data_lm.valid_ds)) 

In [103]:
data_cls.show_batch()

text,target
xxbos M T T Q A P M F T Q P L Q S V V V L E G S T A T F E A H V S G S P V P E V S W F R D G Q V I S T S T L P G V Q I S F S D G R A R L M I P A,F
xxbos M D H S F S G A P R F L T R P K A F V V S V G K D A T L S C Q I V G N P T P H V S W E K D R Q P V E A G A R F R L A Q D G D V Y R L T I L,F
xxbos M E L Y L S A C S K T A N V A A N K A A S S T V A E D S Q Q C V D G R H K T P I P G V G A A Q L L D L P L G V K L P M I P G T D T V Y F T,F
xxbos M K C P K C S H E A L E K A P K F C S E C G H K L Q S Q S Y E T T Q G T P H D K S Q T P S I V P Q I T N A E M D E T G S E S K S L E I Q N,F
xxbos M L L P A L L F G M A W A L A D G R W C E W T E T I R V E E E V A P R Q E D L V P C A S L D H Y S R L G W R L D L P W S G R S G L T R S P,F


In [104]:
len(data_cls.vocab.itos)

40

In [105]:
# acc_02 = partial(accuracy_thresh, thresh=0.5)
# f_score = partial(fbeta, thresh=0.5, beta=1)
from sklearn.metrics import f1_score
@np_func
def f1(inp,targ):
    return f1_score(targ, np.argmax(inp, axis=-1))

In [106]:
def one_hot_embedding(labels, num_classes):
    return torch.eye(num_classes, device= 'cuda')[labels.data]


class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma

    def forward(self, logit, target, reduction='elementwise_mean'):
        target = one_hot_embedding(target, logit.size(-1))
        target = target.float()
        max_val = (-logit).clamp(min=0)
        loss = logit - logit * target + max_val + \
               ((-max_val).exp() + (-logit - max_val).exp()).log()

        invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        if len(loss.size())==2:
            loss = loss.sum(dim=1)
        
        if reduction == 'elementwise_mean': return loss.mean()
        elif reduction == 'sum': return loss.sum()
        return loss

In [107]:

# class myCCELoss(nn.Module):

#     def __init__(self):
#         super(myCCELoss, self).__init__()

#     def forward(self, input, target):
#         y = one_hot_embedding(target, input.size(-1))
#         logit = F.softmax(input)
       
#         loss = -1 * V(y) * torch.log(logit) # cross entropy loss
#         return loss.sum(dim=1).mean()
    
# class FocalLoss(nn.Module):

#     def __init__(self, gamma=0, eps=1e-7):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.eps = eps

#     def forward(self, input, target):
#         y = one_hot_embedding(target, input.size(-1))
#         logit = F.softmax(input)
#         logit = logit.clamp(self.eps, 1. - self.eps)
        
#         loss = -1 * V(y) * torch.log(logit) # cross entropy
#         loss = loss * (1 - logit) ** self.gamma # focal loss
#         return loss.sum(dim=1).mean()

In [108]:
learn_cls = text_classifier_learner(data_cls, AWD_LSTM, drop_mult=0.5, pretrained=False, 
                                    metrics=[accuracy, f1], # f1
#                                     callback_fns=[
#                                         OverSamplingCallback
#                                     ],
#                                     loss_func = FocalLoss()
#                                     loss_func=CrossEntropyFlat(weight=weight)
                                   ).to_fp16()


In [109]:
learn_cls.load_encoder('lm-whole-sp-v4-enc');

In [110]:
def add_tensorboard_callback(learn):
    now = datetime.datetime.now().astimezone(timezone('US/Eastern'))
    time_for_different_run = f'{now.year}-{now.month}-{now.day}-{now.hour}-{now.minute}-{now.second}'

    proj_id = 'cafa' + time_for_different_run
    tboard_path = Path('log/' + proj_id)
    remove_tensorboard_callback(learn)
    learn.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=tboard_path, name='CafaLearner'))

def remove_tensorboard_callback(learn):
    if len(learn.callback_fns) > 1: # not the best way to check this !!
        learn.callback_fns.pop()


In [111]:
# add_tensorboard_callback(learn_cls)

In [112]:
# remove_tensorboard_callback(learn_cls)
# learn_cls.lr_find()
# add_tensorboard_callback(learn_cls)

In [113]:
# learn_cls.recorder.plot(skip_start=5, skip_end=10, suggestion = True)

In [114]:
# learn_cls.recorder.plot_losses()

In [115]:
# learn_cls.recorder.plot_lr(show_moms=True)

In [116]:
# learn_cls.recorder.plot_metrics()

In [None]:
lr = 1e-3
learn_cls.freeze()
learn_cls.fit_one_cycle(3, lr, moms=(0.8,0.7))

In [118]:
lr = 2e-2
learn_cls.freeze_to(-2)
learn_cls.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,f1,time
0,0.670117,0.713908,0.511521,0.676829,00:19
1,0.652687,0.707731,0.516129,0.678899,00:20


In [119]:
learn_cls.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,f1,time
0,0.635074,0.715662,0.534562,0.687307,00:20
1,0.614265,0.695494,0.548387,0.69375,00:19


In [120]:
learn_cls.freeze_to(-3)
learn_cls.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,f1,time
0,0.58486,0.669643,0.580645,0.709265,00:21
1,0.565205,0.636826,0.645161,0.722022,00:20


In [121]:
learn_cls.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,f1,time
0,0.514095,0.612678,0.658986,0.739437,00:21
1,0.531782,0.591656,0.728111,0.737778,00:21


In [122]:
learn_cls.unfreeze()
learn_cls.fit_one_cycle(4, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,f1,time
0,0.489113,0.564866,0.746544,0.755556,00:22
1,0.482449,0.537624,0.746544,0.755556,00:22
2,0.469587,0.522763,0.75576,0.760181,00:22
3,0.465013,0.526218,0.75576,0.75576,00:22


In [None]:
learn_cls.fit_one_cycle(10, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))

In [None]:
learn_cls.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))

In [None]:
# learn_cls.save('cls-is_GO0005886-v0');

In [None]:
# learn_cls.load('cls-is_GO0005525-v0');

In [None]:
%debug

In [None]:
bug

In [None]:
learn_cls.validate(metrics=[partial(accuracy_thresh, thresh=0.5), partial(fbeta, thresh=0.5, beta = 1), top_k_accuracy])

In [None]:
pred = learn_cls.get_preds()

In [None]:
learn_cls.summary()

In [None]:
learn_cls.model

In [None]:
interp = ClassificationInterpretation.from_learner(learn_cls)

In [None]:
losses,idxs = interp.top_losses()

In [None]:
len(data_cls.valid_ds)==len(losses)==len(idxs)

In [None]:
# interp.plot_top_losses(9, figsize=(15,11))

In [None]:
interp.plot_confusion_matrix(figsize=(6,6), dpi=120)

In [None]:
interp.most_confused(min_val=2)

In [None]:
(losses > 1).sum()

In [None]:
len(losses)

In [None]:
len(data_cls.classes)

In [None]:
preds,targs = learn_cls.get_preds(ordered=True)
accuracy(preds,targs),f1(preds,targs)