In [1]:
import sys
sys.path.append('../src/') # for graph_data

import graph_data as gd
import time
import torch
from models import TreeSupport
from torch import optim
from torch.nn import MSELoss
from torch.utils.data import random_split, ConcatDataset
from torch_geometric.data import DataLoader
from Bio import Phylo as phy

import pandas as pd
import numpy as np
import altair as alt

alt.data_transformers.enable('data_server')
alt.renderers.enable('mimetype')

torch.manual_seed(245)

<torch._C.Generator at 0x7f2b66aed750>

# Datasets
## Prepare trees and reconstructed sequences

In [2]:
data_path = '../data/'

In [3]:
target_tree = phy.read(data_path+'tree/Fungi45_infer/Fungi.tre', 'newick')
#all_data = gd.Trees.load_ready_trees(data_path+'tree/Fungi45_infer/fml_output/', data_path+'tree/Fungi45_infer/fml_output/*.tre', target_tree)  # INITIALIZE NEW DS
fungi_data = gd.Trees(data_path+'tree/Fungi45_infer/', data_path+'alns/Fungi45/', target_tree)
fungi_data.data.x = fungi_data.data.x.float()
fungi_data.data.edge_attr = (torch.max(fungi_data.data.edge_attr)+0.001 - fungi_data.data.edge_attr).float()
fungi_data.data.edge_index = fungi_data.data.edge_index.long()

In [4]:
target_tree = phy.read(data_path+'tree/Archaea/Archaea.tre', 'newick')
#archaea_data = gd.Trees.load_ready_trees(data_path+'tree/Archaea/fml_output/', data_path+'tree/Archaea/fml_output/*.tre', target_tree)  # INITIALIZE NEW DS //delete temp_tree
archaea_data = gd.Trees(data_path+"tree/Archaea/", data_path+"alns/Archaea/", target_tree)
archaea_data.data.x = archaea_data.data.x.float()
archaea_data.data.edge_attr = (torch.max(archaea_data.data.edge_attr)+0.001 - archaea_data.data.edge_attr).float()
archaea_data.data.edge_index = archaea_data.data.edge_index.long()

In [5]:
# concatenate datasets
all_data =  ConcatDataset([fungi_data, archaea_data])

In [6]:
# data preparation
train_batch = 50
test_batch = 50
# data loaders
train, test = random_split(
    all_data, [round(len(all_data) * 0.8), round(len(all_data) * 0.2)]
)

In [7]:
lst = np.zeros(len(train))
for i, data in enumerate(train):
    lst[i] = -(torch.sum(torch.log2(data.y.squeeze()))/len(data.y)).item()
lst = pd.DataFrame(lst, columns=['Q'])
alt.Chart(lst).mark_bar().encode(
    alt.X("Q:Q", bin=alt.Bin(maxbins=100)),
    y='count()',
)

<VegaLite 4 object>

If you see this message, it means the renderer has not been properly enabled
for the frontend that you are using. For more information, see
https://altair-viz.github.io/user_guide/troubleshooting.html


# Train model

In [8]:
train_dl = DataLoader(
    train, batch_size=train_batch, pin_memory=True, num_workers=3
)
test_dl = DataLoader(
    test, batch_size=test_batch, num_workers=2
)

In [9]:
model = TreeSupport(231, 400)
model = model.cuda()
silent = False
num_epochs = 200
loss_fn = MSELoss()
test_batches = len(test_dl)
# stattr
losses = []
val_losses = []
learning_rates = []

In [10]:
# interactive descending lr for less loss
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.5, verbose=True, cooldown=2, patience=5
)
num_epochs = 600
start = time.time()
for epoch in range(num_epochs):
    model.train()
    for data in train_dl:
        data = data.to(torch.device('cuda'))
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        losses.append(loss.item())  # history
        # optimize
        loss.backward()
        optimizer.step()

        # stats
        learning_rates.append(-1)

    # evaluation
    model.eval()
    agg_loss = 0
    with torch.autograd.no_grad():
        for data in test_dl:
            data = data.to(torch.device('cuda'))
            out = model(data)
            agg_loss += loss_fn(out, data.y)
        val_losses.append(agg_loss.item()/test_batches)
    scheduler.step(agg_loss)
    if not silent:
        print(
            "Epoch [{}/{}], Loss (last training batch/val): {:.4f}/{:.4f}. Time elapsed: {:.2f}".format(
                epoch + 1,
                num_epochs,
                losses[-1],
                val_losses[-1],
                time.time() - start,
            )
        )

Epoch [1/600], Loss (last training batch/val): 0.2635/0.2633. Time elapsed: 2.58
Epoch [2/600], Loss (last training batch/val): 0.2527/0.2518. Time elapsed: 4.77
Epoch [3/600], Loss (last training batch/val): 0.2472/0.2466. Time elapsed: 6.94
Epoch [4/600], Loss (last training batch/val): 0.2413/0.2415. Time elapsed: 9.36
Epoch [5/600], Loss (last training batch/val): 0.2353/0.2359. Time elapsed: 11.54
Epoch [6/600], Loss (last training batch/val): 0.2297/0.2305. Time elapsed: 13.76
Epoch [7/600], Loss (last training batch/val): 0.2259/0.2270. Time elapsed: 15.98
Epoch [8/600], Loss (last training batch/val): 0.2203/0.2218. Time elapsed: 18.19
Epoch [9/600], Loss (last training batch/val): 0.2141/0.2152. Time elapsed: 20.38
Epoch [10/600], Loss (last training batch/val): 0.2090/0.2098. Time elapsed: 22.57
Epoch [11/600], Loss (last training batch/val): 0.2030/0.2042. Time elapsed: 24.75
Epoch [12/600], Loss (last training batch/val): 0.2172/0.2055. Time elapsed: 26.95
Epoch [13/600], L

# Learning stats

In [11]:
adata = pd.DataFrame(
    {
        "LR": learning_rates[len(train_dl)-1::len(train_dl)],
        "Training":losses[len(train_dl)-1::len(train_dl)], 
        "Validation":val_losses
    }
).reset_index()
alt.Chart(adata, width=900, height=600).transform_fold(["Training", "Validation"]).mark_line().encode(
    x=alt.X("index:Q", axis=alt.Axis(title='Epoch')),
    y=alt.Y("value:Q", title="Loss"),
    color='key:N'
)

<VegaLite 4 object>

If you see this message, it means the renderer has not been properly enabled
for the frontend that you are using. For more information, see
https://altair-viz.github.io/user_guide/troubleshooting.html


In [12]:
torch.save(model.state_dict(), '../models/NewTry.dct')

In [13]:
from sklearn import metrics

## Test with completely different data

In [14]:
# data
target_tree = phy.read(data_path+'tree/Other_eukaryota_2018.tre', 'newick')
test_dataset = gd.Trees(
    data_path+"tree/Eukaryota/",
    data_path+"alns/other_eukaryota_2018/",
    target_tree
)
test_dataset.data.x = test_dataset.data.x.float()
test_dataset.data.edge_attr = (torch.max(test_dataset.data.edge_attr)+0.001 - test_dataset.data.edge_attr).float()
test_dataset.data.edge_index = test_dataset.data.edge_index.long()
test_dl = DataLoader(
    test_dataset, batch_size=test_batch, num_workers=2
)

In [31]:
np.concatenate(nplist).shape

(271930, 2)

In [24]:
nplist = list()
for data in test_dl:
    with torch.no_grad():
        data.x, data.edge_index, data.edge_attr = data.x.float(), data.edge_index.long(), data.edge_attr.float()  # CLEAN IT!!!!!
        out = model(data.to(torch.device('cuda')))
        nplist.append(np.concatenate([data.y.cpu().numpy(), out.detach().cpu().numpy()], axis=1))

In [32]:
test_results = pd.DataFrame(np.concatenate(nplist), columns=['actual', 'predicted'])

In [39]:
alt.Chart(test_results, width=900, height=600).mark_boxplot(clip=True, outliers=False).encode(
    x=alt.X('actual', scale=alt.Scale(domain=(0, 1))),
    y=alt.Y('predicted', scale=alt.Scale(domain=(0, 2)))
)

<VegaLite 4 object>

If you see this message, it means the renderer has not been properly enabled
for the frontend that you are using. For more information, see
https://altair-viz.github.io/user_guide/troubleshooting.html


# Further assessment

In [41]:
# first half of the nodes
nplist = list()
for data in test_dataset:
    with torch.no_grad():
        data.x, data.edge_index, data.edge_attr = data.x.float(), data.edge_index.long(), data.edge_attr.float()  # CLEAN IT!!!!!
        out = model(data.to(torch.device('cuda')))[:60]
        nplist.append(np.concatenate([data.y.cpu()[:60].numpy(), out.detach().cpu().numpy()], axis=1))
test_results = pd.DataFrame(np.concatenate(nplist, 0), columns=['actual', 'predicted'])
test_results.describe()

Unnamed: 0,actual,predicted
count,137880.0,137880.0
mean,0.4496009,0.454507
std,0.4964652,0.43897
min,1.92593e-34,0.0
25%,1.525879e-05,0.0
50%,0.0009765625,0.370812
75%,1.0,0.888189
max,1.0,2.197329


In [42]:
fpr, tpr, threshold = metrics.roc_curve(test_results["actual"]==1, test_results["predicted"])
roc_data = pd.DataFrame({"True Positive Rate": tpr, "False Positive Rate": fpr})

In [43]:
alt.Chart(roc_data, width=900, height=600).mark_line().encode(
    x="False Positive Rate",
    y="True Positive Rate"
)

<VegaLite 4 object>

If you see this message, it means the renderer has not been properly enabled
for the frontend that you are using. For more information, see
https://altair-viz.github.io/user_guide/troubleshooting.html


In [44]:
# second half
nplist = list()
for data in test_dataset:
    with torch.no_grad():
        data.x, data.edge_index, data.edge_attr = data.x.float(), data.edge_index.long(), data.edge_attr.float()  # CLEAN IT!!!!!
        out = model(data.to(torch.device('cuda')))
        nplist.append(np.concatenate([data.y.cpu().numpy(), out.detach().cpu().numpy()], axis=1))
test_results = pd.DataFrame(np.concatenate(nplist, 0), columns=['actual', 'predicted'])
test_results.describe()

Unnamed: 0,actual,predicted
count,271930.0,271930.0
mean,0.5098314,0.513201
std,0.4988907,0.437022
min,1.92593e-34,0.0
25%,6.103516e-05,0.0
50%,1.0,0.509334
75%,1.0,0.898723
max,1.0,2.197329


In [45]:
prec, rec, threshold = metrics.precision_recall_curve(test_results["actual"]==1, test_results["predicted"])
roc_data = pd.DataFrame({"Precision": prec, "Recall": rec})

In [46]:
alt.Chart(roc_data, width=900, height=600).mark_line().encode(
    x="Recall",
    y="Precision"
)

<VegaLite 4 object>

If you see this message, it means the renderer has not been properly enabled
for the frontend that you are using. For more information, see
https://altair-viz.github.io/user_guide/troubleshooting.html


In [47]:
test_results.describe()

Unnamed: 0,actual,predicted
count,271930.0,271930.0
mean,0.5098314,0.513201
std,0.4988907,0.437022
min,1.92593e-34,0.0
25%,6.103516e-05,0.0
50%,1.0,0.509334
75%,1.0,0.898723
max,1.0,2.197329


In [48]:
metrics.confusion_matrix(test_results["actual"]==1, test_results["predicted"]>0.628)

array([[133463,    128],
       [  8434, 129905]])

In [None]:
for i in range(len(nplist)):
    print(np.average(-np.log2(nplist[i])[:,0][-np.log2(nplist[i])[:,0]>0]))

In [None]:
y = 1
max(0, -y*(1-0)*100)