In [1]:
import torch
import cl_graph_bert as cgm
from torch import nn
import torch.nn.functional as F

import dgl
g = dgl.load_graphs("./graphs/industrial_and_scientific_5_core.dgl")[0][0]

model = cgm.CLIPGraphModel(
    rel_types = g.etypes,
    emb_types = {x: g.number_of_nodes(x) for x in g.ntypes} 
)

device = "cpu"

model.load_state_dict(torch.load("./base_statedict_6668.011407389138.pt", map_location=torch.device(device)))
epochs = 20

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Full Graph - Frozen GCN

In [2]:
class Classifier(nn.Module):
    def __init__(self, gnn_model, out_dim, freeze_base=True):
        super().__init__()
        self.mdl = gnn_model
        self.freeze = freeze_base
        self.linear = nn.Linear(model.graph_projection.projection_dim, out_dim)
        self.act = nn.ReLU()
        self.soft = nn.Softmax()

    def forward(self, g):
        if self.freeze:
            with torch.no_grad():
                x = self.mdl.graph_projection(self.mdl.graph_model(g)["Review"].double()).float()
        else:
            x = self.mdl.graph_projection(self.mdl.graph_model(g)["Review"].double()).float()
        x = self.act(x)
        x = self.linear(x)
        out = self.soft(x)
        return out

In [3]:
gnn_cls = Classifier(model, 2, freeze_base=True)

labels = g.nodes["Review"].data["Positive"]
train_mask = g.nodes["Review"].data["train_mask"]
test_mask = g.nodes["Review"].data["test_mask"]

opt = torch.optim.Adam(gnn_cls.parameters())

for epoch in range(epochs):
    model.train()
    logits = gnn_cls(g)
    loss = F.cross_entropy(logits[train_mask == 1], labels[train_mask == 1].type(torch.long))
    opt.zero_grad()
    loss.backward()
    opt.step()
    print("Epoch:", epoch, "Loss:", loss.item())

  out = self.soft(x)


Epoch: 0 Loss: 1.1762685775756836
Epoch: 1 Loss: 1.1353669166564941
Epoch: 2 Loss: 1.1285780668258667
Epoch: 3 Loss: 0.420133501291275
Epoch: 4 Loss: 0.42006874084472656
Epoch: 5 Loss: 0.42010149359703064
Epoch: 6 Loss: 0.4199940860271454
Epoch: 7 Loss: 0.42008885741233826
Epoch: 8 Loss: 0.42007431387901306
Epoch: 9 Loss: 0.42003756761550903
Epoch: 10 Loss: 0.4199889302253723
Epoch: 11 Loss: 0.4200495481491089
Epoch: 12 Loss: 0.4200839102268219
Epoch: 13 Loss: 0.42009052634239197
Epoch: 14 Loss: 0.42005789279937744
Epoch: 15 Loss: 0.4200589954853058
Epoch: 16 Loss: 0.4200672209262848
Epoch: 17 Loss: 0.4199753999710083
Epoch: 18 Loss: 0.42000052332878113
Epoch: 19 Loss: 0.42002803087234497


In [4]:
gnn_cls.eval()
preds = gnn_cls(g)[test_mask == 1].detach().numpy()
y_test = labels[test_mask == 1]

  out = self.soft(x)


In [5]:
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import accuracy_score

def metrics(y_test, preds):
    mets = {}
    mets["f1"] = f1_score(y_test, preds.argmax(1))
    mets["recall"] = recall_score(y_test, preds.argmax(1))
    mets["precision"] = precision_score(y_test, preds.argmax(1))
    mets["acc"] = accuracy_score(y_test, preds.argmax(1))
    return mets

metrics(y_test, preds)

{'f1': 0.9343215065955035,
 'recall': 1.0,
 'precision': 0.8767386340045672,
 'acc': 0.8767386340045672}

## Full Graph - Fine-tunable GCN

In [6]:
gnn_cls_unfrozen = Classifier(model, 2, freeze_base=False)

labels = g.nodes["Review"].data["Positive"]
train_mask = g.nodes["Review"].data["train_mask"]
test_mask = g.nodes["Review"].data["test_mask"]

opt =  torch.optim.Adam(
        [
            {"params":gnn_cls_unfrozen.mdl.parameters(), "lr": 0.00001}, 
            {"params":gnn_cls_unfrozen.linear.parameters(), "lr": 0.001}      
    ])

for epoch in range(epochs):
    model.train()
    logits = gnn_cls_unfrozen(g)
    loss = F.cross_entropy(logits[train_mask == 1], labels[train_mask == 1].type(torch.long))
    opt.zero_grad()
    loss.backward()
    opt.step()
    print("Epoch:", epoch, "Loss:", loss.item())

gnn_cls_unfrozen.eval()
preds = gnn_cls_unfrozen(g)[test_mask == 1].detach().numpy()
y_test = labels[test_mask == 1]

metrics(y_test, preds)

  out = self.soft(x)


Epoch: 0 Loss: 0.4409789443016052
Epoch: 1 Loss: 0.4202536344528198
Epoch: 2 Loss: 0.42011967301368713
Epoch: 3 Loss: 0.420083612203598
Epoch: 4 Loss: 0.4200218617916107
Epoch: 5 Loss: 0.42001932859420776
Epoch: 6 Loss: 0.4200417399406433
Epoch: 7 Loss: 0.4200417399406433
Epoch: 8 Loss: 0.42007875442504883
Epoch: 9 Loss: 0.4200665354728699
Epoch: 10 Loss: 0.4199981093406677
Epoch: 11 Loss: 0.4199976623058319
Epoch: 12 Loss: 0.4200596511363983
Epoch: 13 Loss: 0.4200640022754669
Epoch: 14 Loss: 0.42006736993789673
Epoch: 15 Loss: 0.42004090547561646
Epoch: 16 Loss: 0.42000678181648254
Epoch: 17 Loss: 0.42009252309799194
Epoch: 18 Loss: 0.41996869444847107
Epoch: 19 Loss: 0.4200308322906494


{'f1': 0.9343215065955035,
 'recall': 1.0,
 'precision': 0.8767386340045672,
 'acc': 0.8767386340045672}

In [17]:
review_out = model.graph_projection(model.graph_model(g)["Review"].double()).float()

In [16]:
prod_id = 1
review_id = g.successors(prod_id, 'rev_REVIEW_OF')
review_id

tensor([ 3,  4,  5,  7,  8, 11, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 28,
        29, 30, 32, 33])

In [19]:
review_out[review_id]

tensor([[ 2.7271e+02,  1.3971e+01,  4.9848e-02,  ...,  2.9182e+02,
          3.4543e+03,  2.4162e+02],
        [ 2.5599e+02,  1.3208e+01, -2.7077e+00,  ...,  2.7823e+02,
          3.3170e+03,  2.2543e+02],
        [ 2.7113e+02,  1.8487e+01,  1.0898e+01,  ...,  2.8357e+02,
          3.2502e+03,  2.4358e+02],
        ...,
        [ 3.2835e+02,  3.2599e+01,  3.3280e+01,  ...,  3.3396e+02,
          3.6118e+03,  3.0059e+02],
        [ 2.7287e+02, -2.1590e-01, -3.2043e+01,  ...,  3.1011e+02,
          3.9872e+03,  2.3210e+02],
        [ 2.8199e+02,  2.8142e+01,  2.6989e+01,  ...,  2.8999e+02,
          3.1588e+03,  2.5726e+02]], grad_fn=<IndexBackward0>)

In [None]:


model.language_model(input_ids = tokens['input_ids'], 
           attention_mask=tokens['attention_mask'], 
           token_type_ids=tokens['token_type_ids']).last_hidden_state[:,0].type(torch.float64)