# load data

In [11]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import os.path as osp
from tqdm import tqdm

In [12]:
N = 6111
dir_name = './data/'
datasets = []
for i in tqdm(range(N)):
	data = torch.load(osp.join(dir_name,f'data_{i}.pt'))
	datasets.append(data)

100%|██████████| 6111/6111 [00:03<00:00, 1679.53it/s]


## train-test split

In [13]:
train_idx = list()
test_name = list()
test_idx = list()
y_train = list()
with open('./data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            test_name.append(t[0])
            test_idx.append(i)
        else:
            train_idx.append(i)
            y_train.append(int(t[1][:-1]))

In [14]:
len(train_idx),len(test_idx)

(4888, 1223)

In [15]:
train_val_set = [datasets[i] for i in train_idx]
test_set = [datasets[i] for i in test_idx]

sequential embedding contains a beginning and an ending

In [16]:
train_val_set[0].x.shape

torch.Size([185, 86])

## add label

In [17]:
for data, y in zip(train_val_set, y_train):
	data.y = y

## train-valid split

In [18]:
train_set, val_set = torch.utils.data.random_split(train_val_set, [4400,488], generator=torch.Generator().manual_seed(42))

In [19]:
len(train_set), len(val_set)

(4400, 488)

## Hyperparameters

In [22]:
import numpy as np
num_node_features = train_set[0].num_node_features # 86
print(num_node_features)
num_classes = len(np.unique(y_train)) # 18
print(num_classes)
lr = 1e-3
epochs = 100
batch_size = 32
hidden_dim = 64

86
18


## Dataloader

In [23]:
from torch.utils.data.dataloader import default_collate
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
collate_fn = lambda x: tuple(x_.to(device) for x_ in default_collate(x))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [24]:
train_val_loader =DataLoader(train_val_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [25]:
for step, data in enumerate(train_loader):
	print(f'Step {step + 1}:')
	print('=======')
	print(f'Number of graphs in the current batch: {data.num_graphs}')
	print(data.is_cuda)
	print()
	break

Step 1:
Number of graphs in the current batch: 32
False



In [26]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.utils import dropout_adj

In [27]:
def train(model, optimizer):
	model.train()
	for data in tqdm(train_loader):
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.edge_attr, data.batch)
		loss = criterion(out, data.y)
		loss.backward()
		optimizer.step()
		optimizer.zero_grad()

def eval(model, loader):
	model.eval()
	correct = 0
	for data in loader:
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.edge_attr, data.batch)
		loss = criterion(out, data.y)
		pred = out.argmax(dim=1)
		correct += int((pred == data.y).sum())
		acc = correct / len(loader.dataset)
	return acc, loss

def predict(model):
	pred_proba = []
	model.eval()
	for data in test_loader:
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.edge_attr, data.batch)
		pred_proba.append(out.detach().cpu().numpy())
		del out
		torch.cuda.empty_cache()
	pred_proba = np.vstack(pred_proba)
	pred_proba = np.exp(pred_proba)
	return pred_proba

def embed(model, loader):
	embeddings = []
	model.eval()
	for data in loader:
		data = data.to(device)
		_, embedding = model(data.x, data.edge_index, data.edge_attr, data.batch)
		embeddings.append(embedding.detach().cpu().numpy())
		del embedding
		torch.cuda.empty_cache()
	embeddings = np.vstack(embeddings)
	return embeddings

### load & save checkpoint 

In [28]:
import os.path as osp
from tabnanny import check
def load_checkpoint(model, optimizer, path):
	if osp.exists(path):
		checkpoint = torch.load(path)
		model.load_state_dict(checkpoint['model_state_dict'])
		optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
		print('checkpoint loaded')
		return checkpoint.get('loss')

In [29]:
def save_checkpoint(model, optimizer, loss, path):
	torch.save({
							'epoch': epochs,
							'model_state_dict': model.state_dict(),
							'optimizer_state_dict': optimizer.state_dict(),
							'loss': loss,
							}, path)

### train epochs and save checkpoints

In [30]:
import time
def train_epochs(model, optimizer, path, patience=10, from_scratch=True):
	checkpoint_loss = None
	if not from_scratch:
		checkpoint_loss = load_checkpoint(model, optimizer, path=path)
	best_val_loss = checkpoint_loss if checkpoint_loss else 2.8 #2.8 is the loss of equi-ignorant guess
	best_val_acc = 0
	for epoch in range(epochs):
		train(model, optimizer)
		train_acc, train_loss = eval(model, train_loader)
		val_acc, val_loss = eval(model, val_loader)
		print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
		if val_loss < best_val_loss:
			save_checkpoint(model, optimizer, loss=val_loss, path=path)
			best_val_loss = val_loss
			print('checkpoint saved')
			patience = patience
		else:
			patience -= 1
			if patience == 0:
				print('early stopping')
				break

In [31]:
def train_epochs_full(model, optimizer, path, patience=10, from_scratch=True):
	checkpoint_loss = None
	if not from_scratch:
		checkpoint_loss = load_checkpoint(model, optimizer, path=path)
	best_train_loss = checkpoint_loss if checkpoint_loss else 2.8 #2.8 is the loss of equi-ignorant guess
	best_train_acc = 0
	for epoch in range(epochs):
		train(model, optimizer)
		train_acc, train_loss = eval(model, train_val_loader)
		time.sleep(0.5)
		print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
		if train_loss < best_train_loss:
			save_checkpoint(model, optimizer, loss=train_loss, path=path)
			best_train_loss = train_loss
			print('checkpoint saved')
			patience = patience
		else:
			patience -= 1
			if patience == 0:
				print('early stopping')
				break

### write

In [38]:
import csv
def write_csv(fname, pred_proba):
    with open(f'./predictions/{fname}.csv', 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        lst = list()
        for i in range(18):
            lst.append('class'+str(i))
        lst.insert(0, "name")
        writer.writerow(lst)
        for i, protein in enumerate(test_name):
            lst = pred_proba[i,:].tolist()
            lst.insert(0, protein)
            writer.writerow(lst)

In [33]:
import gc
def clean_gpu_memory():
	gc.collect()
	torch.cuda.empty_cache()

# GNN2: GIN

In [34]:
class GIN_2(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GIN_2, self).__init__()
        torch.manual_seed(12345)
        self.p = 0.2
        self.bn = gnn.BatchNorm(hidden_channels)
        self.gin = gnn.GIN(in_channels=num_node_features, 
                            hidden_channels=hidden_channels,
                            num_layers=2,
                            dropout= 0.2,
                            norm=self.bn,
                            )
        self.lin = gnn.MLP(in_channels=hidden_channels, 
                            hidden_channels=hidden_channels,
                            out_channels=num_classes, 
                            num_layers=2, 
                            dropout=0.2)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.gin(x=x, edge_index=edge_index, 
                    #edge_attr=edge_attr
                    )
        x = gnn.global_add_pool(x, batch)
        x = self.bn(x)
        x = x.relu()
        emb = x
        x = self.lin(x)
        
        return F.log_softmax(x, dim=1), emb

In [35]:
model_gin = GIN_2(hidden_channels=64).to(device)
print(model_gin)
optimizer_1 = torch.optim.Adam(model_gin.parameters(),lr=lr)
criterion = nn.CrossEntropyLoss()

GIN_2(
  (bn): BatchNorm(64)
  (gin): GIN(86, 64, num_layers=2)
  (lin): MLP(64, 64, 18)
)


### Train

In [36]:
train_epochs_full(model_gin, optimizer_1, path='model_gin.pt', patience=50, from_scratch=True)

100%|██████████| 138/138 [00:03<00:00, 43.05it/s]


Epoch: 000, Train Loss: 2.1345, Train Acc: 0.2985
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 98.80it/s]


Epoch: 001, Train Loss: 1.8681, Train Acc: 0.3286
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 95.48it/s]


Epoch: 002, Train Loss: 1.9757, Train Acc: 0.3099


100%|██████████| 138/138 [00:01<00:00, 97.59it/s]


Epoch: 003, Train Loss: 1.9462, Train Acc: 0.3905


100%|██████████| 138/138 [00:01<00:00, 99.57it/s] 


Epoch: 004, Train Loss: 2.0490, Train Acc: 0.4206


100%|██████████| 138/138 [00:01<00:00, 100.12it/s]


Epoch: 005, Train Loss: 1.9050, Train Acc: 0.3805


100%|██████████| 138/138 [00:01<00:00, 103.49it/s]


Epoch: 006, Train Loss: 1.7940, Train Acc: 0.3615
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.65it/s] 


Epoch: 007, Train Loss: 1.7909, Train Acc: 0.4290
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 100.01it/s]


Epoch: 008, Train Loss: 1.7833, Train Acc: 0.4693
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 98.02it/s]


Epoch: 009, Train Loss: 1.8772, Train Acc: 0.4296


100%|██████████| 138/138 [00:01<00:00, 100.74it/s]


Epoch: 010, Train Loss: 1.5948, Train Acc: 0.4685
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 101.35it/s]


Epoch: 011, Train Loss: 2.1249, Train Acc: 0.3179


100%|██████████| 138/138 [00:01<00:00, 99.75it/s] 


Epoch: 012, Train Loss: 1.7663, Train Acc: 0.4603


100%|██████████| 138/138 [00:01<00:00, 101.36it/s]


Epoch: 013, Train Loss: 1.9118, Train Acc: 0.4284


100%|██████████| 138/138 [00:01<00:00, 100.36it/s]


Epoch: 014, Train Loss: 1.6249, Train Acc: 0.4126


100%|██████████| 138/138 [00:01<00:00, 98.65it/s] 


Epoch: 015, Train Loss: 1.3299, Train Acc: 0.4521
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.86it/s] 


Epoch: 016, Train Loss: 1.9339, Train Acc: 0.4210


100%|██████████| 138/138 [00:01<00:00, 101.51it/s]


Epoch: 017, Train Loss: 1.6237, Train Acc: 0.3795


100%|██████████| 138/138 [00:01<00:00, 100.86it/s]


Epoch: 018, Train Loss: 1.8978, Train Acc: 0.4675


100%|██████████| 138/138 [00:01<00:00, 102.24it/s]


Epoch: 019, Train Loss: 2.1479, Train Acc: 0.3014


100%|██████████| 138/138 [00:01<00:00, 101.30it/s]


Epoch: 020, Train Loss: 1.9505, Train Acc: 0.4583


100%|██████████| 138/138 [00:01<00:00, 99.89it/s] 


Epoch: 021, Train Loss: 1.6991, Train Acc: 0.4847


100%|██████████| 138/138 [00:01<00:00, 102.01it/s]


Epoch: 022, Train Loss: 1.4591, Train Acc: 0.4894


100%|██████████| 138/138 [00:01<00:00, 101.94it/s]


Epoch: 023, Train Loss: 1.4469, Train Acc: 0.5346


100%|██████████| 138/138 [00:01<00:00, 99.93it/s] 


Epoch: 024, Train Loss: 1.7184, Train Acc: 0.4967


100%|██████████| 138/138 [00:01<00:00, 101.27it/s]


Epoch: 025, Train Loss: 1.3024, Train Acc: 0.5211
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 96.80it/s] 


Epoch: 026, Train Loss: 1.8073, Train Acc: 0.4343


100%|██████████| 138/138 [00:01<00:00, 101.00it/s]


Epoch: 027, Train Loss: 1.4280, Train Acc: 0.5376


100%|██████████| 138/138 [00:01<00:00, 101.32it/s]


Epoch: 028, Train Loss: 1.7905, Train Acc: 0.4699


100%|██████████| 138/138 [00:01<00:00, 100.64it/s]


Epoch: 029, Train Loss: 1.4355, Train Acc: 0.4816


100%|██████████| 138/138 [00:01<00:00, 100.87it/s]


Epoch: 030, Train Loss: 1.4870, Train Acc: 0.5143


100%|██████████| 138/138 [00:01<00:00, 99.78it/s] 


Epoch: 031, Train Loss: 1.4756, Train Acc: 0.3918


100%|██████████| 138/138 [00:01<00:00, 101.00it/s]


Epoch: 032, Train Loss: 1.3659, Train Acc: 0.5188


100%|██████████| 138/138 [00:01<00:00, 99.82it/s] 


Epoch: 033, Train Loss: 1.2433, Train Acc: 0.4767
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 100.82it/s]


Epoch: 034, Train Loss: 1.5562, Train Acc: 0.3891


100%|██████████| 138/138 [00:01<00:00, 99.33it/s]


Epoch: 035, Train Loss: 1.2157, Train Acc: 0.5383
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 101.54it/s]


Epoch: 036, Train Loss: 1.4564, Train Acc: 0.4969


100%|██████████| 138/138 [00:01<00:00, 99.89it/s] 


Epoch: 037, Train Loss: 1.5126, Train Acc: 0.5074


100%|██████████| 138/138 [00:01<00:00, 101.40it/s]


Epoch: 038, Train Loss: 1.2014, Train Acc: 0.5677
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.56it/s] 


Epoch: 039, Train Loss: 1.6802, Train Acc: 0.4865


100%|██████████| 138/138 [00:01<00:00, 101.41it/s]


Epoch: 040, Train Loss: 1.4206, Train Acc: 0.5344


100%|██████████| 138/138 [00:01<00:00, 99.30it/s] 


Epoch: 041, Train Loss: 1.5545, Train Acc: 0.5153


100%|██████████| 138/138 [00:01<00:00, 100.01it/s]


Epoch: 042, Train Loss: 1.5158, Train Acc: 0.4832


100%|██████████| 138/138 [00:01<00:00, 100.43it/s]


Epoch: 043, Train Loss: 1.3556, Train Acc: 0.5741


100%|██████████| 138/138 [00:01<00:00, 100.40it/s]


Epoch: 044, Train Loss: 1.2908, Train Acc: 0.5689


100%|██████████| 138/138 [00:01<00:00, 99.40it/s] 


Epoch: 045, Train Loss: 1.2585, Train Acc: 0.5276


100%|██████████| 138/138 [00:01<00:00, 99.94it/s] 


Epoch: 046, Train Loss: 1.1093, Train Acc: 0.5520
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.51it/s] 


Epoch: 047, Train Loss: 1.6050, Train Acc: 0.4202


100%|██████████| 138/138 [00:01<00:00, 100.34it/s]


Epoch: 048, Train Loss: 1.1311, Train Acc: 0.5665


100%|██████████| 138/138 [00:01<00:00, 98.05it/s] 


Epoch: 049, Train Loss: 0.9984, Train Acc: 0.5849
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.57it/s] 


Epoch: 050, Train Loss: 1.4262, Train Acc: 0.4564


100%|██████████| 138/138 [00:01<00:00, 100.45it/s]


Epoch: 051, Train Loss: 1.4366, Train Acc: 0.5659


100%|██████████| 138/138 [00:01<00:00, 100.48it/s]


Epoch: 052, Train Loss: 1.1424, Train Acc: 0.5520


100%|██████████| 138/138 [00:01<00:00, 98.45it/s] 


Epoch: 053, Train Loss: 1.0242, Train Acc: 0.6170


100%|██████████| 138/138 [00:01<00:00, 97.69it/s] 


Epoch: 054, Train Loss: 1.3174, Train Acc: 0.6168


100%|██████████| 138/138 [00:01<00:00, 98.63it/s] 


Epoch: 055, Train Loss: 1.0116, Train Acc: 0.5966


100%|██████████| 138/138 [00:01<00:00, 98.73it/s] 


Epoch: 056, Train Loss: 1.0970, Train Acc: 0.6193


100%|██████████| 138/138 [00:01<00:00, 99.35it/s] 


Epoch: 057, Train Loss: 0.9822, Train Acc: 0.6299
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.68it/s] 


Epoch: 058, Train Loss: 1.0888, Train Acc: 0.5501


100%|██████████| 138/138 [00:01<00:00, 99.93it/s] 


Epoch: 059, Train Loss: 0.9990, Train Acc: 0.5239


100%|██████████| 138/138 [00:01<00:00, 99.93it/s] 


Epoch: 060, Train Loss: 1.0668, Train Acc: 0.5428


100%|██████████| 138/138 [00:01<00:00, 99.51it/s] 


Epoch: 061, Train Loss: 0.8304, Train Acc: 0.6140
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 99.07it/s] 


Epoch: 062, Train Loss: 1.0749, Train Acc: 0.6565


100%|██████████| 138/138 [00:01<00:00, 98.88it/s] 


Epoch: 063, Train Loss: 1.2726, Train Acc: 0.5698


100%|██████████| 138/138 [00:01<00:00, 99.58it/s] 


Epoch: 064, Train Loss: 1.3302, Train Acc: 0.5996
early stopping


## Load and write results

In [39]:
load_checkpoint(model_gin, optimizer_1, path='./model_gin.pt')
write_csv('model_gin',predict(model_gin))

checkpoint loaded
