In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.datasets import IMDB
from torch_geometric.nn import HGTConv
from torch_geometric.transforms import ToUndirected

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1) Load dataset
dataset = IMDB(root='data/IMDB', transform=ToUndirected())
data = dataset[0]  # HeteroData
print(data)

HeteroData(
  movie={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278],
  },
  director={ x=[2081, 3066] },
  actor={ x=[5257, 3066] },
  (movie, to, director)={ edge_index=[2, 4278] },
  (movie, to, actor)={ edge_index=[2, 12828] },
  (director, to, movie)={ edge_index=[2, 4278] },
  (actor, to, movie)={ edge_index=[2, 12828] },
  (director, rev_to, movie)={ edge_index=[2, 4278] },
  (actor, rev_to, movie)={ edge_index=[2, 12828] },
  (movie, rev_to, director)={ edge_index=[2, 4278] },
  (movie, rev_to, actor)={ edge_index=[2, 12828] }
)


In [3]:
print(type(data))

<class 'torch_geometric.data.hetero_data.HeteroData'>


In [4]:
for t in ['movie', 'actor']:
    x = data[t].x
    data[t].x = x / x.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)

In [5]:
for etype in list(data.edge_types):
    if etype[1] == 'rev_to':
        del data[etype]

In [6]:
metadata = data.metadata()
metadata

(['movie', 'director', 'actor'],
 [('movie', 'to', 'director'),
  ('movie', 'to', 'actor'),
  ('director', 'to', 'movie'),
  ('actor', 'to', 'movie')])

In [7]:
# data node types: 'movie', 'actor', 'director'
# labels exist on movie nodes: data['movie'].y (3 classes)
# masks: data['movie'].train_mask / val_mask / test_mask

In [8]:
hidden = 128

x_dict = {}
proj = nn.ModuleDict()
embeddings = nn.ModuleDict()

In [9]:
data.edge_index_dict['movie', 'to', 'actor'].shape


torch.Size([2, 12828])

In [10]:
data.edge_index_dict['movie', 'to', 'actor']

tensor([[   0,    0,    0,  ..., 4277, 4277, 4277],
        [ 674, 2394, 5129,  ...,  100, 1078, 1439]])

In [11]:
for k in ['train_mask','val_mask','test_mask']:
    m = data['movie'][k]
    print(k, m.shape, m.dtype, m.sum().item())  # count of True = split size

train_mask torch.Size([4278]) torch.bool 400
val_mask torch.Size([4278]) torch.bool 400
test_mask torch.Size([4278]) torch.bool 3478


In [12]:
"""
    Forces actor/director to use learnable embeddings.
    Keeps projection for node types with real features (e.g., movie).
"""
# Hybrid feature construction
hybrid_types = {'actor', 'director'}
embeddings = nn.ModuleDict()
proj = nn.ModuleDict()

for ntype in data.node_types:
    x = data[ntype].get('x', None)
    if ntype in hybrid_types:
        if x is not None:
            emb_dim = hidden // 2
            proj_dim = hidden - emb_dim
            proj[ntype] = nn.Linear(x.size(-1), proj_dim, bias=False)
            embeddings[ntype] = nn.Embedding(data[ntype].num_nodes, emb_dim)
        else:
            embeddings[ntype] = nn.Embedding(data[ntype].num_nodes, hidden)
    else:
        if x is not None:
            proj[ntype] = nn.Linear(x.size(-1), hidden, bias=False)
        else:
            embeddings[ntype] = nn.Embedding(data[ntype].num_nodes, hidden)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
embeddings = embeddings.to(device)
proj = proj.to(device)

In [13]:
embeddings

ModuleDict(
  (director): Embedding(2081, 64)
  (actor): Embedding(5257, 64)
)

In [14]:
# 3) Define HGT model
class HGTNet(nn.Module):
    def __init__(self, metadata, hidden=64, heads=4, layers=2, out_dim=3):
        super().__init__()
        self.input_dropout = nn.Dropout(0.2)
        self.layers = nn.ModuleList([
            HGTConv(in_channels=hidden, # this is from hetregenous graph transformer
                    out_channels=hidden,
                    metadata=metadata,
                    heads=heads)
            for _ in range(layers)
        ])
        self.dropout = nn.Dropout(0.5)
        self.cls = nn.Linear(hidden, out_dim)

    def forward(self, x_dict, edge_index_dict):
        h = {k: self.input_dropout(v) for k, v in x_dict.items()}  # input feature dropout
        for conv in self.layers:
            h = conv(h, edge_index_dict)  # dict -> dict
            h = {k: F.elu(v) for k, v in h.items()}
            h = {k: self.dropout(v) for k, v in h.items()}
        logits = self.cls(h['movie'])
        return logits, h  # return movie logits + all-type embeddings

model = HGTNet(metadata, hidden=hidden, heads=4, layers=3, out_dim=int(data['movie'].y.max().item()+1))
#opt = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-3)

model = model.to(device)
for ntype in proj: proj[ntype] = proj[ntype].to(device)
for ntype in embeddings: embeddings[ntype] = embeddings[ntype].to(device)


In [15]:
# Move tensor storages to device
for ntype in data.node_types:
    for k, v in data[ntype].items():
        data[ntype][k] = v.to(device)
for etype in data.edge_types:
    data[etype].edge_index = data[etype].edge_index.to(device)

y = data['movie'].y
train_mask = data['movie'].train_mask
val_mask = data['movie'].val_mask
test_mask = data['movie'].test_mask

def get_xdict():
    xd = {}
    for ntype in data.node_types:
        x = data[ntype].get('x', None)
        if (ntype in hybrid_types) and (ntype in proj) and (ntype in embeddings) and (x is not None):
            xd[ntype] = torch.cat([proj[ntype](x), embeddings[ntype].weight], dim=-1)
        elif ntype in proj and x is not None:
            xd[ntype] = proj[ntype](x)
        elif ntype in embeddings:
            xd[ntype] = embeddings[ntype].weight
        else:
            raise KeyError(f"Missing modules for node type: {ntype}")
    return {k: v.to(device) for k, v in xd.items()}


def build_x_dict():
    # Use the same logic in training to avoid drift
    return get_xdict()

@torch.no_grad()
def eval_split(split='val'):
    model.eval()
    xd = get_xdict()
    logits, hdict = model(xd, data.edge_index_dict)
    if split == 'val':
        mask = val_mask
    elif split == 'test':
        mask = test_mask
    else:
        mask = train_mask
    pred = logits[mask].argmax(dim=-1)
    acc = (pred == y[mask]).float().mean().item()
    return acc

In [16]:
# 4) Train
def build_x_dict():
    out = {}
    for ntype in data.node_types:
        if ntype in embeddings:
            out[ntype] = embeddings[ntype].weight
        else:
            out[ntype] = proj[ntype](data[ntype].x)
    return out

# Make sure optimizer sees these params:
proj_params = [p for m in proj.values() for p in m.parameters()]
emb_params  = [p for m in embeddings.values() for p in m.parameters()]
#opt = torch.optim.AdamW(list(model.parameters()) + proj_params + emb_params, lr=3e-3, weight_decay=1e-4)
opt = torch.optim.AdamW(
    list(model.parameters())
    + [p for m in proj.values() for p in m.parameters()]
    + [p for m in embeddings.values() for p in m.parameters()],
    lr=3e-3, weight_decay=1e-3
)
from copy import deepcopy
best = dict(val=-1, model=None, proj={}, emb={})
best_val, best_state = 0.0, None

for epoch in range(1, 300):
    model.train()
    #x_dict = {k: v.to(device) for k, v in build_x_dict().items()}
    #edge_index_dict = {k: v.to(device) for k, v in data.edge_index_dict.items()}
    x_dict = get_xdict()
    logits, _ = model(x_dict, data.edge_index_dict)
    loss = F.cross_entropy(logits[train_mask], data['movie'].y[train_mask].to(device),label_smoothing=0.1)
    opt.zero_grad(); loss.backward(); opt.step()

    if epoch % 10 == 0:
        val_acc = eval_split('val')
        if val_acc > best['val']:
            best['val'] = val_acc
            best['model'] = deepcopy(model.state_dict())
            best['proj']  = {k: deepcopy(m.state_dict()) for k,m in proj.items()}
            best['emb']   = {k: deepcopy(m.state_dict()) for k,m in embeddings.items()}
        print(f"Epoch {epoch:03d} | loss {loss.item():.4f} | val acc {val_acc:.4f}")

model.load_state_dict(best['model'])
for k,m in proj.items(): m.load_state_dict(best['proj'][k])
for k,m in embeddings.items(): m.load_state_dict(best['emb'][k])


Epoch 010 | loss 1.0542 | val acc 0.3925
Epoch 020 | loss 0.5942 | val acc 0.5425
Epoch 030 | loss 0.3401 | val acc 0.5450
Epoch 040 | loss 0.3153 | val acc 0.5375
Epoch 050 | loss 0.3067 | val acc 0.5350
Epoch 060 | loss 0.3035 | val acc 0.5550
Epoch 070 | loss 0.3021 | val acc 0.5475
Epoch 080 | loss 0.3001 | val acc 0.5600
Epoch 090 | loss 0.3000 | val acc 0.5775
Epoch 100 | loss 0.2993 | val acc 0.5800
Epoch 110 | loss 0.2994 | val acc 0.5775
Epoch 120 | loss 0.2990 | val acc 0.5625
Epoch 130 | loss 0.2978 | val acc 0.5625
Epoch 140 | loss 0.2984 | val acc 0.5700
Epoch 150 | loss 0.2977 | val acc 0.5750
Epoch 160 | loss 0.2980 | val acc 0.5650
Epoch 170 | loss 0.2975 | val acc 0.5650
Epoch 180 | loss 0.2973 | val acc 0.5725
Epoch 190 | loss 0.2969 | val acc 0.5675
Epoch 200 | loss 0.2974 | val acc 0.5675
Epoch 210 | loss 0.2965 | val acc 0.5725
Epoch 220 | loss 0.2971 | val acc 0.5700
Epoch 230 | loss 0.2969 | val acc 0.5675
Epoch 240 | loss 0.2975 | val acc 0.5675
Epoch 250 | loss

In [17]:
# 5) Test
if best_state:
    model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
test_acc = eval_split('test')
print(f"Test acc: {test_acc:.4f}")

Test acc: 0.5635


In [18]:
# 6) Get embeddings for downstream tasks (similarity, clustering, etc.)
@torch.no_grad()
def get_movie_embeddings():
    model.eval()
    xd = get_xdict()
    _, hdict = model(xd, data.edge_index_dict)
    return hdict['movie'].cpu()

movie_Z = get_movie_embeddings()
print("Movie embeddings shape:", movie_Z.shape)

Movie embeddings shape: torch.Size([4278, 128])


In [19]:
data.edge_index_dict.keys()

dict_keys([('movie', 'to', 'director'), ('movie', 'to', 'actor'), ('director', 'to', 'movie'), ('actor', 'to', 'movie')])

In [20]:
#data['movie'].x.shape, data['movie'].y.shape
data['movie'].y.unique(return_counts=True) 

(tensor([0, 1, 2], device='mps:0'), tensor([1135, 1584, 1559], device='mps:0'))

In [21]:
data['movie'].x.shape

torch.Size([4278, 3066])

In [22]:
data.node_types

['movie', 'director', 'actor']

In [23]:
data['actor'].x

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')

In [24]:
data['actor'].x.shape #5257 are nodes , 3066 feature vector dimension


torch.Size([5257, 3066])

In [25]:
data['actor'].x.abs().sum()

tensor(17004.1543, device='mps:0')

In [26]:
x = data['actor'].x
x.shape, x.dtype

(torch.Size([5257, 3066]), torch.float32)

In [27]:
# Value range and a few sample rows
x.min().item(), x.max().item(), x.mean().item()
#x[0, :20], x[1, :20]  # peek

(0.0, 1.0, 0.0010549816070124507)

In [28]:
x = data['actor'].x
is_integer_like = bool((x == x.round()).all())        # counts?
nnz = x.count_nonzero().item()
density = nnz / x.numel()
is_integer_like, nnz, density

(False, 76501, 0.004746319664979977)