In [7]:
@torch.no_grad()
def evaluate_graphs_accuracy(
        test_loader,
        model,
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        require_cm=False):

    model.to(device)
    model.eval()
    correct = 0
    all_labels = []
    if require_cm:
        all_predictions = []

    for data in test_loader:
        output = model(
            data.x.to(device),
            torch.tensor([[], []], dtype=torch.long).to(device),
            data.batch.to(device),
        )
        predictions = output.argmax(dim=1).cpu().numpy().reshape(-1)
        labels = data.y.cpu().numpy().reshape(-1)
        all_labels.extend(labels)
        correct += float((predictions == labels).sum())
        if require_cm:
            all_predictions.extend(predictions)
    accuracy = correct / len(test_loader.dataset)
    if require_cm:
        cm = confusion_matrix(all_labels, all_predictions)
        return accuracy, cm
    else:
        return accuracy

In [1]:
from models import GcnEncoderGraph
import torch
from torch_geometric.loader import DataLoader
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from immune import Immune

# training setting
batch_size = 128
lr = 1e-3
epochs = 200
num_workers = 16

train_set = Immune(mode='training')
test_set = Immune(mode="testing")
val_set = Immune(mode='evaluation')

test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

  from .autonotebook import tqdm as notebook_tqdm
Processing...
  edge_index = torch.tensor(edge_index, dtype=torch.long)
Done!


In [24]:
from utils import evaluate_graphs_accuracy
from models import GcnEncoderGraph
import torch
from torch_geometric.loader import DataLoader
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from immune import Immune

# training setting
batch_size = 128
lr = 1e-2
epochs = 200
num_workers = 16

train_set = Immune(mode='training')
test_set = Immune(mode="testing")
val_set = Immune(mode='evaluation')

test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

model = GcnEncoderGraph(input_dim=train_set.num_features,
                        hidden_dim=8,
                        embedding_dim=16,
                        num_layers=3,
                        pred_hidden_dims=[],
                        label_dim=2)

optimizer = torch.optim.Adam(model.parameters(),
                                lr=lr,
                                weight_decay=1e-4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


criterion = CrossEntropyLoss(weight=torch.tensor([0.4, 0.6]).to(device))

model.to(device)
best_accuracy = 0

for epoch in range(1, epochs+1):
    model.train()
    loss_all = 0
    optimizer.zero_grad()
    for data in train_loader:
        x = data.x.to(device)
        edge_index = data.edge_index.to(device)
        # edge_index = torch.tensor([[], []], dtype=torch.long).to(device)
        y = data.y.to(device)
        batch = data.batch.to(device)
        y_pred = model(x, edge_index, batch)
        loss = criterion(y_pred, y)

        loss.backward()
        loss_all += loss.item() * data.num_graphs
    optimizer.step()

    if epoch % 1 == 0:
        accuracy_test = evaluate_graphs_accuracy(test_loader, model, device)
        if accuracy_test > best_accuracy:
            torch.save(model.state_dict(), './params/immune.ckpt')
            best_accuracy = accuracy_test
            accuracy_val = evaluate_graphs_accuracy(val_loader, model, device)
            print(f'Epoch: {epoch:03d}, Loss: {loss_all:.4f}, Curr_Best: {best_accuracy:.4f}, Val: {accuracy_val:.4f}')

Epoch: 001, Loss: 135.1032, Curr_Best: 0.5217, Val: 0.6087
Epoch: 006, Loss: 127.5765, Curr_Best: 0.6957, Val: 0.6087
Epoch: 019, Loss: 127.5195, Curr_Best: 0.7391, Val: 0.5217


In [8]:
y.shape

torch.Size([128])

In [7]:
from torch_geometric.utils import to_dense_adj, dense_to_sparse

@torch.no_grad()
def random_test(test_loader, model, device=('cuda' if torch.cuda.is_available() else 'cpu')):
    model.eval()
    correct = 0

    for data in test_loader:
        adj = to_dense_adj(data.edge_index)
        random_adj = torch.randn_like(adj)
        random_adj = torch.where(random_adj > 0, random_adj, torch.zeros_like(random_adj))
        random_edge, weight = dense_to_sparse(random_adj)
        topk = data.edge_index.shape[1]
        threshold = weight.sort(descending=True).values.topk(topk).values[-1]
        random_edge = random_edge.T[weight>threshold].T
        output = model(
            data.x.to(device),
            random_edge.to(device),
            data.batch.to(device),
        )
        correct += float(output.argmax(dim=1).eq(data.y.to(device)).sum().item())
    return correct / (len(test_loader.dataset))

In [33]:
torch.cuda.empty_cache()

In [19]:
test_loader = DataLoader(test_set, batch_size=16, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_set, batch_size=16, shuffle=False, num_workers=num_workers)
train_loader = DataLoader(train_set, batch_size=16, shuffle=True, num_workers=num_workers)

In [27]:
model.load_state_dict(torch.load('./params/immune.ckpt'))

<All keys matched successfully>

In [None]:
random_test(train_loader, model, device)

In [11]:
random_test(test_loader, model, device)

0.6956521739130435

In [31]:
random_test(val_loader, model, device)

0.6086956521739131

In [30]:
evaluate_graphs_accuracy(train_loader, model, device)

0.6032608695652174

In [28]:
evaluate_graphs_accuracy(test_loader, model, device)

0.7391304347826086

In [29]:
evaluate_graphs_accuracy(val_loader, model, device)

0.5217391304347826

In [13]:
count = 0
for data in train_set:
    if data.y == 0:
        count += 1
count/len(train_set)

0.6304347826086957

In [14]:
count = 0
for data in test_set:
    if data.y == 0:
        count += 1
count/len(test_set)

0.6956521739130435

In [15]:
count = 0
for data in val_set:
    if data.y == 0:
        count += 1
count/len(val_set)

0.6086956521739131