In [9]:
import pandas as pd
import numpy as np
import torch
from models.gnn import GNN
from models.mlp import MLP
from utils.dataloader import GetDataloader
from tqdm import trange
from tqdm import tqdm
import torch.nn.functional as F
import os
import os.path as osp


In [10]:
class Trainer:
    def __init__(self, dataset_name="cora", sentence_encoder="ST", model_type="mlp", device=0, state_dict_path="./state_dicts"):
        self.dataset_name = dataset_name
        self.sentence_encoder = sentence_encoder
        self.model_type = model_type.lower()
        self.device = "cpu" if device==123 else f"cuda:{device}"

        self.state_dict_path = osp.join(state_dict_path, f"{self.dataset_name}_{self.sentence_encoder}", f"{model_type}")
        if not osp.exists(self.state_dict_path):
            os.makedirs(self.state_dict_path)

        dataloader = GetDataloader(dataset_name=self.dataset_name, sentence_encoder=self.sentence_encoder, device=self.device)
        self.data = dataloader.get_data()
        self.num_classes = len(self.data.y.squeeze().unique())

        if self.model_type == "mlp":
            self.model = MLP(num_classes=self.num_classes)
        elif self.model_type in ["gcn", "gat", "sage", "graphsage"]:
            self.model = GNN(name=self.model_type, num_classes=self.num_classes)
        else:
            raise NotImplementedError

        self.data = self.data.to(device=self.device)
        self.model = self.model.to(device=self.device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=1e-4)

    def train(self, mask_idx):
        best_val_acc = 0

        # total of 10 training masks are present for each dataset
        for e in range(1, 201):
            self.model.train()
            self.optimizer.zero_grad()

            out = self.model(self.data)
            train_pred = out.argmax(dim=1)

            train_ypred = train_pred[self.data.train_masks[mask_idx]]
            train_ytrue = self.data.y[self.data.train_masks[mask_idx]]
            train_correct = (train_ypred == train_ytrue).sum()

            train_acc = int(train_correct) / int(self.data.train_masks[mask_idx].sum())
            train_loss = F.cross_entropy(out[self.data.train_masks[mask_idx]], train_ytrue)

            if e % 10 == 0:
                val_ypred = train_pred[self.data.val_masks[mask_idx]]
                val_ytrue = self.data.y[self.data.val_masks[mask_idx]]
                val_correct = (val_ypred == val_ytrue).sum()

                val_acc = int(val_correct) / int(self.data.val_masks[mask_idx].sum())
                val_loss = F.cross_entropy(out[self.data.val_masks[mask_idx]], val_ytrue)

                # print(f"Epoch {e} => Train Accuracy : {train_acc} | Train Loss : {train_loss}")
                # print(f"Validation Accuracy : {val_acc} | Validation Loss : {val_loss}")

                if val_acc > best_val_acc:
                    best_val_acc = val_acc

                    save_path = osp.join(self.state_dict_path, f"Mask_{mask_idx}_best_state_dict.pt")
                    if osp.exists(save_path):
                        os.remove(save_path)

                    model_info = {"state_dict" : self.model.state_dict(),
                                    "optimizer_state_dict" : self.optimizer.state_dict(),
                                    "val_accuracy" : best_val_acc,
                                    "val_loss" : val_loss}

                    torch.save(model_info, save_path)

            train_loss.backward()
            self.optimizer.step()


        self.model.eval()
        with torch.inference_mode():
            out = self.model(self.data)
            pred = out.argmax(dim=1)

            ypred = pred[self.data.test_masks[mask_idx]]
            ytrue = self.data.y[self.data.test_masks[mask_idx]]
            test_correct = (ypred == ytrue).sum()

            test_acc = int(test_correct) / int(self.data.test_masks[mask_idx].sum())
            test_loss = float(F.cross_entropy(out[self.data.test_masks[mask_idx]], ytrue))

        return test_acc, test_loss

In [11]:
datasets = ["cora_ST", "cora_roberta", "pubmed_ST", "pubmed_roberta"]
model_types = ["mlp", "gcn", "gat", "sage"]

final_results = dict.fromkeys(datasets)

for dataset in tqdm(datasets):
    dataset_name = dataset.split("_")[0]
    sent_encoder = dataset.split("_")[1]

    results = dict.fromkeys(model_types)

    for model_t in model_types:
        test_acc = np.zeros(10)
        test_losses = np.zeros(10)

        for idx in range(10):
            trainer = Trainer(dataset_name=dataset_name, sentence_encoder=sent_encoder, model_type=model_t, device=0)
            acc, loss = trainer.train(idx)

            test_acc[idx] = acc
            test_losses[idx] = loss

        avg_test_acc = test_acc.mean()
        std_test_acc = test_acc.std()

        avg_test_loss = test_losses.mean()
        std_test_loss = test_losses.std()

        results[model_t] = {"Test Accuracy (avg)" : avg_test_acc,
                            "Test Accuracy (std)" : std_test_acc,
                            "Test Loss (avg)" : avg_test_loss,
                            "Test Loss (std)" : std_test_loss}

    final_results[dataset] = results

100%|██████████| 4/4 [13:25<00:00, 201.38s/it]


In [12]:
final_results_ = {(i, j) : final_results[i][j] for i in final_results.keys() for j in final_results[i].keys()}
df = pd.DataFrame.from_dict(final_results_, orient="index")
df

Unnamed: 0,Unnamed: 1,Test Accuracy (avg),Test Accuracy (std),Test Loss (avg),Test Loss (std)
cora_ST,mlp,0.640377,0.024404,1.378926,0.123853
cora_ST,gcn,0.806093,0.009473,0.618293,0.027128
cora_ST,gat,0.796422,0.007839,0.639135,0.039659
cora_ST,sage,0.790232,0.011012,1.322649,0.01197
cora_roberta,mlp,0.484623,0.026222,2.717465,0.260291
cora_roberta,gcn,0.743182,0.016231,0.899008,0.055134
cora_roberta,gat,0.717892,0.014368,1.060394,0.103915
cora_roberta,sage,0.729014,0.020994,1.390071,0.017474
pubmed_ST,mlp,0.71791,0.017322,1.028179,0.105397
pubmed_ST,gcn,0.783071,0.014785,0.556836,0.032271
