In [2]:
from utils import evaluate_graphs_accuracy
from models import GcnEncoderGraph
import torch
from torch_geometric.loader import DataLoader
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from immune import Immune

# training setting
batch_size = 128
lr = 1e-4
epochs = 100
num_workers = 8

train_set = Immune(mode='training')
test_set = Immune(mode="testing")
val_set = Immune(mode='evaluation')

test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

model = GcnEncoderGraph(input_dim=train_set.num_features,
                        hidden_dim=128,
                        embedding_dim=32,
                        num_layers=2,
                        pred_hidden_dims=[64, 32],
                        label_dim=2)

optimizer = torch.optim.Adam(model.parameters(),
                                lr=lr,
                                weight_decay=1e-4)

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


criterion = CrossEntropyLoss(weight=torch.tensor([0.4, 0.6]).to(device))

model.to(device)
best_accuracy = 0

for epoch in range(1, epochs+1):
    model.train()
    loss_all = 0
    optimizer.zero_grad()
    for data in train_loader:
        x = data.x.to(device)
        edge_index = data.edge_index.to(device)
        # edge_index = torch.tensor([[], []], dtype=torch.long).to(device)
        y = data.y.to(device)
        batch = data.batch.to(device)
        y_pred = model(x, edge_index, batch)
        loss = criterion(y_pred, y)

        loss.backward()
        loss_all += loss.item() * data.num_graphs
    optimizer.step()

    if epoch % 1 == 0:
        accuracy_test = evaluate_graphs_accuracy(test_loader, model, device)
        if accuracy_test > best_accuracy:
            torch.save(model.state_dict(), './params/immune_net.ckpt')
            best_accuracy = accuracy_test
            accuracy_val = evaluate_graphs_accuracy(val_loader, model, device)
            print(f'Epoch: {epoch:03d}, Loss: {loss_all:.4f}, Curr_Best: {best_accuracy:.4f}, Val: {accuracy_val:.4f}')

Epoch: 001, Loss: 127.7875, Curr_Best: 0.3043, Val: 0.3913


In [27]:
model.load_state_dict(torch.load('./params/immune.ckpt'))

<All keys matched successfully>

In [13]:
count = 0
for data in train_set:
    if data.y == 0:
        count += 1
count/len(train_set)

0.6304347826086957

In [14]:
count = 0
for data in test_set:
    if data.y == 0:
        count += 1
count/len(test_set)

0.6956521739130435

In [1]:
count = 0
for data in val_set:
    if data.y == 0:
        count += 1
count/len(val_set)

NameError: name 'val_set' is not defined