In [12]:
import pandas as pd
import torch
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from matplotlib.pylab import plt
from numpy import arange
import tqdm as tqdm


torch.manual_seed(0)

<torch._C.Generator at 0x115ba6610>

In [13]:
from torch.utils.data import Dataset

class PhageDataset(Dataset):
    def __init__(self):
        self.data = np.load("../one-hot/onehot_tr.npy",allow_pickle=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        host_id = self.data[idx][1]
        onehot = self.data[idx][0]
        host_vector = np.zeros(58, dtype=np.float64)
        host_vector[host_id] = 1.0
        sample = {"onehot": onehot, "host_vector": host_vector}
        return sample

In [3]:
import torch
import torch.nn as nn

class DNA_CNN(nn.Module):
    def __init__(self, seq_len=19044, num_filters=1, kernel_size=6):
        super().__init__()
        self.seq_len = seq_len

        self.conv_net = nn.Sequential(
            nn.Conv1d(1, num_filters, kernel_size=kernel_size),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(num_filters*(seq_len-kernel_size+1), 58)
        )


    def forward(self, x):
        out = self.conv_net(x)
        return out

In [14]:
data_te = np.load("../one-hot/onehot_te.npy",allow_pickle=True)

In [15]:
class PhageTrainer:
    def __init__(self):
        phage_dataset = PhageDataset()  # DISDataset()
        print("Total dataset size:", len(phage_dataset))
        self.phage_dataloader = DataLoader(
            phage_dataset, batch_size=128, shuffle=True, num_workers=1
        )

        self.model = DNA_CNN()#.cuda()

        self._setup_optimizers()

        self.phage_iter = iter(self.phage_dataloader)

        self.bce_loss = nn.BCEWithLogitsLoss(reduction="sum")

    def _clip_weights(self):
        """
        Performs clipping of weights.
        """
        for p in self.model.parameters():
            p.data.clamp_(-1.0 * self.clip_value, self.clip_value)

    def _setup_optimizers(self):
        self.iter_size = 1
        self.optimizer = torch.optim.Adam(
            [param for name, param in self.model.named_parameters()],
            lr=3e-4,
            weight_decay=0.00001)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[30, 80], gamma=0.1
        )

    def save(self, model_path):
        torch.save(self.model.state_dict(), model_path + ".pth")

    @torch.no_grad()
    def predict(self):

        self.model.load_state_dict(
            torch.load("./checkpoints/cnn.pth"),
           strict=False,
        )


        self.hosts = []
        self.features = []

        for i in data_te:
          self.features.append(i[0])
          self.hosts.append(i[1])


        self.model.eval()
        # output = self.model(torch.tensor(np.array(self.features)).cuda().float().unsqueeze(1))
        output = self.model(torch.tensor(np.array(self.features)).float().unsqueeze(1))
        output = torch.sigmoid(output)
        return output, self.hosts

    def step(self):
        self.optimizer.zero_grad()
        seg_loss = 0.0
        for _ in range(self.iter_size):
            try:
                phage_sample = next(self.phage_iter)
            except StopIteration:
                print("bbox dataloader reset.")
                self.phage_iter = iter(self.phage_dataloader)
                phage_sample = next(self.phage_iter)

            # labels = phage_sample["host_vector"].float().cuda()#.unsqueeze(1)
            # output = self.model(phage_sample["onehot"].float().cuda().unsqueeze(1))

            labels = phage_sample["host_vector"].float()#.unsqueeze(1)
            output = self.model(phage_sample["onehot"].float().unsqueeze(1))

            print(output.shape, labels.shape)
            loss = self.bce_loss(output, labels) / output.shape[0]
            loss.backward()

        seg_loss = loss.detach().item()
        self.optimizer.step()

        return [
            seg_loss / self.iter_size,
        ]


In [16]:
def do_training():
    trainer = PhageTrainer()
    max_iters = 20000
    save_iter = 1000
    snap_iter = 1000

    for iter_no in range(max_iters):
        batch_loss = trainer.step()

        print(
            "[Iter %d/%d] seg_loss = %f"
            % (iter_no, max_iters, batch_loss[0])
        )
        if (iter_no + 1) % save_iter == 0:
            trainer.save("./checkpoints/cnn")

In [17]:
do_training()

Total dataset size: 16636


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'PhageDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 

In [18]:
import copy

@torch.no_grad()
def do_prediction(treshold):
    trainer = PhageTrainer()
    output, hosts = trainer.predict()

    output=output.tolist()

    output_ = [0 for i in range(len(output))]


    for k in treshold:
      for i in range(len(output)):
          p1 = max(output[i])
          ind1 = output[i].index(p1)

          x = copy.deepcopy(output[i])
          x.remove(p1)

          p2 = max(x)

          if p1 - p2 >= k:
              output_[i] = ind1
          else:
              output_[i] = -1

      final_op = []

      for i in output_:
          final_op.append(i)

      class_weights = {19: 0.15956136027599802, -1: 0.12148841793987186, 46: 0.07787087235091178, 54: 0.06419418432725481, 40: 0.06320847708230655, 55: 0.05704780680137999, 24: 0.04891572203055693, 18: 0.03523903400689995, 31: 0.030556924593395762, 50: 0.030064070970921637, 5: 0.025135534746180386, 42: 0.02168555938886151, 49: 0.017619517003449974, 16: 0.017496303597831445, 9: 0.015278462296697881, 27: 0.01232134056185313, 1: 0.012074913750616067, 47: 0.011951700344997535, 2: 0.011705273533760474, 57: 0.01145884672252341, 52: 0.009980285855101035, 43: 0.009117792015771316, 33: 0.008748151798915723, 56: 0.007146377525874815, 21: 0.0065303104977821585, 14: 0.006407097092163627, 36: 0.006407097092163627, 17: 0.006037456875308034, 20: 0.005914243469689502, 53: 0.005298176441596846, 41: 0.005174963035978314, 51: 0.005051749630359783, 37: 0.0049285362247412515, 8: 0.004805322819122721, 13: 0.004435682602267127, 11: 0.004189255791030064, 44: 0.004189255791030064, 34: 0.004066042385411533, 30: 0.004066042385411533, 6: 0.0038196155741744703, 29: 0.0034499753573188764, 4: 0.0032035485460818135, 10: 0.0030803351404632825, 12: 0.002957121734844751, 26: 0.0028339083292262196, 35: 0.0022178413011335633, 0: 0.002094627895515032, 39: 0.002094627895515032, 38: 0.001971414489896501, 32: 0.001971414489896501, 23: 0.001971414489896501, 22: 0.0016017742730409068, 45: 0.0016017742730409068, 3: 0.0014785608674223755, 48: 0.0013553474618038443, 7: 0.0012321340561853129, 25: 0.0012321340561853129, 28: 0.0012321340561853129, 15: 0.0012321340561853129}

      print("k = "+str(k)+"\n")
      print("f1_score:" + str(f1_score(hosts,final_op, average="weighted", sample_weight=[class_weights[i] for i in hosts])) + "\n")
      print("accuracy:"+ str(accuracy_score(hosts,final_op)) + "\n")
      print("\n")

In [19]:
do_prediction([0.6,0.7,0.8,0.9,1])

Total dataset size: 16636


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'PhageDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 