# load data

In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import os.path as osp
from tqdm import tqdm

In [105]:
N = 6111
# N = 10
dir_name = './data/'
raw_datasets = []
for i in tqdm(range(N)):
	raw_data = torch.load(osp.join(dir_name,f'data_{i}.pt'))
	raw_datasets.append(raw_data)

100%|██████████| 6111/6111 [01:33<00:00, 65.09it/s] 


In [106]:
def restrict_to_one_feature_col(raw_dataset, feature_col=1):
	dataset = []
	for raw_data in tqdm(raw_dataset):
		data = raw_data.clone()
		edge_mask = raw_data.edge_attr[:,feature_col].nonzero()[0]
		data.edge_index = raw_data.edge_index[:,edge_mask]
		del(data.edge_attr)
		dataset.append(data)
	return dataset

In [107]:
datasets = restrict_to_one_feature_col(raw_datasets, feature_col=2)

100%|██████████| 6111/6111 [00:07<00:00, 813.38it/s] 


## train-test split

In [110]:
train_idx = list()
test_name = list()
test_idx = list()
y_train = list()
with open('./data/raw/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            test_name.append(t[0])
            test_idx.append(i)
        else:
            train_idx.append(i)
            y_train.append(int(t[1][:-1]))

In [111]:
len(train_idx),len(test_idx)

(4888, 1223)

In [112]:
train_val_set = [datasets[i] for i in train_idx]
test_set = [datasets[i] for i in test_idx]

sequential embedding contains a beginning and an ending

In [113]:
train_val_set[0].x.shape

torch.Size([185, 86])

## add label

In [114]:
for data, y in zip(train_val_set, y_train):
	data.y = y

## train-valid split

In [115]:
train_set, val_set = torch.utils.data.random_split(train_val_set, [4400,488], generator=torch.Generator().manual_seed(42))

In [116]:
len(train_set), len(val_set)

(4400, 488)

## Hyperparameters

In [141]:
import numpy as np
num_node_features = train_set[0].num_node_features # 86
print(num_node_features)
num_classes = len(np.unique(y_train)) # 18
print(num_classes)
lr = 1e-3
epochs = 100
batch_size = 128
hidden_dim = 64

86
18


## Dataloader

In [142]:
from torch.utils.data.dataloader import default_collate
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
collate_fn = lambda x: tuple(x_.to(device) for x_ in default_collate(x))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [143]:
train_val_loader =DataLoader(train_val_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [144]:
for step, data in enumerate(train_loader):
	print(f'Step {step + 1}:')
	print('=======')
	print(f'Number of graphs in the current batch: {data.num_graphs}')
	print(data.is_cuda)
	print()
	break

Step 1:
Number of graphs in the current batch: 128
False



In [121]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.utils import dropout_adj

In [129]:
def train(model, optimizer, criterion):
	model.train()
	for data in tqdm(train_loader):
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.batch)
		loss = criterion(out, data.y)
		loss.backward()
		optimizer.step()
		optimizer.zero_grad()

def eval(model, loader, criterion):
	model.eval()
	correct = 0
	for data in loader:
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.batch)
		loss = criterion(out, data.y)
		pred = out.argmax(dim=1)
		correct += int((pred == data.y).sum())
		acc = correct / len(loader.dataset)
	return acc, loss

def predict(model):
	pred_proba = []
	model.eval()
	for data in test_loader:
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.batch)
		pred_proba.append(out.detach().cpu().numpy())
		del out
		torch.cuda.empty_cache()
	pred_proba = np.vstack(pred_proba)
	pred_proba = np.exp(pred_proba)
	return pred_proba

def embed(model, loader):
	embeddings = []
	model.eval()
	for data in loader:
		data = data.to(device)
		_, embedding = model(data.x, data.edge_index, data.batch)
		embeddings.append(embedding.detach().cpu().numpy())
		del embedding
		torch.cuda.empty_cache()
	embeddings = np.vstack(embeddings)
	return embeddings

### load & save checkpoint 

In [18]:
import os.path as osp
def load_checkpoint(model, optimizer, path):
	if osp.exists(path):
		checkpoint = torch.load(path)
		model.load_state_dict(checkpoint['model_state_dict'])
		optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
		print('checkpoint loaded')
		return checkpoint.get('loss')

In [19]:
def save_checkpoint(model, optimizer, loss, path):
	torch.save({
							'epoch': epochs,
							'model_state_dict': model.state_dict(),
							'optimizer_state_dict': optimizer.state_dict(),
							'loss': loss,
							}, path)

### train epochs and save checkpoints

In [123]:
import time 
def train_epochs(model, optimizer, criterion, path, patience=10, from_scratch=True):
	checkpoint_loss = None
	if not from_scratch:
		checkpoint_loss = load_checkpoint(model, optimizer, path=path)
	best_val_loss = checkpoint_loss if checkpoint_loss else 2.8 #2.8 is the loss of equi-ignorant guess
	best_val_acc = 0
	for epoch in range(epochs):
		train(model, optimizer, criterion)
		train_acc, train_loss = eval(model, train_loader, criterion)
		val_acc, val_loss = eval(model, val_loader, criterion)
		print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
		if val_loss < best_val_loss:
			save_checkpoint(model, optimizer, loss=val_loss, path=path)
			best_val_loss = val_loss
			print('checkpoint saved')
			patience = patience
		else:
			patience -= 1
			if patience == 0:
				print('early stopping')
				break

In [124]:
def train_epochs_full(model, optimizer, criterion, path, patience=10, from_scratch=True):
	checkpoint_loss = None
	if not from_scratch:
		checkpoint_loss = load_checkpoint(model, optimizer, path=path)
	best_train_loss = checkpoint_loss if checkpoint_loss else 2.8 #2.8 is the loss of equi-ignorant guess
	best_train_acc = 0
	for epoch in range(epochs):
		train(model, optimizer, criterion)
		train_acc, train_loss = eval(model, train_val_loader, criterion)
		time.sleep(0.5)
		print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
		if train_loss < best_train_loss:
			save_checkpoint(model, optimizer, loss=train_loss, path=path)
			best_train_loss = train_loss
			print('checkpoint saved')
			patience = patience
		else:
			patience -= 1
			if patience == 0:
				print('early stopping')
				break

### write

In [22]:
import csv
def write_csv(fname, pred_proba):
    with open(f'./predictions/{fname}.csv', 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        lst = list()
        for i in range(18):
            lst.append('class'+str(i))
        lst.insert(0, "name")
        writer.writerow(lst)
        for i, protein in enumerate(test_name):
            lst = pred_proba[i,:].tolist()
            lst.insert(0, protein)
            writer.writerow(lst)

In [23]:
import gc
def clean_gpu_memory():
	gc.collect()
	torch.cuda.empty_cache()

# GNN2: GIN

In [132]:
class GIN_2(torch.nn.Module):
    def __init__(self, hidden_channels, edge_feature_col=1):
        super(GIN_2, self).__init__()
        torch.manual_seed(12345)
        self.p = 0.2
        self.feature_col = edge_feature_col
        self.bn = gnn.BatchNorm(hidden_channels)
        self.gin = gnn.GIN(in_channels=num_node_features, 
                            hidden_channels=hidden_channels,
                            num_layers=2,
                            dropout= 0.2,
                            norm=self.bn,
                            )
        self.lin = gnn.MLP(in_channels=hidden_channels, 
                            hidden_channels=hidden_channels,
                            out_channels=num_classes, 
                            num_layers=2, 
                            dropout=0.2)

    def forward(self, x, edge_index, batch):
        x = self.gin(x=x, edge_index=edge_index)
        x = gnn.global_add_pool(x, batch)
        x = self.bn(x)
        x = x.relu()
        emb = x
        x = self.lin(x)
        
        return F.log_softmax(x, dim=1), emb

In [147]:
model_gin = GIN_2(hidden_channels=hidden_dim).to(device)
print(model_gin)
optimizer_1 = torch.optim.Adam(model_gin.parameters(),lr=1e-4,weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

GIN_2(
  (bn): BatchNorm(64)
  (gin): GIN(86, 64, num_layers=2)
  (lin): MLP(64, 64, 18)
)


### Train

In [148]:
train_epochs(model_gin, optimizer_1, criterion, path='model_gin.pt', patience=50, from_scratch=True)

100%|██████████| 35/35 [00:02<00:00, 16.57it/s]


Epoch: 000, Train Loss: 2.8334, Train Acc: 0.1332, Val Loss: 2.8648, Val Acc: 0.1270


100%|██████████| 35/35 [00:01<00:00, 20.73it/s]


Epoch: 001, Train Loss: 2.9140, Train Acc: 0.2139, Val Loss: 2.8042, Val Acc: 0.2090


100%|██████████| 35/35 [00:01<00:00, 20.60it/s]


Epoch: 002, Train Loss: 2.7058, Train Acc: 0.2514, Val Loss: 2.7404, Val Acc: 0.2172
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 18.29it/s]


Epoch: 003, Train Loss: 2.8054, Train Acc: 0.2650, Val Loss: 2.7284, Val Acc: 0.2316
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 21.08it/s]


Epoch: 004, Train Loss: 2.5268, Train Acc: 0.2673, Val Loss: 2.5901, Val Acc: 0.2520
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 18.56it/s]


Epoch: 005, Train Loss: 2.6021, Train Acc: 0.2673, Val Loss: 2.4910, Val Acc: 0.2480
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 21.23it/s]


Epoch: 006, Train Loss: 2.4871, Train Acc: 0.2711, Val Loss: 2.7457, Val Acc: 0.2582


100%|██████████| 35/35 [00:01<00:00, 21.01it/s]


Epoch: 007, Train Loss: 2.4052, Train Acc: 0.2761, Val Loss: 2.4356, Val Acc: 0.2582
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 21.73it/s]


Epoch: 008, Train Loss: 2.3192, Train Acc: 0.2809, Val Loss: 2.8502, Val Acc: 0.2602


100%|██████████| 35/35 [00:01<00:00, 21.59it/s]


Epoch: 009, Train Loss: 2.2959, Train Acc: 0.2802, Val Loss: 2.6851, Val Acc: 0.2602


100%|██████████| 35/35 [00:01<00:00, 22.25it/s]


Epoch: 010, Train Loss: 2.4958, Train Acc: 0.2839, Val Loss: 2.4177, Val Acc: 0.2684
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 20.33it/s]


Epoch: 011, Train Loss: 2.1976, Train Acc: 0.2868, Val Loss: 2.4731, Val Acc: 0.2623


100%|██████████| 35/35 [00:01<00:00, 22.52it/s]


Epoch: 012, Train Loss: 2.3374, Train Acc: 0.2811, Val Loss: 2.4635, Val Acc: 0.2582


100%|██████████| 35/35 [00:01<00:00, 21.83it/s]


Epoch: 013, Train Loss: 2.3660, Train Acc: 0.2814, Val Loss: 2.3785, Val Acc: 0.2561
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 20.32it/s]


Epoch: 014, Train Loss: 2.3201, Train Acc: 0.2975, Val Loss: 2.6378, Val Acc: 0.2787


100%|██████████| 35/35 [00:01<00:00, 21.98it/s]


Epoch: 015, Train Loss: 2.4468, Train Acc: 0.3041, Val Loss: 2.6056, Val Acc: 0.2807


100%|██████████| 35/35 [00:01<00:00, 21.58it/s]


Epoch: 016, Train Loss: 2.2322, Train Acc: 0.2986, Val Loss: 2.2638, Val Acc: 0.2869
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 20.76it/s]


Epoch: 017, Train Loss: 2.0948, Train Acc: 0.3130, Val Loss: 2.1939, Val Acc: 0.3053
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 22.43it/s]


Epoch: 018, Train Loss: 1.9762, Train Acc: 0.3159, Val Loss: 2.5506, Val Acc: 0.2828


100%|██████████| 35/35 [00:01<00:00, 19.88it/s]


Epoch: 019, Train Loss: 2.0823, Train Acc: 0.3218, Val Loss: 2.2043, Val Acc: 0.3053


100%|██████████| 35/35 [00:01<00:00, 22.50it/s]


Epoch: 020, Train Loss: 2.1586, Train Acc: 0.3189, Val Loss: 2.3347, Val Acc: 0.2889


100%|██████████| 35/35 [00:01<00:00, 22.23it/s]


Epoch: 021, Train Loss: 2.1183, Train Acc: 0.3259, Val Loss: 2.1229, Val Acc: 0.3012
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 20.71it/s]


Epoch: 022, Train Loss: 2.3688, Train Acc: 0.3266, Val Loss: 2.3227, Val Acc: 0.2725


100%|██████████| 35/35 [00:01<00:00, 22.79it/s]


Epoch: 023, Train Loss: 2.1137, Train Acc: 0.3077, Val Loss: 2.0409, Val Acc: 0.3053
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 20.50it/s]


Epoch: 024, Train Loss: 2.0698, Train Acc: 0.3359, Val Loss: 2.2815, Val Acc: 0.3115


100%|██████████| 35/35 [00:01<00:00, 21.63it/s]


Epoch: 025, Train Loss: 2.3187, Train Acc: 0.3370, Val Loss: 2.1476, Val Acc: 0.3217


100%|██████████| 35/35 [00:01<00:00, 20.52it/s]


Epoch: 026, Train Loss: 2.1439, Train Acc: 0.3166, Val Loss: 2.3606, Val Acc: 0.3176


100%|██████████| 35/35 [00:01<00:00, 22.43it/s]


Epoch: 027, Train Loss: 2.2252, Train Acc: 0.3457, Val Loss: 2.3187, Val Acc: 0.3135


100%|██████████| 35/35 [00:01<00:00, 22.28it/s]


Epoch: 028, Train Loss: 2.0931, Train Acc: 0.3441, Val Loss: 2.2865, Val Acc: 0.3217


100%|██████████| 35/35 [00:01<00:00, 21.27it/s]


Epoch: 029, Train Loss: 2.0942, Train Acc: 0.3520, Val Loss: 2.3383, Val Acc: 0.3340


100%|██████████| 35/35 [00:01<00:00, 22.04it/s]


Epoch: 030, Train Loss: 2.1954, Train Acc: 0.3455, Val Loss: 2.3558, Val Acc: 0.3094


100%|██████████| 35/35 [00:01<00:00, 21.59it/s]


Epoch: 031, Train Loss: 2.0423, Train Acc: 0.3514, Val Loss: 2.1297, Val Acc: 0.3340


100%|██████████| 35/35 [00:01<00:00, 22.10it/s]


Epoch: 032, Train Loss: 2.1663, Train Acc: 0.3293, Val Loss: 2.3152, Val Acc: 0.3094


100%|██████████| 35/35 [00:01<00:00, 21.99it/s]


Epoch: 033, Train Loss: 2.1266, Train Acc: 0.3575, Val Loss: 2.3004, Val Acc: 0.3299


100%|██████████| 35/35 [00:01<00:00, 22.55it/s]


Epoch: 034, Train Loss: 1.6902, Train Acc: 0.3525, Val Loss: 2.2063, Val Acc: 0.3238


100%|██████████| 35/35 [00:01<00:00, 22.29it/s]


Epoch: 035, Train Loss: 2.0513, Train Acc: 0.3443, Val Loss: 2.2272, Val Acc: 0.2910


100%|██████████| 35/35 [00:01<00:00, 22.55it/s]


Epoch: 036, Train Loss: 2.1653, Train Acc: 0.3377, Val Loss: 2.5471, Val Acc: 0.3197


100%|██████████| 35/35 [00:01<00:00, 22.28it/s]


Epoch: 037, Train Loss: 2.1650, Train Acc: 0.3586, Val Loss: 2.2140, Val Acc: 0.3238


100%|██████████| 35/35 [00:01<00:00, 21.64it/s]


Epoch: 038, Train Loss: 2.1150, Train Acc: 0.3559, Val Loss: 2.0337, Val Acc: 0.3443
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 20.29it/s]


Epoch: 039, Train Loss: 2.0199, Train Acc: 0.3720, Val Loss: 2.1456, Val Acc: 0.3422


100%|██████████| 35/35 [00:01<00:00, 22.19it/s]


Epoch: 040, Train Loss: 1.9908, Train Acc: 0.3170, Val Loss: 2.3233, Val Acc: 0.3135


100%|██████████| 35/35 [00:01<00:00, 21.99it/s]


Epoch: 041, Train Loss: 2.0809, Train Acc: 0.3627, Val Loss: 2.1195, Val Acc: 0.3422


100%|██████████| 35/35 [00:01<00:00, 22.34it/s]


Epoch: 042, Train Loss: 2.3294, Train Acc: 0.3666, Val Loss: 2.2398, Val Acc: 0.3463


100%|██████████| 35/35 [00:01<00:00, 22.62it/s]


Epoch: 043, Train Loss: 2.0843, Train Acc: 0.3675, Val Loss: 2.0680, Val Acc: 0.3422


100%|██████████| 35/35 [00:01<00:00, 22.58it/s]


Epoch: 044, Train Loss: 2.0567, Train Acc: 0.3620, Val Loss: 2.2731, Val Acc: 0.3176


100%|██████████| 35/35 [00:01<00:00, 22.06it/s]


Epoch: 045, Train Loss: 1.9913, Train Acc: 0.3759, Val Loss: 2.2455, Val Acc: 0.3484


100%|██████████| 35/35 [00:01<00:00, 22.58it/s]


Epoch: 046, Train Loss: 2.1210, Train Acc: 0.3766, Val Loss: 2.1484, Val Acc: 0.3545


100%|██████████| 35/35 [00:01<00:00, 21.66it/s]


Epoch: 047, Train Loss: 2.0047, Train Acc: 0.3814, Val Loss: 2.0932, Val Acc: 0.3668


100%|██████████| 35/35 [00:01<00:00, 21.62it/s]


Epoch: 048, Train Loss: 2.2100, Train Acc: 0.3795, Val Loss: 2.2135, Val Acc: 0.3750


100%|██████████| 35/35 [00:01<00:00, 21.60it/s]


Epoch: 049, Train Loss: 2.0298, Train Acc: 0.3811, Val Loss: 2.1294, Val Acc: 0.3770


100%|██████████| 35/35 [00:01<00:00, 21.98it/s]


Epoch: 050, Train Loss: 1.8539, Train Acc: 0.3732, Val Loss: 2.0645, Val Acc: 0.3299


100%|██████████| 35/35 [00:01<00:00, 22.39it/s]


Epoch: 051, Train Loss: 1.7484, Train Acc: 0.3698, Val Loss: 2.0833, Val Acc: 0.3340


100%|██████████| 35/35 [00:01<00:00, 22.08it/s]


Epoch: 052, Train Loss: 2.0061, Train Acc: 0.3893, Val Loss: 2.0379, Val Acc: 0.3566


100%|██████████| 35/35 [00:01<00:00, 22.36it/s]


Epoch: 053, Train Loss: 2.3761, Train Acc: 0.3780, Val Loss: 2.1590, Val Acc: 0.3627


100%|██████████| 35/35 [00:01<00:00, 22.18it/s]


Epoch: 054, Train Loss: 1.9049, Train Acc: 0.3800, Val Loss: 2.2190, Val Acc: 0.3402


100%|██████████| 35/35 [00:01<00:00, 21.16it/s]


Epoch: 055, Train Loss: 1.8814, Train Acc: 0.3770, Val Loss: 2.1290, Val Acc: 0.3176


100%|██████████| 35/35 [00:01<00:00, 19.31it/s]


Epoch: 056, Train Loss: 2.3589, Train Acc: 0.3832, Val Loss: 2.1056, Val Acc: 0.3709


100%|██████████| 35/35 [00:01<00:00, 19.65it/s]


Epoch: 057, Train Loss: 1.7938, Train Acc: 0.3970, Val Loss: 2.0329, Val Acc: 0.3730
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 21.07it/s]


Epoch: 058, Train Loss: 1.7355, Train Acc: 0.3923, Val Loss: 2.0056, Val Acc: 0.3770
checkpoint saved


100%|██████████| 35/35 [00:01<00:00, 21.45it/s]


Epoch: 059, Train Loss: 2.2700, Train Acc: 0.4009, Val Loss: 2.2572, Val Acc: 0.3730


100%|██████████| 35/35 [00:01<00:00, 20.32it/s]


Epoch: 060, Train Loss: 2.2155, Train Acc: 0.4057, Val Loss: 2.0330, Val Acc: 0.3668


100%|██████████| 35/35 [00:01<00:00, 22.45it/s]


Epoch: 061, Train Loss: 1.9613, Train Acc: 0.4025, Val Loss: 2.1398, Val Acc: 0.3730


100%|██████████| 35/35 [00:01<00:00, 22.23it/s]


Epoch: 062, Train Loss: 1.8857, Train Acc: 0.3802, Val Loss: 2.1267, Val Acc: 0.3586


100%|██████████| 35/35 [00:01<00:00, 21.92it/s]


Epoch: 063, Train Loss: 1.9554, Train Acc: 0.4000, Val Loss: 2.0670, Val Acc: 0.3545
early stopping


: 

In [24]:
#train_epochs_full(model_gin, optimizer_1, path='model_gin.pt', patience=50, from_scratch=True)

## Load and write results

In [25]:
#load_checkpoint(model_gin, optimizer_1, path='./model_gin.pt')
#write_csv('model_gin',predict(model_gin))

## 

In [26]:
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.nn import global_mean_pool, global_add_pool

In [31]:
class GIN(torch.nn.Module):
    """GIN"""
    def __init__(self, dim_h):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(num_node_features, dim_h),
                       BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(),
                       Linear(dim_h, dim_h), ReLU()))
        self.lin1 = Linear(dim_h*3, dim_h*3)
        self.lin2 = Linear(dim_h*3, num_classes)

    def forward(self, x, edge_index, edge_attr, batch):
        # Node embeddings 
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)
        
        return h, F.log_softmax(h, dim=1)

In [37]:
gin = GIN(dim_h=32).to(device)
optimizer_1 = torch.optim.Adam(gin.parameters(),lr=lr,weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

In [35]:
print(device)

cuda


In [38]:
train_epochs(gin, optimizer_1, path='model_gin2.pt', patience=50, from_scratch=True)

100%|██████████| 138/138 [00:03<00:00, 45.89it/s]


Epoch: 000, Train Loss: 2.4834, Train Acc: 0.1934, Val Loss: 2.6702, Val Acc: 0.1516
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 82.65it/s]


Epoch: 001, Train Loss: 2.3889, Train Acc: 0.2645, Val Loss: 2.2086, Val Acc: 0.2541
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 81.75it/s]


Epoch: 002, Train Loss: 2.2660, Train Acc: 0.2677, Val Loss: 2.4611, Val Acc: 0.2787


100%|██████████| 138/138 [00:01<00:00, 83.00it/s]


Epoch: 003, Train Loss: 2.6520, Train Acc: 0.2214, Val Loss: 2.3324, Val Acc: 0.2316


100%|██████████| 138/138 [00:01<00:00, 84.79it/s]


Epoch: 004, Train Loss: 2.3110, Train Acc: 0.2666, Val Loss: 3.2186, Val Acc: 0.2561


100%|██████████| 138/138 [00:01<00:00, 85.44it/s]


Epoch: 005, Train Loss: 2.7508, Train Acc: 0.2480, Val Loss: 2.3856, Val Acc: 0.2520


100%|██████████| 138/138 [00:01<00:00, 84.42it/s]


Epoch: 006, Train Loss: 2.7496, Train Acc: 0.2205, Val Loss: 2.0121, Val Acc: 0.2357
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 83.59it/s]


Epoch: 007, Train Loss: 2.0694, Train Acc: 0.2732, Val Loss: 1.6203, Val Acc: 0.2807
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 83.86it/s]


Epoch: 008, Train Loss: 2.0114, Train Acc: 0.2955, Val Loss: 2.0832, Val Acc: 0.2746


100%|██████████| 138/138 [00:01<00:00, 82.74it/s]


Epoch: 009, Train Loss: 2.1409, Train Acc: 0.2955, Val Loss: 2.5292, Val Acc: 0.2664


100%|██████████| 138/138 [00:01<00:00, 84.32it/s]


Epoch: 010, Train Loss: 2.4401, Train Acc: 0.3018, Val Loss: 2.9175, Val Acc: 0.3033


100%|██████████| 138/138 [00:01<00:00, 84.91it/s]


Epoch: 011, Train Loss: 3.2522, Train Acc: 0.2443, Val Loss: 3.2102, Val Acc: 0.2357


100%|██████████| 138/138 [00:01<00:00, 83.45it/s]


Epoch: 012, Train Loss: 6.1355, Train Acc: 0.2416, Val Loss: 3.4455, Val Acc: 0.2582


100%|██████████| 138/138 [00:01<00:00, 82.65it/s]


Epoch: 013, Train Loss: 1.9272, Train Acc: 0.3457, Val Loss: 2.0120, Val Acc: 0.3340


100%|██████████| 138/138 [00:01<00:00, 79.98it/s]


Epoch: 014, Train Loss: 2.0971, Train Acc: 0.3370, Val Loss: 3.2703, Val Acc: 0.3381


100%|██████████| 138/138 [00:01<00:00, 79.03it/s]


Epoch: 015, Train Loss: 2.6610, Train Acc: 0.3455, Val Loss: 2.6897, Val Acc: 0.3340


100%|██████████| 138/138 [00:01<00:00, 75.78it/s]


Epoch: 016, Train Loss: 3.1605, Train Acc: 0.3132, Val Loss: 2.5086, Val Acc: 0.3197


100%|██████████| 138/138 [00:01<00:00, 77.86it/s]


Epoch: 017, Train Loss: 2.5670, Train Acc: 0.3418, Val Loss: 1.9713, Val Acc: 0.3627


100%|██████████| 138/138 [00:01<00:00, 80.98it/s]


Epoch: 018, Train Loss: 1.8292, Train Acc: 0.3916, Val Loss: 1.2029, Val Acc: 0.4180
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 79.84it/s]


Epoch: 019, Train Loss: 1.6277, Train Acc: 0.2393, Val Loss: 2.1192, Val Acc: 0.2377


100%|██████████| 138/138 [00:01<00:00, 79.39it/s]


Epoch: 020, Train Loss: 2.0389, Train Acc: 0.4066, Val Loss: 2.4649, Val Acc: 0.4016


100%|██████████| 138/138 [00:01<00:00, 80.12it/s]


Epoch: 021, Train Loss: 2.4594, Train Acc: 0.4295, Val Loss: 2.7186, Val Acc: 0.4242


100%|██████████| 138/138 [00:01<00:00, 80.15it/s]


Epoch: 022, Train Loss: 3.0204, Train Acc: 0.2909, Val Loss: 2.7983, Val Acc: 0.2971


100%|██████████| 138/138 [00:01<00:00, 84.75it/s]


Epoch: 023, Train Loss: 1.9319, Train Acc: 0.3961, Val Loss: 2.4170, Val Acc: 0.3852


100%|██████████| 138/138 [00:01<00:00, 83.90it/s]


Epoch: 024, Train Loss: 1.8517, Train Acc: 0.4384, Val Loss: 1.9956, Val Acc: 0.4344


100%|██████████| 138/138 [00:01<00:00, 83.74it/s]


Epoch: 025, Train Loss: 1.4932, Train Acc: 0.4380, Val Loss: 2.7643, Val Acc: 0.4344


100%|██████████| 138/138 [00:01<00:00, 83.91it/s]


Epoch: 026, Train Loss: 1.7472, Train Acc: 0.4286, Val Loss: 1.6790, Val Acc: 0.4098


100%|██████████| 138/138 [00:01<00:00, 83.84it/s]


Epoch: 027, Train Loss: 2.0846, Train Acc: 0.4295, Val Loss: 2.4791, Val Acc: 0.4119


100%|██████████| 138/138 [00:01<00:00, 83.60it/s]


Epoch: 028, Train Loss: 1.8589, Train Acc: 0.4414, Val Loss: 1.3791, Val Acc: 0.4385


100%|██████████| 138/138 [00:01<00:00, 82.24it/s]


Epoch: 029, Train Loss: 1.5001, Train Acc: 0.4000, Val Loss: 0.9832, Val Acc: 0.3996
checkpoint saved


100%|██████████| 138/138 [00:01<00:00, 83.85it/s]


Epoch: 030, Train Loss: 2.3004, Train Acc: 0.4345, Val Loss: 2.5005, Val Acc: 0.4119


100%|██████████| 138/138 [00:01<00:00, 84.25it/s]


Epoch: 031, Train Loss: 2.6682, Train Acc: 0.4377, Val Loss: 1.8871, Val Acc: 0.4180


100%|██████████| 138/138 [00:01<00:00, 81.81it/s]


Epoch: 032, Train Loss: 1.5271, Train Acc: 0.4530, Val Loss: 1.3719, Val Acc: 0.4713


100%|██████████| 138/138 [00:01<00:00, 83.87it/s]


Epoch: 033, Train Loss: 2.7651, Train Acc: 0.4068, Val Loss: 2.4066, Val Acc: 0.4098


100%|██████████| 138/138 [00:01<00:00, 83.54it/s]


Epoch: 034, Train Loss: 1.7591, Train Acc: 0.4425, Val Loss: 1.9784, Val Acc: 0.4119


100%|██████████| 138/138 [00:01<00:00, 83.05it/s]


Epoch: 035, Train Loss: 1.7446, Train Acc: 0.4634, Val Loss: 1.8950, Val Acc: 0.4324


100%|██████████| 138/138 [00:01<00:00, 84.81it/s]


Epoch: 036, Train Loss: 1.8353, Train Acc: 0.4450, Val Loss: 3.0539, Val Acc: 0.4242


100%|██████████| 138/138 [00:01<00:00, 83.35it/s]


Epoch: 037, Train Loss: 3.4449, Train Acc: 0.3470, Val Loss: 2.7658, Val Acc: 0.3443


100%|██████████| 138/138 [00:01<00:00, 82.06it/s]


Epoch: 038, Train Loss: 1.3075, Train Acc: 0.4107, Val Loss: 2.7773, Val Acc: 0.4201


100%|██████████| 138/138 [00:01<00:00, 82.39it/s]


Epoch: 039, Train Loss: 1.7835, Train Acc: 0.4000, Val Loss: 2.9929, Val Acc: 0.3975


100%|██████████| 138/138 [00:01<00:00, 81.68it/s]


Epoch: 040, Train Loss: 1.9242, Train Acc: 0.4527, Val Loss: 1.9740, Val Acc: 0.4221


100%|██████████| 138/138 [00:01<00:00, 83.49it/s]


Epoch: 041, Train Loss: 2.3492, Train Acc: 0.4595, Val Loss: 1.7629, Val Acc: 0.4652


100%|██████████| 138/138 [00:01<00:00, 78.08it/s]


Epoch: 042, Train Loss: 1.6356, Train Acc: 0.4359, Val Loss: 2.3129, Val Acc: 0.4201


100%|██████████| 138/138 [00:01<00:00, 82.75it/s]


Epoch: 043, Train Loss: 1.6400, Train Acc: 0.4332, Val Loss: 2.0622, Val Acc: 0.4078


100%|██████████| 138/138 [00:01<00:00, 82.37it/s]


Epoch: 044, Train Loss: 1.1617, Train Acc: 0.4198, Val Loss: 2.6522, Val Acc: 0.4242


100%|██████████| 138/138 [00:01<00:00, 79.71it/s]


Epoch: 045, Train Loss: 1.9080, Train Acc: 0.4511, Val Loss: 1.8947, Val Acc: 0.4262


100%|██████████| 138/138 [00:01<00:00, 80.33it/s]


Epoch: 046, Train Loss: 1.4280, Train Acc: 0.4614, Val Loss: 2.1774, Val Acc: 0.4467


100%|██████████| 138/138 [00:01<00:00, 83.11it/s]


Epoch: 047, Train Loss: 1.5540, Train Acc: 0.4693, Val Loss: 2.2827, Val Acc: 0.4385


100%|██████████| 138/138 [00:01<00:00, 82.80it/s]


Epoch: 048, Train Loss: 1.3029, Train Acc: 0.3916, Val Loss: 2.7375, Val Acc: 0.3934


100%|██████████| 138/138 [00:01<00:00, 83.88it/s]


Epoch: 049, Train Loss: 1.4272, Train Acc: 0.4682, Val Loss: 1.9271, Val Acc: 0.4467


100%|██████████| 138/138 [00:01<00:00, 82.28it/s]


Epoch: 050, Train Loss: 1.1539, Train Acc: 0.4707, Val Loss: 1.4797, Val Acc: 0.4508


100%|██████████| 138/138 [00:01<00:00, 83.96it/s]


Epoch: 051, Train Loss: 1.8474, Train Acc: 0.4218, Val Loss: 2.8702, Val Acc: 0.3934


100%|██████████| 138/138 [00:01<00:00, 83.56it/s]


Epoch: 052, Train Loss: 1.6180, Train Acc: 0.4739, Val Loss: 1.7549, Val Acc: 0.4467


100%|██████████| 138/138 [00:01<00:00, 83.59it/s]


Epoch: 053, Train Loss: 2.2509, Train Acc: 0.4314, Val Loss: 1.9488, Val Acc: 0.4180


100%|██████████| 138/138 [00:01<00:00, 83.33it/s]


Epoch: 054, Train Loss: 1.7103, Train Acc: 0.4495, Val Loss: 2.1500, Val Acc: 0.4037


100%|██████████| 138/138 [00:01<00:00, 81.35it/s]


Epoch: 055, Train Loss: 2.0469, Train Acc: 0.4786, Val Loss: 2.3809, Val Acc: 0.4857
early stopping


In [None]:
#train_epochs_full(model_gin, optimizer_1, path='model_gin2.pt', patience=50, from_scratch=True)