# load data

In [8]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import os.path as osp
from tqdm import tqdm

In [5]:
N = 6111
dir_name = './data/'
datasets = []
for i in tqdm(range(N)):
	data = torch.load(osp.join(dir_name,f'data_{i}.pt'))
	datasets.append(data)

100%|██████████| 6111/6111 [01:19<00:00, 76.83it/s] 


In [99]:
for i, data in enumerate(datasets):
	data.edge_index = data.edge_index.type(torch.long)
	torch.save(data, osp.join(dir_name,f'data_{i}.pt'))

In [6]:
datasets[0]

Data(x=[327, 86], edge_index=[2, 6233], edge_attr=[6233, 5])

## train-test split

In [105]:
train_idx = list()
test_name = list()
test_idx = list()
y_train = list()
with open('./data/raw/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            test_name.append(t[0])
            test_idx.append(i)
        else:
            train_idx.append(i)
            y_train.append(int(t[1][:-1]))

In [106]:
len(train_idx),len(test_idx)

(4888, 1223)

In [107]:
train_val_set = [datasets[i] for i in train_idx]
test_set = [datasets[i] for i in test_idx]

## add label

In [108]:
train_set[0].y

6

In [109]:
for data, y in zip(train_val_set, y_train):
	data.y = y

## train-valid split

In [110]:
train_set, val_set = torch.utils.data.random_split(train_val_set, [0.9,0.1], generator=torch.Generator().manual_seed(42))

In [111]:
len(train_set), len(val_set)

(4400, 488)

## Hyperparameters

In [171]:
import numpy as np
num_node_features = train_set[0].num_node_features # 86
print(num_node_features)
num_classes = len(np.unique(y_train)) # 18
print(num_classes)
lr = 0.001
epochs = 100
batch_size = 64
hidden_dim = 64

86
18


## Dataloader

In [139]:
from torch_geometric.loader import DataLoader
from torch.utils.data.dataloader import default_collate
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
collate_fn = lambda x: tuple(x_.to(device) for x_ in default_collate(x))
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, collate_fn=collate_fn)

In [170]:
train_val_loader =DataLoader(train_val_set, batch_size=64, shuffle=True, collate_fn=collate_fn)

In [140]:
for step, data in enumerate(train_loader):
	print(f'Step {step + 1}:')
	print('=======')
	print(f'Number of graphs in the current batch: {data.num_graphs}')
	print(data.is_cuda)
	print()
	break

Step 1:
Number of graphs in the current batch: 64
False



# GNN0: baseline

In [115]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.utils import dropout_node

In [154]:
from torch import embedding


class GCN_0(torch.nn.Module):
    def __init__(self, hidden_channels):
        self.p = 0.2
        super(GCN_0, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = gnn.GCNConv(num_node_features, hidden_channels)
        self.conv2 = gnn.GCNConv(hidden_channels, hidden_channels)
        self.bn = gnn.BatchNorm(hidden_channels)
        self.dropout = nn.Dropout(self.p)
        self.lin1 = nn.Linear(hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        edge_index = dropout_node(edge_index)[0] # edge_index, edge_mask, node_mask
        x = self.conv2(x, edge_index)
        x = x.relu()

        # 2. Readout layer
        x = gnn.global_add_pool(x, batch)  # [batch_size, hidden_channels]
        x = self.bn(x)
        x = x.relu()
        embedding = x

        # 3. Apply a final classifier
        x = self.lin1(x)
        x = x.relu()
        
        x = F.dropout(x, p=self.p, training=self.training)
        x = self.lin2(x)
        
        return F.log_softmax(x, dim=1), embedding

In [163]:
model_0 = GCN_0(hidden_channels=hidden_dim).to(device)
print(model_0)
optimizer_0 = torch.optim.Adam(model_0.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

GCN_0(
  (conv1): GCNConv(86, 64)
  (conv2): GCNConv(64, 64)
  (bn): BatchNorm(64)
  (dropout): Dropout(p=0.2, inplace=False)
  (lin1): Linear(in_features=64, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=18, bias=True)
)


In [181]:
def train(model, optimizer):
	model.train()
	for data in tqdm(train_loader):
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.batch)
		loss = criterion(out, data.y)
		loss.backward()
		optimizer.step()
		optimizer.zero_grad()

def eval(model, loader):
	model.eval()
	correct = 0
	for data in loader:
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.batch)
		loss = criterion(out, data.y)
		pred = out.argmax(dim=1)
		correct += int((pred == data.y).sum())
		acc = correct / len(loader.dataset)
	return acc, loss

def predict(model):
	pred_proba = []
	model.eval()
	for data in test_loader:
		data = data.to(device)
		out,_ = model(data.x, data.edge_index, data.batch)
		pred_proba.append(out)
	pred_proba = torch.cat(pred_proba, dim=0)
	pred_proba = torch.exp(pred_proba)
	return pred_proba.detach().cpu().numpy()

In [176]:
import os.path as osp
def load_checkpoint(model, optimizer, path):
	if osp.exists(path):
		checkpoint = torch.load(path)
		model.load_state_dict(checkpoint['model_state_dict'])
		optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
		print('checkpoint loaded')

In [169]:
import time
# pbar = tqdm(range(epochs))
for epoch in range(epochs):
	train(model_0, optimizer_0)
	train_acc, train_loss = eval(model_0, train_loader)
	val_acc, val_loss = eval(model_0, val_loader)
	time.sleep(0.5)
	print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')


100%|██████████| 69/69 [00:05<00:00, 11.95it/s]


Epoch: 000, Train Loss: 1.7957, Train Acc: 0.5325, Val Loss: 1.8774, Val Acc: 0.4713


100%|██████████| 69/69 [00:03<00:00, 19.72it/s]


Epoch: 001, Train Loss: 1.5660, Train Acc: 0.5255, Val Loss: 2.1350, Val Acc: 0.4406


100%|██████████| 69/69 [00:03<00:00, 18.70it/s]


Epoch: 002, Train Loss: 1.4955, Train Acc: 0.5514, Val Loss: 2.1288, Val Acc: 0.4385


100%|██████████| 69/69 [00:03<00:00, 20.34it/s]


Epoch: 003, Train Loss: 1.6860, Train Acc: 0.5441, Val Loss: 1.4618, Val Acc: 0.4549


100%|██████████| 69/69 [00:03<00:00, 20.30it/s]


Epoch: 004, Train Loss: 1.1824, Train Acc: 0.5273, Val Loss: 2.0306, Val Acc: 0.4631


100%|██████████| 69/69 [00:03<00:00, 20.46it/s]


Epoch: 005, Train Loss: 1.4549, Train Acc: 0.5334, Val Loss: 2.2162, Val Acc: 0.4549


100%|██████████| 69/69 [00:03<00:00, 18.71it/s]


Epoch: 006, Train Loss: 1.6254, Train Acc: 0.5025, Val Loss: 2.0711, Val Acc: 0.3955


100%|██████████| 69/69 [00:04<00:00, 16.18it/s]


Epoch: 007, Train Loss: 1.3229, Train Acc: 0.5466, Val Loss: 1.8376, Val Acc: 0.4262


100%|██████████| 69/69 [00:04<00:00, 16.26it/s]


Epoch: 008, Train Loss: 1.3447, Train Acc: 0.5136, Val Loss: 1.9028, Val Acc: 0.4098


100%|██████████| 69/69 [00:03<00:00, 19.43it/s]


Epoch: 009, Train Loss: 1.3562, Train Acc: 0.5441, Val Loss: 1.9657, Val Acc: 0.4611


100%|██████████| 69/69 [00:03<00:00, 18.08it/s]


Epoch: 010, Train Loss: 1.3333, Train Acc: 0.5343, Val Loss: 2.0186, Val Acc: 0.4262


100%|██████████| 69/69 [00:03<00:00, 17.92it/s]


Epoch: 011, Train Loss: 1.4582, Train Acc: 0.5452, Val Loss: 2.4766, Val Acc: 0.4344


100%|██████████| 69/69 [00:04<00:00, 16.52it/s]


Epoch: 012, Train Loss: 1.6616, Train Acc: 0.5459, Val Loss: 2.1696, Val Acc: 0.4795


100%|██████████| 69/69 [00:03<00:00, 19.37it/s]


Epoch: 013, Train Loss: 1.4664, Train Acc: 0.5602, Val Loss: 20.2301, Val Acc: 0.4631


100%|██████████| 69/69 [00:03<00:00, 20.59it/s]


Epoch: 014, Train Loss: 1.6254, Train Acc: 0.5511, Val Loss: 2.1169, Val Acc: 0.4672


100%|██████████| 69/69 [00:03<00:00, 18.00it/s]


Epoch: 015, Train Loss: 1.3629, Train Acc: 0.5391, Val Loss: 2.4051, Val Acc: 0.4488


100%|██████████| 69/69 [00:04<00:00, 14.82it/s]


Epoch: 016, Train Loss: 1.2688, Train Acc: 0.5468, Val Loss: 1.4115, Val Acc: 0.4754


100%|██████████| 69/69 [00:03<00:00, 19.35it/s]


Epoch: 017, Train Loss: 1.5058, Train Acc: 0.5543, Val Loss: 1.4754, Val Acc: 0.4344


100%|██████████| 69/69 [00:03<00:00, 18.35it/s]


Epoch: 018, Train Loss: 1.1597, Train Acc: 0.5580, Val Loss: 2.2193, Val Acc: 0.4672


100%|██████████| 69/69 [00:03<00:00, 19.65it/s]


Epoch: 019, Train Loss: 1.4914, Train Acc: 0.5618, Val Loss: 7.4209, Val Acc: 0.4303


100%|██████████| 69/69 [00:03<00:00, 19.65it/s]


Epoch: 020, Train Loss: 1.2274, Train Acc: 0.5107, Val Loss: 9.5664, Val Acc: 0.4385


100%|██████████| 69/69 [00:03<00:00, 18.60it/s]


Epoch: 021, Train Loss: 1.8078, Train Acc: 0.5277, Val Loss: 447.8732, Val Acc: 0.3955


100%|██████████| 69/69 [00:03<00:00, 19.16it/s]


Epoch: 022, Train Loss: 1.4272, Train Acc: 0.5432, Val Loss: 2.1306, Val Acc: 0.4488


100%|██████████| 69/69 [00:03<00:00, 19.44it/s]


Epoch: 023, Train Loss: 1.9424, Train Acc: 0.4784, Val Loss: 1.3820, Val Acc: 0.3934


100%|██████████| 69/69 [00:03<00:00, 19.73it/s]


Epoch: 024, Train Loss: 1.4654, Train Acc: 0.5632, Val Loss: 2.1456, Val Acc: 0.4836


100%|██████████| 69/69 [00:03<00:00, 19.89it/s]


Epoch: 025, Train Loss: 1.7299, Train Acc: 0.5080, Val Loss: 1.7852, Val Acc: 0.3996


100%|██████████| 69/69 [00:03<00:00, 19.43it/s]


Epoch: 026, Train Loss: 1.2108, Train Acc: 0.5439, Val Loss: 1.8527, Val Acc: 0.4447


100%|██████████| 69/69 [00:03<00:00, 20.25it/s]


Epoch: 027, Train Loss: 1.4919, Train Acc: 0.5664, Val Loss: 1.6013, Val Acc: 0.4447


100%|██████████| 69/69 [00:03<00:00, 19.97it/s]


Epoch: 028, Train Loss: 1.2606, Train Acc: 0.5734, Val Loss: 1.6618, Val Acc: 0.4303


100%|██████████| 69/69 [00:03<00:00, 19.43it/s]


Epoch: 029, Train Loss: 1.4567, Train Acc: 0.5473, Val Loss: 1.7317, Val Acc: 0.4385


100%|██████████| 69/69 [00:03<00:00, 20.24it/s]


Epoch: 030, Train Loss: 1.1428, Train Acc: 0.5520, Val Loss: 1.4062, Val Acc: 0.4262


100%|██████████| 69/69 [00:03<00:00, 20.10it/s]


Epoch: 031, Train Loss: 1.7074, Train Acc: 0.5568, Val Loss: 2.1313, Val Acc: 0.4365


100%|██████████| 69/69 [00:03<00:00, 19.89it/s]


Epoch: 032, Train Loss: 1.1715, Train Acc: 0.5716, Val Loss: 2.1458, Val Acc: 0.4693


100%|██████████| 69/69 [00:03<00:00, 19.98it/s]


Epoch: 033, Train Loss: 1.1071, Train Acc: 0.5682, Val Loss: 1.9989, Val Acc: 0.4447


100%|██████████| 69/69 [00:03<00:00, 20.60it/s]


Epoch: 034, Train Loss: 1.3973, Train Acc: 0.5270, Val Loss: 46.7864, Val Acc: 0.4016


100%|██████████| 69/69 [00:03<00:00, 20.48it/s]


Epoch: 035, Train Loss: 1.4944, Train Acc: 0.5366, Val Loss: 1.9438, Val Acc: 0.4221


100%|██████████| 69/69 [00:03<00:00, 20.53it/s]


Epoch: 036, Train Loss: 1.3390, Train Acc: 0.5232, Val Loss: 1.7086, Val Acc: 0.4057


100%|██████████| 69/69 [00:03<00:00, 20.31it/s]


Epoch: 037, Train Loss: 1.1825, Train Acc: 0.5707, Val Loss: 2.0521, Val Acc: 0.4467


100%|██████████| 69/69 [00:03<00:00, 20.46it/s]


Epoch: 038, Train Loss: 1.4943, Train Acc: 0.5445, Val Loss: 2.2496, Val Acc: 0.4406


100%|██████████| 69/69 [00:03<00:00, 20.08it/s]


Epoch: 039, Train Loss: 1.7873, Train Acc: 0.5534, Val Loss: 1.9457, Val Acc: 0.4078


100%|██████████| 69/69 [00:11<00:00,  6.25it/s]


Epoch: 040, Train Loss: 1.5925, Train Acc: 0.5327, Val Loss: 1.9567, Val Acc: 0.4344


100%|██████████| 69/69 [00:11<00:00,  6.11it/s]


Epoch: 041, Train Loss: 1.0400, Train Acc: 0.5780, Val Loss: 1.8994, Val Acc: 0.4611


100%|██████████| 69/69 [00:03<00:00, 17.66it/s]


Epoch: 042, Train Loss: 1.3028, Train Acc: 0.5607, Val Loss: 2.0937, Val Acc: 0.4406


100%|██████████| 69/69 [00:03<00:00, 18.15it/s]


Epoch: 043, Train Loss: 1.2415, Train Acc: 0.5666, Val Loss: 1.7756, Val Acc: 0.4447


100%|██████████| 69/69 [00:03<00:00, 17.98it/s]


Epoch: 044, Train Loss: 1.3755, Train Acc: 0.5527, Val Loss: 1.8814, Val Acc: 0.3914


100%|██████████| 69/69 [00:04<00:00, 16.98it/s]


Epoch: 045, Train Loss: 1.6188, Train Acc: 0.5416, Val Loss: 2.1065, Val Acc: 0.4385


100%|██████████| 69/69 [00:03<00:00, 17.47it/s]


Epoch: 046, Train Loss: 1.1530, Train Acc: 0.5432, Val Loss: 2.1718, Val Acc: 0.4242


100%|██████████| 69/69 [00:03<00:00, 17.47it/s]


Epoch: 047, Train Loss: 1.4181, Train Acc: 0.5250, Val Loss: 218.7119, Val Acc: 0.4652


100%|██████████| 69/69 [00:03<00:00, 17.28it/s]


Epoch: 048, Train Loss: 1.5120, Train Acc: 0.5480, Val Loss: 2.4486, Val Acc: 0.4180


100%|██████████| 69/69 [00:03<00:00, 17.46it/s]


Epoch: 049, Train Loss: 1.3727, Train Acc: 0.5430, Val Loss: 1.7742, Val Acc: 0.4160


100%|██████████| 69/69 [00:04<00:00, 15.74it/s]


Epoch: 050, Train Loss: 1.4071, Train Acc: 0.5698, Val Loss: 1.9198, Val Acc: 0.4508


100%|██████████| 69/69 [00:04<00:00, 16.25it/s]


Epoch: 051, Train Loss: 1.9284, Train Acc: 0.5514, Val Loss: 2.0438, Val Acc: 0.4242


100%|██████████| 69/69 [00:04<00:00, 17.12it/s]


Epoch: 052, Train Loss: 1.1916, Train Acc: 0.5768, Val Loss: 2.2313, Val Acc: 0.4672


100%|██████████| 69/69 [00:04<00:00, 15.06it/s]


Epoch: 053, Train Loss: 1.4046, Train Acc: 0.5230, Val Loss: 1.9201, Val Acc: 0.4160


100%|██████████| 69/69 [00:04<00:00, 14.98it/s]


Epoch: 054, Train Loss: 1.3406, Train Acc: 0.5748, Val Loss: 1.3864, Val Acc: 0.4549


100%|██████████| 69/69 [00:04<00:00, 17.00it/s]


Epoch: 055, Train Loss: 1.0728, Train Acc: 0.5664, Val Loss: 1.6912, Val Acc: 0.4549


100%|██████████| 69/69 [00:03<00:00, 17.89it/s]


Epoch: 056, Train Loss: 1.8009, Train Acc: 0.5839, Val Loss: 1.8745, Val Acc: 0.4549


100%|██████████| 69/69 [00:03<00:00, 17.65it/s]


Epoch: 057, Train Loss: 1.6800, Train Acc: 0.5184, Val Loss: 2.2160, Val Acc: 0.3996


100%|██████████| 69/69 [00:04<00:00, 16.74it/s]


Epoch: 058, Train Loss: 1.3395, Train Acc: 0.5739, Val Loss: 2.0726, Val Acc: 0.4529


100%|██████████| 69/69 [00:04<00:00, 16.17it/s]


Epoch: 059, Train Loss: 1.4360, Train Acc: 0.5520, Val Loss: 2.2005, Val Acc: 0.4385


100%|██████████| 69/69 [00:03<00:00, 19.01it/s]


Epoch: 060, Train Loss: 1.6336, Train Acc: 0.5557, Val Loss: 2.3950, Val Acc: 0.4303


100%|██████████| 69/69 [00:04<00:00, 14.57it/s]


Epoch: 061, Train Loss: 1.4734, Train Acc: 0.5900, Val Loss: 1.8023, Val Acc: 0.4570


100%|██████████| 69/69 [00:03<00:00, 17.88it/s]


Epoch: 062, Train Loss: 1.2535, Train Acc: 0.5984, Val Loss: 38.4535, Val Acc: 0.4754


 30%|███       | 21/69 [00:01<00:02, 18.28it/s]


KeyboardInterrupt: 

In [None]:
# PATH = './model/model_0.pt'
# torch.save({
#             'epoch': epochs,
#             'model_state_dict': model_0.state_dict(),
#             'optimizer_state_dict': optimizer_0.state_dict(),
#             }, PATH)

training sample model for over 150 epochs, no significant improvement. we keep the 100 epochs checkpoint.

retrain model together with validation data

In [172]:
model_0_full = GCN_0(hidden_channels=hidden_dim).to(device)
print(model_0_full)
optimizer_0_full = torch.optim.Adam(model_0_full.parameters(), lr=lr)

GCN_0(
  (conv1): GCNConv(86, 64)
  (conv2): GCNConv(64, 64)
  (bn): BatchNorm(64)
  (dropout): Dropout(p=0.2, inplace=False)
  (lin1): Linear(in_features=64, out_features=64, bias=True)
  (lin2): Linear(in_features=64, out_features=18, bias=True)
)


In [173]:
for epoch in range(epochs):
	train(model_0_full, optimizer_0_full)
	train_acc, train_loss = eval(model_0_full, train_val_loader)
	time.sleep(0.5)
	print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')

100%|██████████| 69/69 [00:03<00:00, 18.34it/s]


Epoch: 000, Train Loss: 2.5891, Train Acc: 0.2361


100%|██████████| 69/69 [00:03<00:00, 20.54it/s]


Epoch: 001, Train Loss: 2.0847, Train Acc: 0.2287


100%|██████████| 69/69 [00:03<00:00, 20.08it/s]


Epoch: 002, Train Loss: 2.1090, Train Acc: 0.2390


100%|██████████| 69/69 [00:03<00:00, 19.08it/s]


Epoch: 003, Train Loss: 2.1934, Train Acc: 0.2535


100%|██████████| 69/69 [00:03<00:00, 17.44it/s]


Epoch: 004, Train Loss: 2.3879, Train Acc: 0.2696


100%|██████████| 69/69 [00:03<00:00, 19.30it/s]


Epoch: 005, Train Loss: 2.2787, Train Acc: 0.2813


100%|██████████| 69/69 [00:03<00:00, 17.97it/s]


Epoch: 006, Train Loss: 2.3390, Train Acc: 0.3003


100%|██████████| 69/69 [00:03<00:00, 19.70it/s]


Epoch: 007, Train Loss: 2.1496, Train Acc: 0.3181


100%|██████████| 69/69 [00:03<00:00, 19.69it/s]


Epoch: 008, Train Loss: 2.0978, Train Acc: 0.3224


100%|██████████| 69/69 [00:03<00:00, 19.77it/s]


Epoch: 009, Train Loss: 2.3000, Train Acc: 0.3347


100%|██████████| 69/69 [00:03<00:00, 18.11it/s]


Epoch: 010, Train Loss: 2.3105, Train Acc: 0.3365


100%|██████████| 69/69 [00:04<00:00, 14.10it/s]


Epoch: 011, Train Loss: 2.1176, Train Acc: 0.3445


100%|██████████| 69/69 [00:03<00:00, 19.64it/s]


Epoch: 012, Train Loss: 2.4211, Train Acc: 0.3165


100%|██████████| 69/69 [00:03<00:00, 17.97it/s]


Epoch: 013, Train Loss: 1.9480, Train Acc: 0.3844


100%|██████████| 69/69 [00:03<00:00, 19.07it/s]


Epoch: 014, Train Loss: 2.0333, Train Acc: 0.3558


100%|██████████| 69/69 [00:03<00:00, 19.48it/s]


Epoch: 015, Train Loss: 2.2184, Train Acc: 0.4028


100%|██████████| 69/69 [00:03<00:00, 19.75it/s]


Epoch: 016, Train Loss: 1.8665, Train Acc: 0.3658


100%|██████████| 69/69 [00:03<00:00, 20.72it/s]


Epoch: 017, Train Loss: 2.1811, Train Acc: 0.3664


100%|██████████| 69/69 [00:03<00:00, 19.51it/s]


Epoch: 018, Train Loss: 2.0308, Train Acc: 0.4184


100%|██████████| 69/69 [00:03<00:00, 20.46it/s]


Epoch: 019, Train Loss: 1.9985, Train Acc: 0.4227


100%|██████████| 69/69 [00:03<00:00, 20.31it/s]


Epoch: 020, Train Loss: 2.5752, Train Acc: 0.3959


100%|██████████| 69/69 [00:03<00:00, 19.26it/s]


Epoch: 021, Train Loss: 1.6070, Train Acc: 0.4065


100%|██████████| 69/69 [00:03<00:00, 18.39it/s]


Epoch: 022, Train Loss: 1.6461, Train Acc: 0.4198


100%|██████████| 69/69 [00:03<00:00, 19.24it/s]


Epoch: 023, Train Loss: 1.6541, Train Acc: 0.4216


100%|██████████| 69/69 [00:03<00:00, 20.29it/s]


Epoch: 024, Train Loss: 2.1744, Train Acc: 0.4315


100%|██████████| 69/69 [00:03<00:00, 19.81it/s]


Epoch: 025, Train Loss: 1.9092, Train Acc: 0.4016


100%|██████████| 69/69 [00:03<00:00, 20.40it/s]


Epoch: 026, Train Loss: 1.6840, Train Acc: 0.4194


100%|██████████| 69/69 [00:03<00:00, 19.04it/s]


Epoch: 027, Train Loss: 1.9189, Train Acc: 0.4343


100%|██████████| 69/69 [00:03<00:00, 20.28it/s]


Epoch: 028, Train Loss: 1.6168, Train Acc: 0.4390


100%|██████████| 69/69 [00:03<00:00, 20.01it/s]


Epoch: 029, Train Loss: 1.7934, Train Acc: 0.4411


100%|██████████| 69/69 [00:03<00:00, 19.86it/s]


Epoch: 030, Train Loss: 2.4397, Train Acc: 0.4309


100%|██████████| 69/69 [00:03<00:00, 18.64it/s]


Epoch: 031, Train Loss: 1.3060, Train Acc: 0.4421


100%|██████████| 69/69 [00:03<00:00, 19.86it/s]


Epoch: 032, Train Loss: 1.7864, Train Acc: 0.4489


100%|██████████| 69/69 [00:03<00:00, 20.16it/s]


Epoch: 033, Train Loss: 1.7867, Train Acc: 0.4251


100%|██████████| 69/69 [00:03<00:00, 19.26it/s]


Epoch: 034, Train Loss: 1.8061, Train Acc: 0.3717


100%|██████████| 69/69 [00:03<00:00, 19.35it/s]


Epoch: 035, Train Loss: 1.7425, Train Acc: 0.4325


100%|██████████| 69/69 [00:03<00:00, 19.57it/s]


Epoch: 036, Train Loss: 1.8626, Train Acc: 0.4243


100%|██████████| 69/69 [00:03<00:00, 19.33it/s]


Epoch: 037, Train Loss: 1.6109, Train Acc: 0.4511


100%|██████████| 69/69 [00:03<00:00, 19.79it/s]


Epoch: 038, Train Loss: 2.2455, Train Acc: 0.4036


100%|██████████| 69/69 [00:03<00:00, 19.51it/s]


Epoch: 039, Train Loss: 1.9510, Train Acc: 0.4564


100%|██████████| 69/69 [00:03<00:00, 19.31it/s]


Epoch: 040, Train Loss: 2.0154, Train Acc: 0.4405


100%|██████████| 69/69 [00:03<00:00, 19.74it/s]


Epoch: 041, Train Loss: 1.6766, Train Acc: 0.4716


100%|██████████| 69/69 [00:03<00:00, 19.52it/s]


Epoch: 042, Train Loss: 1.4322, Train Acc: 0.4315


100%|██████████| 69/69 [00:03<00:00, 19.69it/s]


Epoch: 043, Train Loss: 1.6932, Train Acc: 0.4098


100%|██████████| 69/69 [00:03<00:00, 20.19it/s]


Epoch: 044, Train Loss: 2.0459, Train Acc: 0.4583


100%|██████████| 69/69 [00:03<00:00, 19.90it/s]


Epoch: 045, Train Loss: 3.3442, Train Acc: 0.3971


100%|██████████| 69/69 [00:03<00:00, 19.44it/s]


Epoch: 046, Train Loss: 1.7320, Train Acc: 0.4679


100%|██████████| 69/69 [00:03<00:00, 19.22it/s]


Epoch: 047, Train Loss: 1.6377, Train Acc: 0.4636


100%|██████████| 69/69 [00:03<00:00, 19.09it/s]


Epoch: 048, Train Loss: 2.2952, Train Acc: 0.4736


100%|██████████| 69/69 [00:03<00:00, 19.29it/s]


Epoch: 049, Train Loss: 1.6665, Train Acc: 0.4560


100%|██████████| 69/69 [00:03<00:00, 19.35it/s]


Epoch: 050, Train Loss: 1.7885, Train Acc: 0.4675


100%|██████████| 69/69 [00:03<00:00, 19.03it/s]


Epoch: 051, Train Loss: 1.9572, Train Acc: 0.4540


100%|██████████| 69/69 [00:03<00:00, 19.19it/s]


Epoch: 052, Train Loss: 1.4841, Train Acc: 0.4791


100%|██████████| 69/69 [00:03<00:00, 19.36it/s]


Epoch: 053, Train Loss: 1.1682, Train Acc: 0.4669


100%|██████████| 69/69 [00:03<00:00, 19.17it/s]


Epoch: 054, Train Loss: 1.9572, Train Acc: 0.4499


100%|██████████| 69/69 [00:03<00:00, 19.14it/s]


Epoch: 055, Train Loss: 1.6222, Train Acc: 0.4757


100%|██████████| 69/69 [00:03<00:00, 18.83it/s]


Epoch: 056, Train Loss: 1.3394, Train Acc: 0.4478


100%|██████████| 69/69 [00:03<00:00, 18.82it/s]


Epoch: 057, Train Loss: 1.1888, Train Acc: 0.4808


100%|██████████| 69/69 [00:03<00:00, 18.54it/s]


Epoch: 058, Train Loss: 1.6050, Train Acc: 0.4583


100%|██████████| 69/69 [00:03<00:00, 19.52it/s]


Epoch: 059, Train Loss: 1.8312, Train Acc: 0.4787


100%|██████████| 69/69 [00:04<00:00, 16.18it/s]


Epoch: 060, Train Loss: 1.8948, Train Acc: 0.4511


100%|██████████| 69/69 [00:03<00:00, 19.17it/s]


Epoch: 061, Train Loss: 1.2188, Train Acc: 0.4572


100%|██████████| 69/69 [00:03<00:00, 18.64it/s]


Epoch: 062, Train Loss: 1.8156, Train Acc: 0.4742


100%|██████████| 69/69 [00:03<00:00, 19.34it/s]


Epoch: 063, Train Loss: 2.0451, Train Acc: 0.4587


100%|██████████| 69/69 [00:03<00:00, 18.83it/s]


Epoch: 064, Train Loss: 1.8069, Train Acc: 0.4587


100%|██████████| 69/69 [00:03<00:00, 19.20it/s]


Epoch: 065, Train Loss: 1.6148, Train Acc: 0.4617


100%|██████████| 69/69 [00:03<00:00, 18.74it/s]


Epoch: 066, Train Loss: 1.8489, Train Acc: 0.4812


100%|██████████| 69/69 [00:03<00:00, 19.13it/s]


Epoch: 067, Train Loss: 2.1653, Train Acc: 0.4842


100%|██████████| 69/69 [00:03<00:00, 18.48it/s]


Epoch: 068, Train Loss: 2.0774, Train Acc: 0.5010


100%|██████████| 69/69 [00:03<00:00, 18.91it/s]


Epoch: 069, Train Loss: 1.3640, Train Acc: 0.4975


100%|██████████| 69/69 [00:03<00:00, 19.03it/s]


Epoch: 070, Train Loss: 1.7543, Train Acc: 0.4527


100%|██████████| 69/69 [00:03<00:00, 19.09it/s]


Epoch: 071, Train Loss: 48.1239, Train Acc: 0.4781


100%|██████████| 69/69 [00:03<00:00, 18.68it/s]


Epoch: 072, Train Loss: 1.5921, Train Acc: 0.4793


100%|██████████| 69/69 [00:03<00:00, 18.43it/s]


Epoch: 073, Train Loss: 1.6284, Train Acc: 0.4472


100%|██████████| 69/69 [00:03<00:00, 18.98it/s]


Epoch: 074, Train Loss: 1.9929, Train Acc: 0.4288


100%|██████████| 69/69 [00:03<00:00, 18.98it/s]


Epoch: 075, Train Loss: 1.4489, Train Acc: 0.4965


100%|██████████| 69/69 [00:03<00:00, 17.79it/s]


Epoch: 076, Train Loss: 1.7725, Train Acc: 0.5055


100%|██████████| 69/69 [00:03<00:00, 17.26it/s]


Epoch: 077, Train Loss: 1.8825, Train Acc: 0.4982


100%|██████████| 69/69 [00:03<00:00, 19.62it/s]


Epoch: 078, Train Loss: 1.7906, Train Acc: 0.4734


100%|██████████| 69/69 [00:03<00:00, 17.82it/s]


Epoch: 079, Train Loss: 1.7278, Train Acc: 0.4982


100%|██████████| 69/69 [00:03<00:00, 18.76it/s]


Epoch: 080, Train Loss: 1.2210, Train Acc: 0.4525


100%|██████████| 69/69 [00:03<00:00, 19.39it/s]


Epoch: 081, Train Loss: 1.8563, Train Acc: 0.4906


100%|██████████| 69/69 [00:03<00:00, 18.85it/s]


Epoch: 082, Train Loss: 1.4899, Train Acc: 0.4816


100%|██████████| 69/69 [00:03<00:00, 18.56it/s]


Epoch: 083, Train Loss: 1.5559, Train Acc: 0.4098


100%|██████████| 69/69 [00:04<00:00, 16.70it/s]


Epoch: 084, Train Loss: 1.9674, Train Acc: 0.5029


100%|██████████| 69/69 [00:05<00:00, 12.69it/s]


Epoch: 085, Train Loss: 1.5145, Train Acc: 0.4894


100%|██████████| 69/69 [00:04<00:00, 16.63it/s]


Epoch: 086, Train Loss: 1.3625, Train Acc: 0.5131


100%|██████████| 69/69 [00:03<00:00, 17.51it/s]


Epoch: 087, Train Loss: 1.4233, Train Acc: 0.5127


100%|██████████| 69/69 [00:03<00:00, 18.08it/s]


Epoch: 088, Train Loss: 1.9251, Train Acc: 0.4096


100%|██████████| 69/69 [00:03<00:00, 18.27it/s]


Epoch: 089, Train Loss: 1.7029, Train Acc: 0.4740


100%|██████████| 69/69 [00:03<00:00, 19.25it/s]


Epoch: 090, Train Loss: 1.6312, Train Acc: 0.5149


100%|██████████| 69/69 [00:03<00:00, 18.54it/s]


Epoch: 091, Train Loss: 1.9678, Train Acc: 0.4609


100%|██████████| 69/69 [00:03<00:00, 18.68it/s]


Epoch: 092, Train Loss: 1.9403, Train Acc: 0.4961


100%|██████████| 69/69 [00:03<00:00, 18.45it/s]


Epoch: 093, Train Loss: 1.6749, Train Acc: 0.4636


100%|██████████| 69/69 [00:03<00:00, 18.01it/s]


Epoch: 094, Train Loss: 1.5936, Train Acc: 0.5168


100%|██████████| 69/69 [00:03<00:00, 18.90it/s]


Epoch: 095, Train Loss: 1.4029, Train Acc: 0.5082


100%|██████████| 69/69 [00:03<00:00, 18.38it/s]


Epoch: 096, Train Loss: 1.4459, Train Acc: 0.4892


100%|██████████| 69/69 [00:03<00:00, 18.88it/s]


Epoch: 097, Train Loss: 1.5849, Train Acc: 0.5072


100%|██████████| 69/69 [00:03<00:00, 18.87it/s]


Epoch: 098, Train Loss: 2.2545, Train Acc: 0.4988


100%|██████████| 69/69 [00:03<00:00, 19.08it/s]


Epoch: 099, Train Loss: 2.1586, Train Acc: 0.5125


In [175]:
PATH = './model/model_0_full.pt'
torch.save({
            'epoch': epochs,
            'model_state_dict': model_0_full.state_dict(),
            'optimizer_state_dict': optimizer_0_full.state_dict(),
            }, PATH)

In [177]:
PATH = './model/model_0.pt'
PATH_full = './model/model_0_full.pt'
load_checkpoint(model_0, optimizer_0, PATH)
load_checkpoint(model_0_full, optimizer_0_full, PATH_full)

checkpoint loaded
checkpoint loaded


In [185]:
import csv
def write_csv(fname, pred_proba):
    with open(f'./prediction/{fname}.csv', 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        lst = list()
        for i in range(18):
            lst.append('class'+str(i))
        lst.insert(0, "name")
        writer.writerow(lst)
        for i, protein in enumerate(test_name):
            lst = pred_proba[i,:].tolist()
            lst.insert(0, protein)
            writer.writerow(lst)

In [186]:
pred_proba_0 = predict(model_0)
write_csv('gnn0',pred_proba_0)


In [187]:
write_csv('gnn0_full', predict(model_0_full))