In [1]:
# TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387
# https://github.com/twitter/the-algorithm-ml
# https://en.wikipedia.org/wiki/Knowledge_graph_embedding

# Dataset
# https://huggingface.co/datasets/Twitter/TwitterFaveGraph
# https://stackoverflow.com/questions/25962114/how-do-i-read-a-large-csv-file-with-pandas

# TorchRec
# https://github.com/pytorch/torchrec/blob/main/Torchrec_Introduction.ipynb

# venv setup https://github.com/twitter/the-algorithm-ml/blob/main/images/init_venv.sh
# https://stackoverflow.com/questions/42449814/running-jupyter-notebook-in-a-virtualenv-installed-sklearn-module-not-available

# loss
# https://zhang-yang.medium.com/how-is-pytorchs-binary-cross-entropy-with-logits-function-related-to-sigmoid-and-d3bd8fb080e7
# https://medium.com/dejunhuang/learning-day-57-practical-5-loss-function-crossentropyloss-vs-bceloss-in-pytorch-softmax-vs-bd866c8a0d23

In [2]:
import math
import torch
import torchrec
from torch import nn
import torch.nn.functional as F
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
import torch.distributed as dist
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")

In [3]:
class MyKGE(nn.Module):
    def __init__(self, embedding_dim = 4, num_features = 2, feature_names=['user', 'tweet'], num_embeddings=[424_241, 72_543], num_relations = 4, in_batch_negatives=10, device = "cpu"):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.in_batch_negatives = in_batch_negatives
        self.num_relations = num_relations
        self.feature_names = feature_names
        self.device = device
        
        tables = []

        for i in range(num_features):
            tables.append(
                EmbeddingBagConfig(
                    embedding_dim=embedding_dim,
                    feature_names=[feature_names[i]],  # restricted to 1 feature per table for now
                    name=feature_names[i],
                    num_embeddings=num_embeddings[i],
                    pooling=torchrec.PoolingType.SUM,
                )
            )        

        # embedding for features
        self.ebc = EmbeddingBagCollection(
          #device="meta",
          device = device,
          tables=tables,
        )

        # embedding for relation (translation)
        self.relation_embedding = nn.Embedding(num_relations, embedding_dim)

    

    def forward(self, users, tweets, rels):
        batch_size = users.shape[0]
        mb = torchrec.KeyedJaggedTensor(
            keys = self.feature_names,
            values = torch.concat([user, tweet]).to(self.device),
            lengths = torch.tensor([1], dtype=torch.int64).repeat(batch_size*2).to(self.device), # user batch size + tweet batch size
        )        
        
        x = self.ebc(mb.to(device)).values() # B * 2D
        x = x.reshape(batch_size, 2, self.embedding_dim) #  B * 2 * D
        trans_embs = self.relation_embedding(rels) # B * D
        
        # translation
        translated = x[:, 1, :] + trans_embs

        # negative sampling
        negs = []
        if self.in_batch_negatives:
            for relation in range(self.num_relations):
                rel_mask = rels == relation
                rel_count = rel_mask.sum()

                if not rel_count:
                  continue        

                # R x D
                lhs_matrix = x[rel_mask, 0, :]
                rhs_matrix = x[rel_mask, 1, :]

                lhs_perm = torch.randperm(lhs_matrix.shape[0])
                # repeat until we have enough negatives
                lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
                lhs_indices = lhs_perm[: self.in_batch_negatives]
                sampled_lhs = lhs_matrix[lhs_indices]
        
                rhs_perm = torch.randperm(rhs_matrix.shape[0])
                # repeat until we have enough negatives
                rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
                rhs_indices = rhs_perm[: self.in_batch_negatives]
                sampled_rhs = rhs_matrix[rhs_indices]
        
                # RS
                negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
                negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
        
                negs.append(negs_lhs)
                negs.append(negs_rhs)   

        # dot product for positives' scoring
        x = (x[:, 0, :] * translated).sum(-1)

        # concat positives and negatives
        x = torch.cat([x, *negs])        

        return {
          "logits": x,
          "probabilities": torch.sigmoid(x),
        }        


In [4]:
# https://huggingface.co/datasets/Twitter/TwitterFaveGraph/blob/main/TwitterFaveGraph.csv.zip
filename = "./TwitterFaveGraph.csv"
df = pd.read_csv(filename, nrows=1024*100)
df

Unnamed: 0,user_index,tweet_index,time_chunk
0,2664563,4656426,134
1,9109889,3354941,32
2,10473449,4994572,31
3,4288842,5859769,133
4,6032357,6472618,10
...,...,...,...
102395,11791325,3499381,176
102396,7222675,2559397,110
102397,6612248,1656162,104
102398,5138976,5782149,102


In [5]:
class TwitterFaveGraphDataset(Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        user_index_data = torch.tensor(int(self.df["user_index"].iloc[idx]))
        tweet_index_data = torch.tensor(int(self.df["tweet_index"].iloc[idx]))

        return user_index_data, tweet_index_data

In [6]:
df_train = df.sample(frac=0.8,random_state=200)
df_test = df.drop(df_train.index)
dataset_train = TwitterFaveGraphDataset(df_train)
dataset_test = TwitterFaveGraphDataset(df_test)
dataset_test[0]

(tensor(4288842), tensor(5859769))

In [7]:
dataloader_train = DataLoader(dataset_train, batch_size = 512, shuffle = True)
dataloader_test = DataLoader(dataset_train, batch_size = 512, shuffle = False)

In [8]:
user, tweet = next(iter(dataloader_train))
#user, tweet

In [9]:
df.tweet_index.max()

6761642

In [10]:
device = "cuda"
#device = "cpu"
model = MyKGE(num_embeddings=[int(df.user_index.max()+1), int(df.tweet_index.max()+1)])
#model = torchrec.distributed.DistributedModelParallel(model, device=torch.device("cuda"))
model = model.to(device)

In [11]:
mb = torchrec.KeyedJaggedTensor(
    keys = ["user", "tweet"],
    values = torch.concat([user, tweet]).to(device),
    lengths = torch.tensor([1], dtype=torch.int64).repeat(user.shape[0]*2).to(device),
)

mb.to(torch.device("cpu"))

<torchrec.sparse.jagged_tensor.KeyedJaggedTensor at 0x7f4d745f91c0>

In [12]:
model.ebc.embedding_bags

ModuleDict(
  (user): EmbeddingBag(13130632, 4, mode=sum)
  (tweet): EmbeddingBag(6761643, 4, mode=sum)
)

In [13]:
model.ebc.embedding_bags['user']

EmbeddingBag(13130632, 4, mode=sum)

In [14]:
model.ebc(mb).values()

tensor([[-0.4515,  0.3983, -0.4965,  ..., -0.2964,  1.3926, -0.7841],
        [ 0.1097, -1.4804,  0.7915,  ..., -0.0192,  0.1354,  0.0383],
        [-2.1182, -1.8855,  0.6606,  ...,  0.0640, -0.7168, -0.8143],
        ...,
        [-0.5152, -0.4280,  0.0953,  ...,  0.8195, -0.2280, -0.0336],
        [ 0.6786, -1.8344, -1.0917,  ..., -1.1620, -1.2750, -1.1539],
        [-0.3536,  0.5567, -0.5723,  ..., -1.5122,  0.3058,  2.4866]],
       device='cuda:0', grad_fn=<CatBackward0>)

In [15]:
model.ebc(mb).to_dict()

{'user': tensor([[-0.4515,  0.3983, -0.4965, -2.0446],
         [ 0.1097, -1.4804,  0.7915, -1.2557],
         [-2.1182, -1.8855,  0.6606,  1.2509],
         ...,
         [-0.5152, -0.4280,  0.0953,  1.8395],
         [ 0.6786, -1.8344, -1.0917, -0.0861],
         [-0.3536,  0.5567, -0.5723, -0.1013]], device='cuda:0',
        grad_fn=<SplitWithSizesBackward0>),
 'tweet': tensor([[ 0.1382, -0.2964,  1.3926, -0.7841],
         [ 0.2004, -0.0192,  0.1354,  0.0383],
         [-0.2875,  0.0640, -0.7168, -0.8143],
         ...,
         [-1.0028,  0.8195, -0.2280, -0.0336],
         [ 0.1024, -1.1620, -1.2750, -1.1539],
         [-1.2999, -1.5122,  0.3058,  2.4866]], device='cuda:0',
        grad_fn=<SplitWithSizesBackward0>)}

In [16]:
pooled_embeddings = model.ebc(mb).to_dict()
print("user embeddings", pooled_embeddings["user"])
print("tweet embeddings", pooled_embeddings["tweet"])

user embeddings tensor([[-0.4515,  0.3983, -0.4965, -2.0446],
        [ 0.1097, -1.4804,  0.7915, -1.2557],
        [-2.1182, -1.8855,  0.6606,  1.2509],
        ...,
        [-0.5152, -0.4280,  0.0953,  1.8395],
        [ 0.6786, -1.8344, -1.0917, -0.0861],
        [-0.3536,  0.5567, -0.5723, -0.1013]], device='cuda:0',
       grad_fn=<SplitWithSizesBackward0>)
tweet embeddings tensor([[ 0.1382, -0.2964,  1.3926, -0.7841],
        [ 0.2004, -0.0192,  0.1354,  0.0383],
        [-0.2875,  0.0640, -0.7168, -0.8143],
        ...,
        [-1.0028,  0.8195, -0.2280, -0.0336],
        [ 0.1024, -1.1620, -1.2750, -1.1539],
        [-1.2999, -1.5122,  0.3058,  2.4866]], device='cuda:0',
       grad_fn=<SplitWithSizesBackward0>)


In [17]:
pooled_embeddings["user"][0]

tensor([-0.4515,  0.3983, -0.4965, -2.0446], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [18]:
lossFunc = F.binary_cross_entropy_with_logits
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_loop():
    model.train()
    
    step = 0
    losses = 0
    for users, tweets in dataloader_train:
        batch_size = users.shape[0]
        rels = torch.tensor([0]).repeat(batch_size)
        outp = model(users.to(device), tweets.to(device), rels.to(device) )
        #print(outp['logits'].shape)

        logits = outp['logits']
        num_negatives = 2 * batch_size * model.in_batch_negatives
        num_positives = batch_size
    
        neg_weight = float(num_positives) / num_negatives
    
        labels = torch.cat([torch.ones(num_positives).to(device), torch.ones(num_negatives).to(device)])
    
        weights = torch.cat(
          [torch.ones(num_positives).to(device), (torch.ones(num_negatives) * neg_weight).to(device)]
        )
       
        loss = lossFunc(logits, labels, weights)
        loss.backward()
    
        with torch.no_grad():
            optimizer.step()
            optimizer.zero_grad()
            
        step+=1
        losses += loss.item()
        #if step % 100 == 0:
            #print(loss)
    print(f"train epoch loss: {losses/step}")

@torch.no_grad()
def eval_loop():
    model.eval()
    step = 0
    losses = 0
    print(f"eval epoch loss: {losses/step}")
    
for i in range(10):
    print(f'epoch #{i}')
    train_loop()
    #eval_loop()

epoch #0
train epoch loss: 0.08896548664197326
epoch #1
train epoch loss: 0.07396449549123645
epoch #2
train epoch loss: 0.06490082058589905
epoch #3
train epoch loss: 0.05926096269395202
epoch #4
train epoch loss: 0.055423598387278616
epoch #5
train epoch loss: 0.052668127021752296
epoch #6
train epoch loss: 0.05005531164351851
epoch #7
train epoch loss: 0.047707011457532644
epoch #8
train epoch loss: 0.04520940040238201
epoch #9
train epoch loss: 0.042617363412864505
