In [4]:
import time
import dgl
import numpy as np
from functools import partial
from torch.nn import CrossEntropyLoss
import torch
from typing import Iterator
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR

In [6]:
from packages.data_management.reddit_data import get_dataloader, load_reddit_adj, load_reddit_feats, load_reddit_labels, load_reddit_masks
from packages.transformer.optimizer import rate, SimpleLossCompute
from packages.utils.checkpointing import load_model, checkpoint_model
from packages.transformer.data import cora_data_gen, TransformerGraphBundleInput
from packages.transformer.encoder_decoder import make_model, EncoderDecoder

In [9]:
def run_train_epoch(subgraph_bundle_generator: Iterator[TransformerGraphBundleInput], model: EncoderDecoder, loss_compute: SimpleLossCompute,
                    optimizer: torch.optim.Optimizer, scheduler: LambdaLR):
    "Standard Training and Logging Function"
    start = time.time()
    total_loss = 0
    ntokens = 0
    for subgraph_bundle in subgraph_bundle_generator: 
        optimizer.zero_grad(set_to_none=True)
        out = model.forward(subgraph_bundle.src_feats, subgraph_bundle.src_mask,  
                            subgraph_bundle.train_inds) # B x B_out x model_D.  
        # TODO: need to think about this loss computation carefully. Is it even possible?
        loss, loss_node = loss_compute(out, subgraph_bundle.trg_labels, subgraph_bundle.ntokens)
        ntokens += subgraph_bundle.ntokens 
        total_loss += loss
        loss_node.backward()
        optimizer.step()
        scheduler.step()
    elapsed = time.time() - start
    print(f"Train loss on epoch: {total_loss / ntokens}; time taken: {elapsed}")
    return total_loss / ntokens, elapsed

def run_eval_epoch(subgraph_bundle_generator: Iterator[TransformerGraphBundleInput], model: EncoderDecoder, \
                    loss_compute: SimpleLossCompute):
    start = time.time()
    total_loss = 0
    ntokens = 0
    for subgraph_bundle in subgraph_bundle_generator: 
        out = model.forward(subgraph_bundle.src_feats, subgraph_bundle.src_mask,  
                            subgraph_bundle.train_inds) # B x B_out x model_D.  
        # TODO: need to think about this loss computation carefully. Is it even possible?
        loss, _ = loss_compute(out, subgraph_bundle.trg_labels, subgraph_bundle.ntokens)
        ntokens += subgraph_bundle.ntokens 
        total_loss += loss
    elapsed = time.time() - start
    print(f"Validation loss on epoch: {total_loss / ntokens}")
    return total_loss / ntokens 

# TODO: need to think about this carefully. We'll be evaluating on packed batches.
def eval_accuracy(subggraph_bundle_generator: Iterator[TransformerGraphBundleInput], model: EncoderDecoder):
    total = 0
    num_correct = 0 
    for graph_bundle in subggraph_bundle_generator:
        out = model.forward(graph_bundle.src_feats, graph_bundle.src_mask, graph_bundle.train_inds) # B x B_out x model_D.  
        out = model.generator(out) # B x num_nodes x num_classes
        out = out.squeeze(0) # num_nodes x num_classes
        out = out.argmax(axis=1) # num_nodes
        mb_test_labels = graph_bundle.trg_labels.squeeze(0)
        total += mb_test_labels.shape[0]
        num_correct += (out == mb_test_labels).sum()
        test_accuracy = num_correct / total
        print(test_accuracy)
    return test_accuracy

def build_dataloader(graph: dgl.DGLHeteroGraph, sampler: dgl.dataloading.MultiLayerNeighborSampler, ids: torch.tensor):
    dataloader = dgl.dataloading.DataLoader(
        graph, ids, sampler,
        batch_size=64,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    return dataloader

def train_model(bs: int, num_sg: int):
    """Train the model on Reddit data.

    Args:
        bs (int): _description_
        num_sg (int): _description_
    """
    adj_sparse = load_reddit_adj()
    feats = load_reddit_feats()
    masks = load_reddit_masks()
    labels = load_reddit_labels()
    num_classes = len(labels.unique())
    all_ids = np.arange(masks.shape[1])
    train_ids = all_ids[masks[0,:]]
    val_ids = all_ids[masks[0,:]]
    test_ids = all_ids[masks[0,:]]

    graph = dgl.graph((adj_sparse.row, adj_sparse.col))
    sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 5])
    _build_dataloader = partial(build_dataloader, graph, sampler)

    train_dataloader = _build_dataloader(train_ids)
    val_dataloader = _build_dataloader(val_ids)
    test_dataloader = _build_dataloader(test_ids)

    criterion = CrossEntropyLoss(reduction='sum').cuda()
    model = make_model(feats.shape[1], num_classes, N=2).cuda() 

    torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-6)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, 512, factor=1, warmup=3000
        ),
    )
    batch_size = bs
    num_subgraphs = num_sg

    tb_log_dir = f"runs/batch-{batch_size}_num_sg-{num_subgraphs}_reddit"
    tb_sw = SummaryWriter(tb_log_dir)

    device= 'cuda'
    nepochs = 30
    for nepoch in range(nepochs):
        model.train()
        nbatches = train_ids.shape[0] // batch_size
        epoch_loss, train_epoch_elapsed  = run_train_epoch(cora_data_gen(train_dataloader, nbatches, num_subgraphs, feats, labels, device), model, 
            SimpleLossCompute(model.generator, criterion), optimizer, lr_scheduler)

        tb_sw.add_scalar('Loss/train', epoch_loss, nepoch)
        tb_sw.add_scalar('Duration/train', train_epoch_elapsed, nepoch)
        
        model.eval()
        with torch.no_grad():
            validation_loss = run_eval_epoch(cora_data_gen(val_dataloader, nbatches, 1, feats, labels, device), model, 
                SimpleLossCompute(model.generator, criterion, None))
            tb_sw.add_scalar('Loss/validation', validation_loss, nepoch)
            if validation_loss < best_loss:
                checkpoint_model(model)
                best_loss = validation_loss
                best_loss_epoch = nepoch

In [10]:
train_model(32, 1)

(232965, 232965)


(3, 232965)


(232965,)


In [8]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 5])
dataloader = dgl.dataloading.DataLoader(
    graph, train_ids, sampler,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    num_workers=0)

In [82]:
from torch import square

def convert_mfg_to_sg_adj(mfg: DGLBlock, square_shape: int):
    sparse_adj = mfg.adj()
    square_adj = torch.sparse_coo_tensor(sparse_adj._indices(), sparse_adj._values(), size=(square_shape, square_shape)) 
    return square_adj.to_dense()

def construct_batch(target_nodes, subgraph_nodes, mfgs, features, labels, device):
    first_layer_mfg = mfgs[0]
    second_layer_mfg = mfgs[1]

    first_layer_adj_submatrix = convert_mfg_to_sg_adj(first_layer_mfg, subgraph_nodes.shape[0]) + torch.eye(subgraph_nodes.shape[0], device=device) 
    second_layer_adj_submatrix = convert_mfg_to_sg_adj(second_layer_mfg, subgraph_nodes.shape[0]) + torch.eye(subgraph_nodes.shape[0], device=device) 
    
    minibatch_adjacencies = torch.stack((first_layer_adj_submatrix, second_layer_adj_submatrix))
    all_minibatch_feats = features[subgraph_nodes, :]

    all_minibatch_feats = all_minibatch_feats.unsqueeze(0)
    minibatch_adjacencies = minibatch_adjacencies.unsqueeze(0)
    minibatch_labels = labels[target_nodes].unsqueeze(0)
    output_node_inds = target_nodes.unsqueeze(0)

    minibatch = TransformerGraphBundleInput(all_minibatch_feats, minibatch_labels, minibatch_adjacencies, output_node_inds, device)
    return minibatch

In [9]:
dataloader_iter = iter(dataloader)
input_nodes, output_nodes, mfgs = next(dataloader_iter) # input nodes gives us the requisite features. The mfgs gives us the requisite attention mask

In [70]:
unique_labels = torch.Tensor(labels).unique()

In [83]:
mfgs = load_pkl_from_path("data/reddit_cpr/mfg")
all_sg_nodes = mfgs[0].srcdata[dgl.NID]

all_tgt_nodes = mfgs[1].dstdata[dgl.NID]
minibatch = construct_batch(all_tgt_nodes, all_sg_nodes, mfgs, torch.Tensor(feats), torch.Tensor(labels), 'cpu')


TypeError: __init__() missing 1 required positional argument: 'device'

In [74]:
model = make_model(feats.shape[1], len(unique_labels), N=2) # TODO: do i need a +1 here?? DOn't think so.

  nn.init.xavier_uniform(p)


In [79]:
model.forward(minibatch.src_feats, minibatch.src_mask, minibatch.train_inds)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

In [32]:
adj_sparse = convert_scipy_sparse_to_torch(adj_sparse, 'cpu')

In [33]:
minibatch = construct_batch(output_nodes, input_nodes, mfgs, adj_sparse, feats, 'cpu')

: 

: 

In [3]:
def get_max_nodes_in_batch(batch_size, fanout):
    print(batch_size + batch_size * fanout + (batch_size + batch_size * fanout) * fanout)

get_max_nodes_in_batch(64, 5)
get_max_nodes_in_batch(64, 4)
get_max_nodes_in_batch(32, 5)

2304
1600
1152
