In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import os.path as osp
from torch_geometric.data import Dataset, Data, DataLoader
from torch.nn import Linear, ReLU, BatchNorm1d
import torch.nn.functional as F 
from torch_geometric.nn import Set2Set, TransformerConv
import time

In [None]:
class EColi_Prot():
    def __init__(self):
        self.processed_dir = 'iso_hyd_dataset'

    def processed_file_names(self):
        return os.listdir(self.processed_dir)

    def len(self):
        return len(self.processed_file_names())

    def get(self, prot_ch):
        data = torch.load(osp.join(self.processed_dir, '{}.pt'.format(prot_ch)))
        return data

In [None]:
x_num_features = 3
edge_attr_num_features = 4
embedding_size = 64
num_layers = 5

class GCNN(torch.nn.Module):
    def __init__(self, embedding_size):
        # Init parent
        super(GCNN, self).__init__()
        self.sig = torch.nn.Sigmoid()
        self.tanh = torch.nn.Tanh()
        self.relu = torch.nn.ReLU()
        self.dropout = 0.1
        # GCN Layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(TransformerConv(x_num_features, embedding_size))
        for _ in range(num_layers - 1):
            self.convs.append(TransformerConv(embedding_size, embedding_size))
        # Normaliz Layers
        self.norms = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            self.norms.append(BatchNorm1d(embedding_size))
        # Pooling Layer
        self.pool = Set2Set(embedding_size, 4, num_layers=2)
        # Output layers
        self.lins = torch.nn.ModuleList()
        self.lins.append(Linear(embedding_size*4, embedding_size*4))
        self.lins.append(Linear(embedding_size*4, embedding_size*4))
        self.lins.append(Linear(embedding_size*4, 1))

    def convolutional_pass(self, x, edge_index, edge_type):
        for conv, norm in zip(self.convs[:-1], self.norms):
            x = norm(self.relu(conv(x, edge_index)))
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

    def forward(self, batch1, batch2):
        x = [batch1.x, batch2.x]
        edge_index = [batch1.edge_index, batch2.edge_index]
        edge_type = [batch1.edge_attr, batch2.edge_attr]
        batch_vec = [batch1.x_batch, batch2.x_batch]

        # Apply conv
        x[0] = self.convolutional_pass(x[0], edge_index[0], edge_type[0])
        x[1] = self.convolutional_pass(x[1], edge_index[1], edge_type[1])

        # Pooling
        x = torch.cat([self.pool(x[0], batch_vec[0]),
                       self.pool(x[1], batch_vec[1])], dim=1)

        # Apply a final regression layers.
        for lin in self.lins[:-1]:
            x = self.relu(lin(x))
        x = self.sig(self.lins[-1](x))
        return x

model = GCNN(embedding_size)
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

In [None]:
def get_link_sets2(links, p_train=0.9, seed=None):
    links = links.sample(frac=1, random_state=seed).reset_index(drop=True)
    n = len(links)
    stop = int(n*p_train)
    train_data = links[:stop].reset_index(drop=True)
    test_data = links[stop:].reset_index(drop=True)
    return train_data, test_data

def load_dataset():
    print('Loading Dataset...')
    prots = {}
    ec = EColi_Prot()
    for name in tqdm(ec.processed_file_names()):
        prots[name[:-3]] = ec.get(name[:-3])
    return prots
    
def make_data_loader(data_links, batch_size):
    prots = load_dataset()
    print('Prepering Lists...')
    tr = np.array([data_links.P1, data_links.P2, data_links.INTERACT]).T
    data_list1 = []
    data_list2 = []
    for id1, id2, score in tr:
        p1 = prots[id1]
        p2 = prots[id2]  
        y = torch.tensor([[score]],dtype=torch.float32)
        d1 = Data(x=p1.x, edge_index=p1.edge_index, edge_attr=p1.edge_attr, y=y)
        d2 = Data(x=p2.x, edge_index=p2.edge_index, edge_attr=p2.edge_attr, y=y)
        data_list1.append(d1)
        data_list2.append(d2)
    loader1 = DataLoader(data_list1, batch_size=batch_size, follow_batch=['x'], shuffle=False) 
    loader2 = DataLoader(data_list2, batch_size=batch_size, follow_batch=['x'], shuffle=False)
    return loader1, loader2

def load_model(name, epoch, PATH='model'):
    model = name()
    model.load_state_dict(torch.load(osp.join(PATH, 'CNN_epoch{}.pt'.format(epoch))))
    return model

def save_model(model, epoch, PATH='model'):
    torch.save(model.state_dict(), osp.join(PATH, 'CNN_epoch{}.pt'.format(epoch)))

def train(loader1, loader2):
    model.train()
    total_loss = 0
    for batch1, batch2 in zip(loader1, loader2):
        # Use GPU
        batch1 = batch1.to(device)
        batch2 = batch2.to(device)        
        # Reset gradients
        optimizer.zero_grad()
        output = model(batch1, batch2)
        #print(output)
        # Calculating the loss and gradients
        loss = loss_fn(output, batch1.y)
        total_loss += loss.item()
        loss.backward()
        # Update using the gradients
        optimizer.step()
    return total_loss / len(loader1.dataset)

def test(loader1, loader2):
    model.eval()
    total_loss = 0
    for batch1, batch2 in zip(loader1, loader2):
        # Use GPU
        batch1.to(device)
        batch2.to(device)     
        output = model(batch1, batch2)
        loss = loss_fn(output, batch1.y)
        total_loss += loss.item()
    return total_loss / len(loader1.dataset)

In [None]:
# Use GPU for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
model = GCNN(embedding_size)
model = model.to(device)
# Root mean squared error
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  

num_epochs = 1000
batch_size = 16


links = pd.read_csv('ecoli_pdb_links_revisited.csv')
train_links, test_links = get_link_sets2(links, seed=11)
train_loader1, train_loader2 = make_data_loader(train_links, batch_size)
test_loader1, test_loader2 = make_data_loader(test_links, batch_size)


print("Starting training...")
tr_losses = []
losses = []
for epoch in range(1,num_epochs+1):
    start_time = time.time()
    tr_loss = train(train_loader1, train_loader2)
    tr_losses.append(tr_loss)
    loss = test(test_loader1, test_loader2)
    losses.append(loss)
    print('Epoch: {}, Loss: {:.4f}, Test Loss: {:.4f} [{:.2f} sec/epoch]'.format(
        epoch, tr_loss, loss, time.time() - start_time))
    if epoch % 10 == 0:
        save_model(model, epoch)