In [1]:
# imports
import pandas as pd
from graph import Graph
from copy import deepcopy
import numpy as np

In [2]:
# read in data
df_abundance = pd.read_csv('df_phylum.csv')
metadata = pd.read_csv("samp_metadata.csv")

In [3]:
# Phylum level data
df_phylum = df_abundance.groupby('Phylum').sum()
# remove columns that is not start with "Samp"
df_phylum = df_phylum[df_phylum.columns[df_phylum.columns.str.startswith('Samp')]]
df_phylum.head()

Unnamed: 0_level_0,Samp001,Samp002,Samp003,Samp004,Samp005,Samp006,Samp007,Samp008,Samp009,Samp010,...,Samp231,Samp232,Samp233,Samp234,Samp235,Samp236,Samp237,Samp238,Samp239,Samp240
Phylum,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Abditibacteriota,6,0,3,4,2,1,5,9,7,7,...,31,38,5,5,8,21,4,8,6,9
Acidobacteriota,734,63,570,498,327,211,1188,647,643,1354,...,898,519,2225,1128,1539,2033,2677,2810,2264,1758
Actinomycetota,392,32,322,480,120,107,446,332,326,460,...,4797,2350,2582,1548,3249,4785,3051,3052,2969,2875
Armatimonadota,31,0,20,9,5,7,43,27,34,48,...,56,27,63,27,36,74,55,37,65,59
BRC1_bacterium_SCGC_AAA252-M09,19,0,9,5,7,5,27,19,16,26,...,33,15,21,20,15,16,27,22,42,24


In [4]:
len(df_phylum)

37

In [5]:
# test graph
tmp = Graph()

In [6]:
tmp.build_graph()

In [7]:
tmp.prune_graph(df_phylum)

Pruning Tree...


In [8]:
def generate_maps(x, g, f, p=-1):
	temp_g = deepcopy(g)
	temp_g.populate_graph(f, x)
	map = temp_g.get_map()
	vector = temp_g.graph_vector()
	del(temp_g)
	return x, np.array(map), np.array(vector)

In [9]:
data_matrices = []

for i in range(len(df_phylum.T)):
    data_matrices.append(generate_maps(df_phylum.T.values[i], tmp, df_phylum))

Abditibacteriota
Acidobacteriota
Actinomycetota
Armatimonadota
BRC1_bacterium_SCGC_AAA252-M09
Bacillota
Bacteroidia
Candidatus_Dadabacteria
Candidatus_Dependentiae
Candidatus_Diapherotrites
Candidatus_Hydrogenedentota
Candidatus_Latescibacterota
Candidatus_Omnitrophota
Candidatus_Rokuibacteriota
Candidatus_Tectomicrobia
Candidatus_Zixiibacteriota
Chlamydiota
Chloroflexota
Cyanobacteriota
Deinococcota
Elusimicrobia
Euryarchaeota
FCB_group
Gemmatimonadota
Kiritimatiellota
Mycoplasmatota
Nanoarchaeota
Nematoda
Nitrososphaerota
Nitrospirota
Planctomycetota
Pseudomonadota
Rhodophyta
Spirochaetota
Thermoprotei
Verrucomicrobiota
candidate_division_CPR1
Abditibacteriota
Acidobacteriota
Actinomycetota
Armatimonadota
BRC1_bacterium_SCGC_AAA252-M09
Bacillota
Bacteroidia
Candidatus_Dadabacteria
Candidatus_Dependentiae
Candidatus_Diapherotrites
Candidatus_Hydrogenedentota
Candidatus_Latescibacterota
Candidatus_Omnitrophota
Candidatus_Rokuibacteriota
Candidatus_Tectomicrobia
Candidatus_Zixiibacterio

In [11]:
metadata.drop(columns=['Unnamed: 0'], inplace=True)
# 0 = non-fumigated, 1 = recently fumigated, 2 = fumigated more than a month ago
metadata["Fumigation"] = None

for i, row in enumerate(metadata.itertuples()):
    if row.Time == "Day_0" or row.Treatment == "Non-fumigated chipping grass":
        metadata.at[i, "Fumigation"] = 0
    elif row.Time == "Day_10" and row.Treatment != "Non-fumigated chipping grass":
        metadata.at[i, "Fumigation"] = 1
    else:
        metadata.at[i, "Fumigation"] = 2

metadata

Unnamed: 0,Samp,samp_number,Time,Treatment,Date,Fumigation
0,Fum-1-1,1,Day_0,Non-amended,2016-08-08,0
1,Fum-1-2,2,Day_0,Non-amended,2016-08-08,0
2,Fum-1-3,3,Day_0,Non-amended,2016-08-08,0
3,Fum-1-4,4,Day_0,Non-amended,2016-08-08,0
4,Fum-1-5,5,Day_0,Non-amended,2016-08-08,0
...,...,...,...,...,...,...
235,Fum-1-236,236,Day_282,Non-fumigated chipping grass,2017-05-17,0
236,Fum-1-237,237,Day_282,Non-fumigated chipping grass,2017-05-17,0
237,Fum-1-238,238,Day_282,Non-fumigated chipping grass,2017-05-17,0
238,Fum-1-239,239,Day_282,Non-fumigated chipping grass,2017-05-17,0


In [12]:
y = metadata["Fumigation"].values.astype(int)

In [26]:
X_val = [mat[1] for mat in data_matrices]

In [27]:
# build the CNN and train test split
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X_val, y, test_size=0.2, random_state=42)

In [28]:
# dataset and dataloader
from torch.utils.data import Dataset, DataLoader
import torch

class MicrobiomeDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    
trainDataset = MicrobiomeDataset(X_train, y_train)
testDataset = MicrobiomeDataset(X_test, y_test)

In [29]:
train_loader = DataLoader(trainDataset, batch_size=8, shuffle=True)
test_loader = DataLoader(testDataset, batch_size=8, shuffle=False)

In [73]:
# build the CNN
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNMicrobiome(nn.Module):
    def __init__(self):
        super(CNNMicrobiome, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.flatten = nn.Flatten()
        
        self.fc1 = nn.Linear(768, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(64, 3)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x

In [74]:
# training the model
model = CNNMicrobiome()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
EPOCHS = 100
log = {"train_loss": [], "test_loss": [], "train_acc": [], "test_acc": []}

In [75]:
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data.unsqueeze(1).float())
        loss = loss_fn(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_correct += (output.argmax(1) == target).type(torch.float).sum().item()
        train_total += len(target)
    train_loss /= train_total
    train_acc = train_correct / train_total
    
    model.eval()
    test_loss = 0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.unsqueeze(1).float())
            loss = loss_fn(output, target)
            test_loss += loss.item()
            test_correct += (output.argmax(1) == target).type(torch.float).sum().item()
            test_total += len(target)
    test_loss /= test_total
    test_acc = test_correct / test_total
    
    log["train_loss"].append(train_loss)
    log["test_loss"].append(test_loss)
    log["train_acc"].append(train_acc)
    log["test_acc"].append(test_acc)
    print(train_correct, train_total)
    print(test_correct, test_total)
    
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

104.0 192
26.0 48
Epoch 1/100, Train Loss: 4.0926, Train Acc: 0.5417, Test Loss: 7.7410, Test Acc: 0.5417
142.0 192
42.0 48
Epoch 2/100, Train Loss: 1.7599, Train Acc: 0.7396, Test Loss: 0.0974, Test Acc: 0.8750
186.0 192
45.0 48
Epoch 3/100, Train Loss: 0.0149, Train Acc: 0.9688, Test Loss: 0.0192, Test Acc: 0.9375
186.0 192
48.0 48
Epoch 4/100, Train Loss: 0.0252, Train Acc: 0.9688, Test Loss: 0.0027, Test Acc: 1.0000
185.0 192
48.0 48
Epoch 5/100, Train Loss: 0.0384, Train Acc: 0.9635, Test Loss: 0.0028, Test Acc: 1.0000
191.0 192
47.0 48
Epoch 6/100, Train Loss: 0.0051, Train Acc: 0.9948, Test Loss: 0.0048, Test Acc: 0.9792
187.0 192
48.0 48
Epoch 7/100, Train Loss: 0.0185, Train Acc: 0.9740, Test Loss: 0.0036, Test Acc: 1.0000
190.0 192
48.0 48
Epoch 8/100, Train Loss: 0.0093, Train Acc: 0.9896, Test Loss: 0.0029, Test Acc: 1.0000
190.0 192
48.0 48
Epoch 9/100, Train Loss: 0.0027, Train Acc: 0.9896, Test Loss: 0.0027, Test Acc: 1.0000
189.0 192
43.0 48
Epoch 10/100, Train Loss: 0.