In [1]:
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from importlib import reload

import ssl
ssl._create_default_https_context = ssl._create_stdlib_context


import matplotlib.pyplot as plt
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.tuner import Tuner

import random
L.seed_everything(random.randint(0, 10000))

[rank: 0] Seed set to 292


292

In [2]:
dataset = PyGLinkPropPredDataset(name="tgbl-wiki", root="datasets")

raw file found, skipping download
Dataset directory is  /home/mila/s/soroush.omranpour/Projects/my_env/lib/python3.10/site-packages/tgb/datasets/tgbl_wiki
loading processed file


In [3]:
class TGSeq(Dataset):
    def __init__(self, src, dst, ts, max_length=512):
        super().__init__()
        self.max_length = max_length
        
        self.src = src + 1
        self.dst = dst + 1
        self.ts = ts

    def __len__(self):
        return self.src.shape[0] - self.max_length + 1
        

    def __getitem__(self, idx):
        src = self.src[idx:idx+self.max_length]
        dst = self.dst[idx:idx+self.max_length]
        ts = self.ts[idx:idx+self.max_length]
        seq = torch.stack([torch.zeros_like(src), src, dst]).T.flatten()
        x = seq[:-1]
        y = seq[1:]
        t = torch.stack([torch.zeros_like(ts), ts, ts]).T.flatten()[:-1]
        
        return x, y, t

context_window = 512
train_ds = TGSeq(
    dataset.src[dataset.train_mask],
    dataset.dst[dataset.train_mask],
    dataset.ts[dataset.train_mask],
    max_length=context_window
)

val_idx = torch.where(dataset.val_mask)[0]
min_id, max_id = val_idx.min(), val_idx.max()
val_ds = TGSeq(
    dataset.src[min_id - context_window + 1:max_id+1],
    dataset.dst[min_id - context_window + 1:max_id+1],
    dataset.ts[min_id - context_window + 1:max_id+1],
    max_length=context_window
)

test_idx = torch.where(dataset.test_mask)[0]
min_id = val_idx.min()
test_ds = TGSeq(
    dataset.src[min_id - context_window + 1:],
    dataset.dst[min_id - context_window + 1:],
    dataset.ts[min_id - context_window + 1:],
    max_length=context_window
)
print(len(train_ds), len(val_ds), len(test_ds))
train_loader = DataLoader(train_ds, batch_size=128, num_workers=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=128, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=128, num_workers=4)

109721 23621 47242


In [4]:
# b = next(iter(train_loader))
# x, y, t = b
# x.shape, y.shape, t.shape

In [5]:
import src.modules
reload(src.modules)
import src.lm
reload(src.lm)
from src.lm import TGTransformer

model = TGTransformer(
    n_vocab=9226+2,
    d_hidden=128,
    d_mlp=512, 
    n_blocks = 8, 
    n_head = 8, 
    dropout=0.1
)

In [6]:
model.count_parameters()

4477196

In [7]:
project = 'tgbl-wiki-tglm'
name = 'test'
wandb_logger = WandbLogger(
    project=project,
    name=name,
    save_dir='/home/mila/s/soroush.omranpour/scratch/wandb'
)
trainer = L.Trainer(
    max_epochs=10,
    devices=1,
    accelerator="gpu", 
    logger=wandb_logger,
    accumulate_grad_batches=1,
    gradient_clip_val=1.,
    num_nodes=1
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model, train_loader, val_loader)

You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Currently logged in as: [33msoorism[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type             | Params
----------------------------------------------------
0 | node_embedding | Embedding        | 1.2 M 
1 | time_embedding | TimeEmbedding    | 64    
2 | blocks         | ModuleList       | 2.1 M 
3 | head           | Sequential       | 1.2 M 
4 | criterion      | CrossEntropyLoss | 0     
----------------------------------------------------
4.5 M     Trainable params
64        Non-trainable params
4.5 M     Total params
17.909    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

/home/mila/s/soroush.omranpour/Projects/my_env/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [9]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |                                                                                                    …

[{'test_loss': 7.314603328704834, 'test_mrr': 0.4213508069515228}]