In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
from torch.utils import data
import numpy as np
import torch.optim as optim
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

In [None]:
weights_path = 'g4gan_big.pkl'
model_size = 32
seq_len = 512
onehot_len = 5
hidden_state_len = 128
g4gan = load_g4gan_generator(weights_path, model_size, seq_len,
                             onehot_len)
g4gan.eval()

In [None]:
# G4detector
class G4Detector(nn.Module):
    def __init__(self, onehot_len):
        super(G4Detector, self).__init__()
        self.conv1 = nn.Conv1d(onehot_len, 80, 2)
        self.conv2 = nn.Conv1d(onehot_len, 80, 3)
        self.conv3 = nn.Conv1d(onehot_len, 96, 6)
        self.linear_block = nn.Sequential(nn.Linear(256, 32),
                                          nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid())

    def forward(self, x):
        output = x.transpose(1, 2)
        output1 = self.conv1(output)
        output2 = self.conv2(output)
        output3 = self.conv3(output)
        (output1, _) = torch.max(output1, 2)
        (output2, _) = torch.max(output2, 2)
        (output3, _) = torch.max(output3, 2)
        output = torch.cat([output1, output2, output3], dim=1)
        output = self.linear_block(output)
        return output


# Data
g4_np_data_path = 'G4_Chip_seq_quadruplex_norm.npy'
batch_size = 128
test_ratio = 0.1
shuffle_train_each_epoch = True
# 0: real data
# 1: fake data
# 2: real+fake data
pos_class_strategy = 2
# 0: random
# 1: dishuffle
negative_class_stratery = 0
input = np.load(g4_np_data_path)
if pos_class_strategy == 1 or pos_class_strategy == 2:
    with torch.no_grad():
        noise = torch.Tensor(input.shape[0],
                             hidden_state_len).uniform_(-1, 1)
        fake = g4gan(noise)
        (values, indices_hot) = fake.max(2)
        fake[:, :, :] = 0
        indices_hot = indices_hot.view(indices_hot.shape[0],
                                       indices_hot.shape[1], 1)
        fake = fake.scatter_(2, indices_hot, 1)
        input_fake = fake.detach().numpy()
num_g4 = input.shape[0]
num_test = int(np.ceil(num_g4 * test_ratio))
num_train = num_g4 - num_test
indices = np.arange(num_g4)
np.random.shuffle(indices)
train_data_indices = indices[:num_train]
test_data_indices = indices[num_train:]
if pos_class_strategy == 0:
    x_train_pos = input[train_data_indices]
elif pos_class_strategy == 1:
    x_train_pos = input_fake[train_data_indices]
elif pos_class_strategy == 2:
    x_train_pos = np.concatenate((input[train_data_indices],
                                  input_fake[train_data_indices]))
x_test_pos = input[test_data_indices]
y_train_pos = np.repeat(np.array([[1.]]), x_train_pos.shape[0], axis=0)
y_test_pos = np.repeat(np.array([[1.]]), x_test_pos.shape[0], axis=0)
codes = np.eye(5)
x_train_neg = []
x_test_neg = []
if negative_class_stratery == 0:
    for i in range(x_train_pos.shape[0]):
        x_train_neg.append(codes[np.random.choice(codes.shape[0],
                                                  size=x_train_pos.shape[1], p=[0.25, 0.25,
                                                                                0.25, 0.25, 0])])
    x_train_neg = np.array(x_train_neg)

    for i in range(y_test_pos.shape[0]):
        x_test_neg.append(codes[np.random.choice(codes.shape[0],
                                                 size=x_test_pos.shape[1], p=[0.25, 0.25,
                                                                              0.25, 0.25, 0])])
    x_test_neg = np.array(x_test_neg)
elif negative_class_stratery == 1:
    for x_tp in x_train_pos:
        x_train_neg.append(np.random.permutation(x_tp))
    x_train_neg = np.array(x_train_neg)
    for x_tp in x_test_pos:
        x_test_neg.append(np.random.permutation(x_tp))
    x_test_neg = np.array(x_test_neg)
y_train_neg = np.repeat(np.array([[0.]]), x_train_neg.shape[0], axis=0)
y_test_neg = np.repeat(np.array([[0.]]), x_test_neg.shape[0], axis=0)
x_np_train = np.concatenate((x_train_pos, x_train_neg))
y_np_train = np.concatenate((y_train_pos, y_train_neg))
x_np_test = np.concatenate((x_test_pos, x_test_neg))
y_np_test = np.concatenate((y_test_pos, y_test_neg))
x_train_t = torch.Tensor(x_np_train)
y_train_t = torch.Tensor(y_np_train)
x_test_t = torch.Tensor(x_np_test)
y_test_t = torch.Tensor(y_np_test)
train_data = data.TensorDataset(x_train_t, y_train_t)
test_data = data.TensorDataset(x_test_t, y_test_t)
train_dataloader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
lr = 0.0001
beta_1 = 0.9
beta_2 = 0.99
net = G4Detector(5)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=lr, betas=(beta_1, beta_2))


def accuracy(output, target):
    """Computes the accuracy for multiple binary predictions"""
    pred = output >= 0.5
    truth = target >= 0.5
    acc = pred.eq(truth).sum().item() / target.numel()
    return acc


epoches = 15
stat_every_batch = 10
use_cuda = True
for epoch in range(epoches):
    running_loss = 0.
    running_acc = 0.
    for (i, data) in enumerate(train_dataloader, 0):
        (inputs, labels) = data
    optimizer.zero_grad()
    if use_cuda:
        inputs.cuda()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    # print statistics
    running_loss += loss.item()
    running_acc += accuracy(outputs, labels)

    if i % stat_every_batch == stat_every_batch - 1:  # print every 2000 minibatches
        print('[%d, %5d] loss: %.3f acc: %.3f' % epoch + 1, i + 1,
              running_loss / stat_every_batch, running_acc
              / stat_every_batch)
        running_loss = 0.
        running_acc = 0.

running_loss = 0.
running_acc = 0.

for (i, data) in enumerate(test_dataloader, 0):
    (inputs, labels) = data
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    running_loss += loss.item()
    running_acc += accuracy(outputs, labels)
    print('Test_loss: %.3f Test_acc: %.3f' % (running_loss / (i + 1),
                                              running_acc / (i + 1)))
running_loss = 0.
running_acc = 0.
all_labels = None
all_outputs = None
for (i, data) in enumerate(test_dataloader, 0):
    (inputs, labels) = data
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    running_loss += loss.item()
    running_acc += accuracy(outputs, labels)
    if all_labels is None:
        all_labels = labels.numpy()
    else:
        all_labels = np.concatenate((all_labels, labels.numpy()))
    if all_outputs is None:
        all_outputs = outputs.detach().numpy()
    else:
        all_outputs = np.concatenate((all_outputs,
                                      outputs.detach().numpy()))
print('Test loss: %.3f acc: %.3f' % (running_loss / (i + 1),
                                     running_acc / (i + 1)))
(fpr, tpr, threshold) = metrics.roc_curve(all_labels, all_outputs)
roc_auc = metrics.auc(fpr, tpr)
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label='AUC = %0.5f' % roc_auc)
plt.legend(loc='lower right')
plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()