In [16]:
import glob, os, time, sys
import json
import numpy as np
from collections import defaultdict
import argparse

from dataset.genome import GenomeDataset
from dataset.utils import load_meta_reads, build_bert_corpus_v2
import utils.utils as utils
from debug.visualize import get_group_label, visualize
from utils.metrics import genome_acc, group_precision_recall

import torch
from bert_pytorch.dataset import BERTDataset, WordVocab
from torch.utils.data import DataLoader

from sklearn.cluster import KMeans

import tqdm

In [6]:
from torch.utils.data import Dataset
import tqdm
import torch
import random


class BERTDataset_ForInference(Dataset):
    def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
        self.vocab = vocab
        self.seq_len = seq_len

        self.on_memory = on_memory
        self.corpus_lines = corpus_lines
        self.corpus_path = corpus_path
        self.encoding = encoding

        with open(corpus_path, "r", encoding=encoding) as f:
            if self.corpus_lines is None and not on_memory:
                for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
                    self.corpus_lines += 1

            if on_memory:
                self.lines = [line[:-1].split("\t")
                              for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
                self.corpus_lines = len(self.lines)

        if not on_memory:
            self.file = open(corpus_path, "r", encoding=encoding)
            self.random_file = open(corpus_path, "r", encoding=encoding)

            for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()

    def __len__(self):
        return self.corpus_lines

    def __getitem__(self, item):
        t1, t2 = self.get_corpus_line(item)
        is_next_label = 1
        # t1, t2, is_next_label = self.random_sent(item)
        # t1_random, t1_label = self.random_word(t1)
        # t2_random, t2_label = self.random_word(t2)

        t1_random, t1_label = self.token2token_id(t1)
        t2_random, t2_label = self.token2token_id(t2)

        # [CLS] tag = SOS tag, [SEP] tag = EOS tag
        t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
        t2 = t2_random + [self.vocab.eos_index]

        t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
        t2_label = t2_label + [self.vocab.pad_index]

        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input,
                  "bert_label": bert_label,
                  "segment_label": segment_label,
                  "is_next": is_next_label}

        return {key: torch.tensor(value) for key, value in output.items()}

    def token2token_id(self, sentence):
        tokens = sentence.split()
        output_label = []

        for i, token in enumerate(tokens):
            tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
            output_label.append(0)

        return tokens, output_label

    def get_corpus_line(self, item):
        if self.on_memory:
            return self.lines[item][0], self.lines[item][1]
        else:
            line = self.file.__next__()
            if line is None:
                self.file.close()
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
                line = self.file.__next__()

            t1, t2 = line[:-1].split("\t")
            return t1, t2

In [7]:
sys.path.append('.')

DATASET_DIR = '../../Data/test/input/'
# Specifc dataset or all of them
DATASET_NAME = ['S1_test', 'hmp_test']
# DATASET_NAME = 'all'
BIMETAOUT_DIR = '../../Data/test/output/bimetaout/'
RESULT_DIR = '../../Data/test/output/bertbimetaout/'

In [8]:
# Hyperparameters
KMERS = [4]
# Training batchsize
BATCH_SIZE = 256
# Number of epochs for pretraining
PRETRAIN_EPOCHS = 2000
# Dir contains raw fasta data
is_save_each_corpus = True

In [4]:
import random
def build_bert_corpus(reads, k_mer_length=4):
    '''
    Output will be a text file, where:
    - each line is a read which is represented by 2 sentence of k-mer,
        each k-mer is separated from the others with space
    '''
    corpus = []
    for read in reads:
        lines = []
        n = len(read)
        m = n // 2

        kmers_first_sentence = []
        for j in range(0, m - k_mer_length + 1):
            k_mer = read[j:(j + k_mer_length)]
            kmers_first_sentence.append(k_mer)
        
        kmers_second_sentence = []
        for j in range(m, n - k_mer_length + 1):
            k_mer = read[j:(j + k_mer_length)]
            kmers_second_sentence.append(k_mer)
        
        first_sentence = ' '.join(kmers_first_sentence)
        second_sentence = ' '.join(kmers_second_sentence)
        document = first_sentence + '\t' + second_sentence + '\n'
        
        corpus.append(document)

    return corpus

In [17]:
def build_bert_corpus_v2(reads, k_mer_length=4):
    corpus = []
    for read in reads:
        kmers_per_read = []
        for j in range(0,len(read)-k_mer_length + 1):
            k_mer = read[j:j+k_mer_length]
            kmers_per_read.append(k_mer)
        
        middle_idx = len(kmers_per_read) // 2
        # line = ' '.join(kmers_per_read)
        sentence = ' '.join(kmers_per_read[:middle_idx]) + ' \t ' + ' '.join(kmers_per_read[middle_idx:]) + '\n'
        
        corpus.append(sentence)

    return corpus

In [18]:
if DATASET_NAME == 'all':
    raw_datasets = glob.glob(DATASET_DIR + '/*.fna')
else:
    if type(DATASET_NAME) == list:
        raw_datasets = [os.path.join(DATASET_DIR, ds_name + '.fna') for ds_name in DATASET_NAME]
    else:    
        raw_datasets = [os.path.join(DATASET_DIR, DATASET_NAME + '.fna')]
    #raw_datasets = [os.path.join(DATASET_DIR, DATASET_NAME + '.fna')]

# Mapping of dataset and its corresponding number of clusters
with open('config/dataset_metadata.json', 'r') as f:
    n_clusters_mapping = json.load(f)['datasets']

raw_datasets.sort()

In [6]:
raw_datasets

['../../Data/test/input/S1_test.fna', '../../Data/test/input/hmp_test.fna']

In [11]:
dataset = raw_datasets[0]
dataset_name = os.path.basename(dataset).split('.fna')[0]

print("-------------------------------------------------------")
print('Processing dataset: ', dataset_name)

bimetaout_file = os.path.join(BIMETAOUT_DIR, dataset_name + '.json')

log_file = os.path.join(RESULT_DIR, dataset_name + '.log.txt')
log = open(log_file, "w")
log.write('------------------------------------------------------- ')
log.write('\nProcessing dataset ' + dataset_name)

n_clusters = n_clusters_mapping[dataset_name]

print('Prior number of clusters: ', n_clusters)
log.write('\nPrior number of clusters: ' + str(n_clusters))

-------------------------------------------------------
Processing dataset:  S1_test
Prior number of clusters:  2


28

In [12]:
t0 = time.time()
# Load group file (phase 1 of bimeta) according to dataset_name
print('Loading groups/seeds: ...')
groups, seeds = utils.load_groups_seeds(BIMETAOUT_DIR, dataset_name)
print('Total number of groups: ', len(groups))
log.write('\nTotal number of groups/seeds: ' + str(len(groups)))
print('Time to load groups: ', (time.time() - t0))
log.write('\nTime to load groups: ' + str((time.time() - t0)))

Loading groups/seeds: ...
Total number of groups:  396
Time to load groups:  0.0046346187591552734


41

In [13]:
# Read fasta dataset
t1 = time.time()
print('Loading reads ...')
reads, labels = load_meta_reads(dataset, type='fasta')
print('Total number of reads: ', len(labels))
log.write('\nTotal number of reads: ' + str(len(labels)))
print('Time to load reads: ', (time.time() - t1))
log.write('\nTime to load reads: ' + str((time.time() - t1)))

Loading reads ...
Total number of reads:  400
Time to load reads:  0.03836774826049805


40

In [19]:
t2 = time.time()
# Creating bert corpus...
corpus = build_bert_corpus_v2(reads, k_mer_length=4)
print('Time to create corpus from reads: ', (time.time() - t2))
log.write('\nTime to create corpus from reads: ' + str((time.time() - t2)))

Time to create corpus from reads:  0.02024102210998535


55

In [21]:
t3 = time.time()
print('Save corpus for current dataset ...')
# Save corpus for current dataset
bert_corpus_file = os.path.join(RESULT_DIR, dataset_name + '.bert_corpus.txt')
if is_save_each_corpus:
    with open(bert_corpus_file, 'w') as f:
        f.writelines(corpus)
        print('Saved bert corpus to ', bert_corpus_file)
print('Time to save corpus for current dataset: ', (time.time() - t3))
log.write('\nTime to save corpus for current dataset: ' + str((time.time() - t3)))

Save corpus for current dataset ...
Saved bert corpus to  ../../Data/test/output/bertbimetaout/S1_test.bert_corpus.txt
Time to save corpus for current dataset:  0.0028955936431884766


62

In [27]:
sys.path.append('.')

In [28]:
os.getcwd()

'/home/hoangqd/Projects/nlp-bimeta-binning'

In [1]:
!bert-vocab -c '/home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert_corpus.txt' -o '/home/hoangqd/Data/test/output/bertbimetaout/vocab'

Building Vocab
400it [00:00, 12586.15it/s]
VOCAB SIZE: 656


In [2]:
!cat '/home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert_corpus.txt' | head -10

ATAA TAAT AATT ATTG TTGG TGGC GGCA GCAA CAAG AAGT AGTG GTGT TGTT GTTT TTTT TTTA TTAG TAGT AGTC GTCT TCTT CTTA TTAG TAGA AGAG GAGA AGAG GAGA AGAT GATT ATTC TTCT TCTC CTCT TCTA CTAA TAAG AAGT AGTC GTCT TCTA CTAA TAAC AACT ACTT CTTG TTGA TGAA GAAC AACT ACTC CTCA TCAA CAAT AATT ATTT TTTG TTGG TGGA GGAA GAAT AATC ATCA TCAT CATT ATTT TTTC TTCC TCCC CCCA CCAA CAAT AATT ATTT TTTT TTTT TTTA	TTTC TTCA TCAA CAAA AAAC AACA ACAC CACT ACTT CTTT TTTA TTAC TACA ACAC CACC ACCT CCTC CTCT TCTA CTAC TACC ACCA CCAT CATT ATTC TTCA TCAT CATT ATTC TTCA TCAA CAAT AATT ATTG TTGG TGGA GGAT GATC ATCA TCAC CACA ACAA CAAA AAAT AATA ATAC TACA ACAG CAGA AGAG GAGC AGCA GCAG CAGT AGTG GTGT TGTA GTAT TATT ATTT TTTG TTGA TGAG GAGA AGAT GATA ATAT TATC ATCC TCCT CCTG CTGA TGAA GAAA AAAG AAGA AGAT
TAAT AATT ATTA TTAG TAGT AGTT GTTA TTAG TAGG AGGT GGTA GTAA TAAA AAAG AAGG AGGA GGAA GAAC AACC ACCT CCTT CTTG TTGT TGTT GTTA TTAA TAAT AATA ATAA TAAG AAGA AGAC GACT ACTA CTAG TAGG AGGT GGTT GTTT TTTT TTTA TTAT TATT ATTA TTAA TAAC 

In [4]:
!bert --train_dataset '/home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert_corpus.txt' \
    --vocab_path '/home/hoangqd/Data/test/output/bertbimetaout/vocab' \
    --output_path '/home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert.model' \
    --epochs 20 \
    --hidden 128 \
    --layers 4 \
    --attn_heads 4 \
    --seq_len 256 \
    --batch_size 16 \
    --lr 1e-4 \
    --log_freq 100

Loading Vocab /home/hoangqd/Data/test/output/bertbimetaout/vocab
Vocab Size:  656
Loading Train Dataset /home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert_corpus.txt
Loading Dataset: 400it [00:00, 213152.28it/s]
Loading Test Dataset None
Creating Dataloader
Building BERT model
Creating BERT Trainer
Total Parameters: 962322
Training Start
{'epoch': 0, 'iter': 0, 'avg_loss': 7.96380615234375, 'avg_acc': 50.0, 'loss': 7.96380615234375}
EP_train:0: 100%|| 25/25 [00:23<00:00,  1.07it/s]
EP0_train, avg_loss= 8.09812650680542 total_acc= 49.25
EP:0 Model Saved on: /home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert.model.ep0
{'epoch': 1, 'iter': 0, 'avg_loss': 8.26927661895752, 'avg_acc': 25.0, 'loss': 8.26927661895752}
EP_train:1: 100%|| 25/25 [00:29<00:00,  1.18s/it]
EP1_train, avg_loss= 7.81920804977417 total_acc= 49.5
EP:1 Model Saved on: /home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert.model.ep1
{'epoch': 2, 'iter': 0, 'avg_loss': 7.7504496574401855, 'avg_acc': 50.0,

In [22]:
t4 = time.time()
print('Train bert  model ...')
bert_model = torch.load('/home/hoangqd/Data/test/output/bertbimetaout/S1_test.bert.model.ep19')

vocab = WordVocab.load_vocab('/home/hoangqd/Data/test/output/bertbimetaout/vocab')
bert_dataset = BERTDataset_ForInference(bert_corpus_file, vocab, seq_len=256, on_memory=True)
bert_data_loader = DataLoader(bert_dataset, batch_size=1, num_workers=1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = bert_model.to(device) # device=cuda --> infer faster
bert_model = bert_model.eval()

all_embedding = []
for idx, seed in enumerate(seeds):
    # each element of this list is a read embedding
    seed_embedding = []
    for read_idx in seed:
        data = bert_dataset[read_idx]
        data = {key: value for key, value in data.items()}

        # 1. forward the next_sentence_prediction and masked_lm model
        with torch.no_grad():
            read_embedding = bert_model(torch.unsqueeze(data["bert_input"], 0).to(device), torch.unsqueeze(data["segment_label"], 0).to(device)) # (1x256x256) ~ (batch_size x seq_length x hidden_length)
            read_embedding = read_embedding.detach().cpu().numpy()[0] # (256 x 256) ~ (seq_length x hidden_length)
            seed_embedding.append(read_embedding)
    
    # Convert to numpy array, and get the average of read embedding
    seed_embedding = np.array(seed_embedding) # (len(seeds), 256, 256)
    # print(seed_embedding.shape)
    seed_embedding = np.mean(seed_embedding, axis=0) # (256, 256)
    # print(seed_embedding.shape)

    # Continue to average across token: (256, 256) -> (256,)
    seed_embedding = np.mean(seed_embedding, axis=0)
    # print(seed_embedding.shape)
    all_embedding.append(seed_embedding)
    # break


print(len(all_embedding))
print('Train bert model time: ', (time.time() - t4))
log.write('\nCTrain bert model time: ' + str(time.time() - t4))

Loading Dataset: 400it [00:00, 135716.03it/s]

Train bert  model ...





396
Train bert model time:  9.322391986846924


42

In [23]:
# Clustering groups
t6 = time.time()
print('Clustering ...')
kmeans = KMeans(
    init="random",
    n_clusters=n_clusters,
    n_init=100,
    max_iter=200,
    random_state=20210905)
kmeans.fit(X=all_embedding, y=labels)
y_pred_kmeans = kmeans.predict(X=all_embedding)
print('Clustering time: ', (time.time() - t6))
log.write('\nClustering time: ' + str(time.time() - t6))

Clustering ...
Clustering time:  0.5102109909057617


36

In [24]:
# Map read to group and compute F-measure
t7 = time.time()
print('Compute F-measure ...')
log.write('\nCompute measures: ')
groupPrec = group_precision_recall(labels, groups, n_clusters)[0]
f1 = genome_acc(groups, y_pred_kmeans, labels, n_clusters)[2]
print('Group precision: ', groupPrec)
print('F1-score: ', f1)
print('Total time: ', (time.time() - t0))
log.write('\nGroup precision: ' + str(groupPrec))
log.write('\nF1-score: ' + str(f1))
log.write('\nTotal time: ' + str(time.time() - t0))


Compute F-measure ...
Group precision:  1.0
F1-score:  0.5255819477434679
Total time:  702.6991295814514


30