# Set up environment

In [None]:
!pip install biopython
!git clone https://github.com/networkx/networkx-metis.git &> /dev/null
%cd networkx-metis
!python setup.py build &> /dev/null
!python setup.py install &> /dev/null

/content/networkx-metis


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install bert-pytorch



In [None]:
#import sys
#sys.path.append('/usr/local/lib/python3.7/dist-packages')

In [None]:
import os
os.environ['PYTHONPATH'] = ('PYTHONPATH:/content/drive/MyDrive/metagenomic-binning/nlp-bimeta')
!export | grep PYTHONPATH

declare -x PYTHONPATH="PYTHONPATH:/content/drive/MyDrive/metagenomic-binning/nlp-bimeta"


In [None]:
import os
os.getcwd()

'/content/networkx-metis'

In [None]:
%cd ..
%cd drive/MyDrive/metagenomic-binning/nlp-bimeta

/content
/content/drive/MyDrive/metagenomic-binning/nlp-bimeta


In [None]:
!ls -la /content/drive/MyDrive/metagenomic-binning/nlp-bimeta

total 43
-rw------- 1 root root 26881 Sep 10 23:36 bertbimeta.ipynb
drwx------ 2 root root  4096 Sep  8 08:15 config
drwx------ 2 root root  4096 Sep  8 08:15 dataset
drwx------ 2 root root  4096 Sep  8 08:15 debug
drwx------ 2 root root  4096 Sep  8 08:15 utils


# Begin bert-bimeta

In [None]:
import glob, os, time, sys
import json
import numpy as np
from collections import defaultdict
import argparse

from dataset.genome import GenomeDataset
from dataset.utils import load_meta_reads, build_bert_corpus_v2
import dataset.bertds as bertds
import utils.utils as utils
from debug.visualize import get_group_label, visualize
from utils.metrics import genome_acc, group_precision_recall

import torch
from bert_pytorch.dataset import BERTDataset, WordVocab
from torch.utils.data import DataLoader

from sklearn.cluster import KMeans

import tqdm

In [None]:
sys.path.append('.')

DATASET_DIR = '../data/input/'
# Specifc dataset or all of them
DATASET_NAME = 'S1'
BIMETAOUT_DIR = '../data/output/bimetaout/'
BERTBASE_DIR = '../data/output/bertbase/'
BERTBIMETAOUT_DIR = '../data/output/bertbimetaout/'

In [None]:
if DATASET_NAME == 'all':
    raw_datasets = glob.glob(DATASET_DIR + '/*.fna')
else:
    if type(DATASET_NAME) == list:
        raw_datasets = [os.path.join(DATASET_DIR, ds_name + '.fna') for ds_name in DATASET_NAME]
    else:    
        raw_datasets = [os.path.join(DATASET_DIR, DATASET_NAME + '.fna')]

# Mapping of dataset and its corresponding number of clusters
with open('config/dataset_metadata.json', 'r') as f:
    n_clusters_mapping = json.load(f)['datasets']

raw_datasets.sort()

In [None]:
raw_datasets

['../data/input/S1.fna']

In [None]:
dataset = raw_datasets[0]
dataset_name = os.path.basename(dataset).split('.fna')[0]

print("-------------------------------------------------------")
print('Processing dataset: ', dataset_name)

bimetaout_file = os.path.join(BIMETAOUT_DIR, dataset_name + '.json')

n_clusters = n_clusters_mapping[dataset_name]

print('Prior number of clusters: ', n_clusters)

-------------------------------------------------------
Processing dataset:  S1
Prior number of clusters:  2


In [None]:
t0 = time.time()
# Load group file (phase 1 of bimeta) according to dataset_name
print('Loading groups/seeds: ...')
groups, seeds = utils.load_groups_seeds(BIMETAOUT_DIR, dataset_name)
print('Total number of groups: ', len(groups))
print('Time to load groups: ', (time.time() - t0))

Loading groups/seeds: ...
Total number of groups:  152
Time to load groups:  1.1665575504302979


In [None]:
# Read fasta dataset
t1 = time.time()
print('Loading reads ...')
reads, labels = load_meta_reads(dataset, type='fasta')
print('Total number of reads: ', len(labels))
print('Time to load reads: ', (time.time() - t1))

Loading reads ...
Total number of reads:  96367
Time to load reads:  6.485032320022583


In [None]:
t2 = time.time()
# Creating bert corpus...
corpus = build_bert_corpus_v2(reads, k_mer_length=4)
print('Time to create corpus from reads: ', (time.time() - t2))

Time to create corpus from reads:  3.588386058807373


In [None]:
t3 = time.time()
print('Save corpus for current dataset ...')
is_save_each_corpus = True
bert_corpus_file = os.path.join(BERTBASE_DIR, dataset_name + '.bert.corpus.txt')
if is_save_each_corpus:
    with open(bert_corpus_file, 'w') as f:
        f.writelines(corpus)
        print('Saved bert corpus to ', bert_corpus_file)
print('Time to save corpus for current dataset: ', (time.time() - t3))

Save corpus for current dataset ...
Saved bert corpus to  ../data/output/bertbase/S1.bert.corpus.txt
Time to save corpus for current dataset:  1.5450563430786133


In [None]:
!cat '../data/output/bertbase/S1.bert.corpus.txt' | head -2

ATAA TAAT AATT ATTG TTGG TGGC GGCA GCAA CAAG AAGT AGTG GTGT TGTT GTTT TTTT TTTA TTAG TAGT AGTC GTCT TCTT CTTA TTAG TAGA AGAG GAGA AGAG GAGA AGAT GATT ATTC TTCT TCTC CTCT TCTA CTAA TAAG AAGT AGTC GTCT TCTA CTAA TAAC AACT ACTT CTTG TTGA TGAA GAAC AACT ACTC CTCA TCAA CAAT AATT ATTT TTTG TTGG TGGA GGAA GAAT AATC ATCA TCAT CATT ATTT TTTC TTCC TCCC CCCA CCAA CAAT AATT ATTT TTTT TTTT TTTA TTAT 	 TATT ATTT TTTC TTCA TCAA CAAA AAAC AACA ACAC CACT ACTT CTTT TTTA TTAC TACA ACAC CACC ACCT CCTC CTCT TCTA CTAC TACC ACCA CCAT CATT ATTC TTCA TCAT CATT ATTC TTCA TCAA CAAT AATT ATTG TTGG TGGA GGAT GATC ATCA TCAC CACA ACAA CAAA AAAT AATA ATAC TACA ACAG CAGA AGAG GAGC AGCA GCAG CAGT AGTG GTGT TGTA GTAT TATT ATTT TTTG TTGA TGAG GAGA AGAT GATA ATAT TATC ATCC TCCT CCTG CTGA TGAA GAAA AAAG AAGA AGAT
TAAT AATT ATTA TTAG TAGT AGTT GTTA TTAG TAGG AGGT GGTA GTAA TAAA AAAG AAGG AGGA GGAA GAAC AACC ACCT CCTT CTTG TTGT TGTT GTTA TTAA TAAT AATA ATAA TAAG AAGA AGAC GACT ACTA CTAG TAGG AGGT GGTT GTTT TTTT TTTA TTAT TAT

In [None]:
!bert-vocab -c '../data/output/bertbase/S1.bert.corpus.txt' -o '../data/output/bertbase/vocab.txt'

Building Vocab
96367it [00:06, 15754.28it/s]
VOCAB SIZE: 261


In [None]:
!bert --train_dataset '../data/output/bertbase/S1.bert.corpus.txt' \
    --vocab_path '../data/output/bertbase/vocab.txt' \
    --output_path '../data/output/bertbase/S1.2.bert.model' \
    --epochs 2 \
    --hidden 32 \
    --layers 1 \
    --attn_heads 4 \
    --seq_len 64 \
    --batch_size 64 \
    --lr 1e-4 \
    --log_freq 200

Loading Vocab ../data/output/bertbase/vocab.txt
Vocab Size:  261
Loading Train Dataset ../data/output/bertbase/S1.bert.corpus.txt
Loading Dataset: 96367it [00:00, 225053.76it/s]
Loading Test Dataset None
Creating Dataloader
  cpuset_checked))
Building BERT model
Creating BERT Trainer
Total Parameters: 29831
Training Start
{'epoch': 0, 'iter': 0, 'avg_loss': 6.877734184265137, 'avg_acc': 57.8125, 'loss': 6.877734184265137}
{'epoch': 0, 'iter': 200, 'avg_loss': 6.72452566991398, 'avg_acc': 50.108830845771145, 'loss': 6.443744659423828}
{'epoch': 0, 'iter': 400, 'avg_loss': 6.369413800370366, 'avg_acc': 50.02727556109726, 'loss': 5.746042251586914}
{'epoch': 0, 'iter': 600, 'avg_loss': 6.119032285376119, 'avg_acc': 49.963602329450914, 'loss': 5.500411510467529}
{'epoch': 0, 'iter': 800, 'avg_loss': 5.928880637951112, 'avg_acc': 49.992197253433204, 'loss': 5.1761088371276855}
{'epoch': 0, 'iter': 1000, 'avg_loss': 5.782960463475276, 'avg_acc': 49.93444055944056, 'loss': 5.190054416656494}


In [None]:
t4 = time.time()
print('Load bert model ...')
bert_model = torch.load('../data/output/bertbase/S1.2.bert.model.ep1')

vocab = WordVocab.load_vocab('../data/output/bertbase/vocab.txt')
bert_dataset = bertds.BERTDataset_ForInference(bert_corpus_file, vocab, seq_len=64, on_memory=True)
bert_data_loader = DataLoader(bert_dataset, batch_size=1, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = bert_model.to(device) # device=cuda --> infer faster
bert_model = bert_model.eval()

print('Get bert embeddings for read and seed ...')
all_seed_embeddings = []
for idx, seed in enumerate(seeds):
    # each element of this list is a read embedding
    seed_embedding = []
    for read_idx in seed:
        data = bert_dataset[read_idx]
        data = {key: value for key, value in data.items()}

        # 1. forward the next_sentence_prediction and masked_lm model
        with torch.no_grad():
            read_embedding = bert_model(torch.unsqueeze(data["bert_input"], 0).to(device), torch.unsqueeze(data["segment_label"], 0).to(device)) # (1x256x256) ~ (batch_size x seq_length x hidden_length)
            read_embedding = read_embedding.detach().cpu().numpy()[0] # (256 x 256) ~ (seq_length x hidden_length)
            seed_embedding.append(read_embedding)
    
    # Convert to numpy array, and get the average of read embedding
    seed_embedding = np.array(seed_embedding) # (len(seeds), 256, 256)
    # print(seed_embedding.shape)
    seed_embedding = np.mean(seed_embedding, axis=0) # (256, 256)
    # print(seed_embedding.shape)

    # Continue to average across token: (256, 256) -> (256,)
    seed_embedding = np.mean(seed_embedding, axis=0)
    # print(seed_embedding.shape)
    all_seed_embeddings.append(seed_embedding)
    # break

print(len(all_seed_embeddings))
print('Time to get bert embeddings for read and seed: ', (time.time() - t4))

Load bert model ...


Loading Dataset: 96367it [00:00, 243568.54it/s]


Get bert embeddings for read and seed ...
152
Time to get bert embeddings for read and seed:  41.27671790122986


In [None]:
# Clustering groups
t6 = time.time()
print('Clustering ...')
skmeans = KMeans(
    init="random",
    n_clusters=n_clusters,
    n_init=100,
    max_iter=200,
    random_state=20210905)
skmeans.fit(X=all_seed_embeddings, y=labels)
y_pred_skmeans = skmeans.predict(X=all_seed_embeddings)
print('Clustering time: ', (time.time() - t6))

Clustering ...
Clustering time:  0.27023887634277344


In [None]:
# Map read to group and compute F-measure
t7 = time.time()
print('Compute F-measure ...')
groupPrec = group_precision_recall(labels, groups, n_clusters)[0]
f1s = genome_acc(groups, y_pred_skmeans, labels, n_clusters)[2]
print('Group precision: ', groupPrec)
print('F1-score (using seed): ', f1s)
print('Compute F-measure time: ', (time.time() - t7))


Compute F-measure ...
Group precision:  0.9898824286322081
F1-score (using seed):  0.700545087530471
Compute F-measure time:  0.435455322265625
