In [None]:
import os
os.chdir("/mnt/home/cpedersen/Codes/cancer-net/")

In [None]:
import TCGAData
import torch, torch_geometric.transforms as T, torch.nn.functional as F
import matplotlib.pyplot as plt, numpy as np
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from arch.net import *
import wandb
import optuna

In [None]:
class Objective(object):
    def __init__(self,arch,root,rng,batch,epochs,device):
        self.arch=arch
        self.root=root
        self.rng=rng
        self.batch=batch
        self.epochs=epochs
        self.device=device
        
        ## hardcoding this false for now
        self.parall=False
        
        ## Should be able to construct the dataset before __call__
        ## as this won't change for different trials
        label_mapping = ["LGG", "GBM"]
        self.dataset = TCGAData.TCGADataset(
            root=self.root,
            files=self.root+"/samples.txt",
            label_mapping=label_mapping,
            gene_graph="brain.geneSymbol.gz",
        )

        rng = np.random.default_rng(self.rng)
        rnd_perm = rng.permutation(len(self.dataset))
        self.train_indices = list(rnd_perm[: 3 * len(self.dataset) // 4])
        self.test_indices = list(rnd_perm[3 * len(self.dataset) // 4 :])
        self.train_loader = DataLoader(
            self.dataset,
            batch_size=self.batch,
            sampler=SubsetRandomSampler(self.train_indices),
            drop_last=True,
        )
        self.test_loader = DataLoader(
            self.dataset,
            batch_size=self.batch,
            sampler=SubsetRandomSampler(self.test_indices),
            drop_last=True,
        )

        assert len(self.train_indices) + len(self.test_indices) == len(
            self.dataset
        ), "Train test split with overlap or unused samples!"
    
    def train(self, epoch, report=True):
        self.model.train()

        if epoch == 30:
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = self.lr * 0.5

        if epoch == 60:
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = self.lr * 0.1

        total_loss = 0
        correct = 0
        num_samps = 0
        for data in self.train_loader:
            if not self.parall:
                data = data.to(device)
            self.optimizer.zero_grad()

            output = self.model(data)
            output = output.squeeze()

            if self.parall:
                y = torch.cat([d.y for d in data]).to(output.device)
            else:
                y = data.y

            if len(output.shape) == 1:
                output = output.unsqueeze(0)
            loss = self.criterion(output, y)

            pred = output.max(1)[1]
            correct += pred.eq(y).sum().item()
            total_loss += loss
            loss.backward()
            self.optimizer.step()
            num_samps += len(y)
        if report:
            print(
                "Epoch: {:02d}, Loss: {:.3g}, Train Acc: {:.4f}".format(
                    epoch, total_loss / num_samps, correct / num_samps
                )
            )

        return total_loss / num_samps, correct / num_samps
    
    def test(self):
        self.model.eval()
        correct = 0

        total_loss = 0
        num_samps = 0
        for data in self.test_loader:
            if not self.parall:
                data = data.to(device)
            output = self.model(data)
            output = output.squeeze()

            pred = output.max(1)[1]
            if self.parall:
                y = torch.cat([d.y for d in data]).to(output.device)
            else:
                y = data.y
            loss = self.criterion(output, y)
            total_loss += loss.item()

            correct += pred.eq(y).sum().item()
            num_samps += len(y)
        return total_loss / num_samps, correct / num_samps
    
    def __call__(self,trial):
        print("Suggesting trial")
        # get the value of the hyperparameters
        self.lr = trial.suggest_float("lr", 1e-5, 5e-2, log=True)
        #wd     = trial.suggest_float("wd", 1e-8, 1e-1, log=True)
        #dr     = trial.suggest_float("dr", 0.0,  0.9)
        print("Suggested trial")
        ## Store hyperparams in a config for wandb
        config = {"learning rate": self.lr,
                 "epochs": self.epochs,
                 "batch size": self.batch,
                 "arch": self.arch}

        print('\nTrial number: {}'.format(trial.number))
        print('lr: {}'.format(self.lr))
        wandb.login()
        wandb.init(project="brain-test", entity="chris-pedersen",config=config)
        self.model = GCNNet().to(device)
        wandb.watch(self.model, log_freq=1)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = F.nll_loss
        train_losses = []
        train_acces = []
        test_acces = []
        test_losses = []
        for epoch in range(1, self.epochs):
            report = (epoch) % 10 == 0
            train_loss, train_acc = self.train(epoch, report=report)
            test_loss, test_acc = self.test()
            train_losses.append(train_loss.cpu().detach().numpy())
            test_losses.append(test_loss)
            train_acces.append(train_acc)
            test_acces.append(test_acc)
            wandb.log({"train loss": train_loss,
                       "test loss": test_loss,
                       "train accuracy": train_acc,
                       "test accuracy": test_acc})
            if report:
                print("Test Loss: {:.3g}, Acc: {:.4f}".format(test_loss, test_acc))
        wandb.finish()

In [None]:
arch = "GCN"
batch = 10
rng = 2022
parall = False
epochs=100

if torch.cuda.is_available():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    device = "cpu"

root = "/mnt/home/sgolkar/projects/cancer-net/data/brain"

## Optuna params
study_name = "optuna/brain_test"  # Unique identifier of the study.
storage_name = "sqlite:///{}.db".format(study_name)
n_trials=10

# train networks with bayesian optimization
objective = Objective(arch,root,rng,batch,epochs,device)
sampler = optuna.samplers.TPESampler(n_startup_trials=20)
study = optuna.create_study(study_name=study_name, sampler=sampler, storage=storage_name,
                            load_if_exists=True)
study.optimize(objective, n_trials, gc_after_trial=False)

In [None]:
# Training on brain data using GCN architecture

# Hyperparameters etc:

arch = "GCN"
batch = 10
parall = False
lr = 0.005

root = "/mnt/home/sgolkar/projects/cancer-net/data/brain"
files = "/mnt/home/sgolkar/projects/cancer-net/data/brain/samples.txt"
label_mapping = ["LGG", "GBM"]
dataset = TCGAData.TCGADataset(
    root=root,
    files=files,
    label_mapping=label_mapping,
    gene_graph="brain.geneSymbol.gz",
)

rng = np.random.default_rng(2022)
rnd_perm = rng.permutation(len(dataset))
train_indices = list(rnd_perm[: 3 * len(dataset) // 4])
test_indices = list(rnd_perm[3 * len(dataset) // 4 :])
train_loader = DataLoader(
    dataset,
    batch_size=batch,
    sampler=SubsetRandomSampler(train_indices),
    drop_last=True,
)
test_loader = DataLoader(
    dataset,
    batch_size=batch,
    sampler=SubsetRandomSampler(test_indices),
    drop_last=True,
)

assert len(train_indices) + len(test_indices) == len(
    dataset
), "Train test split with overlap or unused samples!"

model = GCNNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = F.nll_loss


def train(epoch, report=True):
    model.train()

    if epoch == 30:
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr * 0.5

    if epoch == 60:
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr * 0.1

    total_loss = 0
    correct = 0
    num_samps = 0
    for data in train_loader:
        if not parall:
            data = data.to(device)
        optimizer.zero_grad()

        output = model(data)
        output = output.squeeze()

        if parall:
            y = torch.cat([d.y for d in data]).to(output.device)
        else:
            y = data.y

        if len(output.shape) == 1:
            output = output.unsqueeze(0)
        loss = criterion(output, y)

        pred = output.max(1)[1]
        correct += pred.eq(y).sum().item()
        total_loss += loss
        loss.backward()
        optimizer.step()
        num_samps += len(y)
    if report:
        print(
            "Epoch: {:02d}, Loss: {:.3g}, Train Acc: {:.4f}".format(
                epoch, total_loss / num_samps, correct / num_samps
            )
        )

    return total_loss / num_samps, correct / num_samps


def test():
    model.eval()
    correct = 0

    total_loss = 0
    num_samps = 0
    for data in test_loader:
        if not parall:
            data = data.to(device)
        output = model(data)
        output = output.squeeze()

        pred = output.max(1)[1]
        if parall:
            y = torch.cat([d.y for d in data]).to(output.device)
        else:
            y = data.y
        loss = criterion(output, y)
        total_loss += loss.item()

        correct += pred.eq(y).sum().item()
        num_samps += len(y)
    return total_loss / num_samps, correct / num_samps


train_losses = []
train_acces = []
test_acces = []
test_losses = []
for epoch in range(1, 101):
    report = (epoch) % 10 == 0
    train_loss, train_acc = train(epoch, report=report)
    test_loss, test_acc = test()
    train_losses.append(train_loss.cpu().detach().numpy())
    test_losses.append(test_loss)
    train_acces.append(train_acc)
    test_acces.append(test_acc)
    if report:
        print("Test Loss: {:.3g}, Acc: {:.4f}".format(test_loss, test_acc))

plt.figure()
plt.plot(train_acces, label="train acc", linewidth=3)
plt.plot(test_acces, label="test acc", linewidth=3)
plt.legend(prop={"size": 16})
plt.xlabel("epoch", fontsize=16)
plt.grid()
plt.show()
plt.plot(train_losses, c="tab:blue", label="train loss", linewidth=3)
plt.plot(test_losses, c="tab:orange", label="test loss", linewidth=3)
plt.legend(prop={"size": 16})
plt.xlabel("epoch", fontsize=16)
plt.grid()
plt.show()


## Below is for populating ROC curve I think
loader_auc = DataLoader(
    dataset,
    batch_size=batch,
    sampler=SubsetRandomSampler(train_indices),
    drop_last=False,
)

outs = []
ys = []
for tb in loader_auc:
    tbc = tb.cuda()
    outs.append(torch.exp(model(tb)))
    ys.append(tb.y)

outs = torch.cat(outs).cpu().data.numpy()
ys = torch.cat(ys).cpu().data.numpy()

fpr_train, tpr_train, _ = roc_curve(ys, outs[:, 1])
train_auc = auc(fpr_train, tpr_train)

loader_auc = DataLoader(
    dataset,
    batch_size=batch,
    sampler=SubsetRandomSampler(test_indices),
    drop_last=False,
)

outs = []
ys = []
for tb in loader_auc:
    tbc = tb.cuda()
    outs.append(torch.exp(model(tb)))
    ys.append(tb.y)

outs = torch.cat(outs).cpu().data.numpy()
ys = torch.cat(ys).cpu().data.numpy()

fpr_test, tpr_test, _ = roc_curve(ys, outs[:, 1])
test_auc = auc(fpr_test, tpr_test)

plt.plot(
    fpr_train, tpr_train, lw=2, label="ROC curve (area = %0.3f)" % train_auc,
)
plt.plot(
    fpr_test, tpr_test, lw=2, label="ROC curve (area = %0.3f)" % test_auc,
)
plt.plot([0, 1], [0, 1], color="black", lw=1, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic")
plt.legend(loc="lower right")
plt.show()

In [None]:
os.getcwd()