In [124]:
#!/usr/bin/env python
# coding: utf-8

# # Imports, constants

# In[1]:
# imports
import pickle
import random
import argparse
from copy import deepcopy
import itertools
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from collections import defaultdict
from datetime import datetime
from pprint import pprint
from tqdm import tqdm
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# constants
DATA = "../../../data/created_data/seqs100000.tsv"
SEED = 566
# Results and models file paths
RESULTS_FILE = (
    "../../../data/out_metrics/results_{timestamp}_seq_{config_index}.pkl"
)
LOSSES_FILE = (
    "../../../data/out_metrics/losses_{timestamp}_seq_{config_index}.pkl"
)
MODELS_FILE = "../../../data/out_models/models_{timestamp}_seq_{config_index}.pkl"

# fix random seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [125]:
def read_and_tokenize(file_path):
    # Step 1: Read the TSV file with variable columns
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            row = line.strip().split('\t')
            data.append(row)

    # Determine the maximum number of columns dynamically
    max_cols = max(len(row) for row in data)
    column_names = [f"col_{i}" for i in range(max_cols)]
    df = pd.DataFrame(data, columns=column_names)

    # Step 2: Replace '.' and ',' with spaces, and handle missing columns by filling with "0"
    df = df.applymap(lambda x: x.replace('.', ' ').replace(',', ' ') if pd.notnull(x) else "0")

    # Step 3: Tokenize nodes/edges uniquely
    EOS_TOKEN = "<EOS>"
    node_edge_vocab = {EOS_TOKEN: 1}  # Start with EOS token
    node_edge_counter = 2  # Start token IDs from 2 to reserve 1 for EOS

    def tokenize_node_edge(value):
        nonlocal node_edge_counter
        if value not in node_edge_vocab:
            node_edge_vocab[value] = node_edge_counter
            node_edge_counter += 1
        return node_edge_vocab[value]

    df_tokenized_node_edge = df.applymap(lambda x: tokenize_node_edge(x) if x != "0" else 0)
    df_tokenized_node_edge[f"col_{len(column_names)}"]=0

    # Step 4: Tokenize word-by-word (using "_" and spaces as separators)
    word_vocab = {EOS_TOKEN: 1}  # Start with EOS token
    word_counter = 2  # Start token IDs from 2 to reserve 1 for EOS

    def tokenize_word_by_word(value):
        nonlocal word_counter
        tokens = []
        if value != "0":
            for word in value.replace("_", " ").replace(".", " ").split():  # Split on spaces and "_"
                if word not in word_vocab:
                    word_vocab[word] = word_counter
                    word_counter += 1
                tokens.append(word_vocab[word])
            tokens.append(word_vocab[EOS_TOKEN])  # Append EOS token
        tokens.append(0)  # Append EOS token
        return tokens

    df_tokenized_word_by_word = df.applymap(lambda x: tokenize_word_by_word(x) if x != "0" else [0])

    # Step 5: Unflatten word-by-word into a new DataFrame
    word_by_word_expanded = []
    for index, row in df_tokenized_word_by_word.iterrows():
        expanded_row = []
        for cell in row:
            if isinstance(cell, list):
                expanded_row.extend(cell)
            else:
                expanded_row.append(cell)
        word_by_word_expanded.append(expanded_row)

    max_words = max(len(row) for row in word_by_word_expanded)
    df_word_by_word_unflattened = pd.DataFrame(word_by_word_expanded, columns=[f"word_{i}" for i in range(max_words)]).fillna(0).astype(int)

    return df, df_tokenized_node_edge, df_word_by_word_unflattened, node_edge_vocab, word_vocab


In [126]:
df_original, df_node_edge, df_word_by_word, node_edge_vocab, word_vocab = read_and_tokenize(DATA)

# Print a sample
print("Original DataFrame:")
display(df_original.head())
print("\nNode/Edge Tokenized DataFrame:")
display(df_node_edge.head())
print("\nWord-by-Word Tokenized DataFrame:")
display(df_word_by_word.head())
print("\nNode/Edge Vocabulary len:")
print(len(node_edge_vocab))
print("\nWord Vocabulary len:")
print(len(word_vocab))

Original DataFrame:


Unnamed: 0,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7,col_8,col_9,col_10
0,Adolescent onset conduct-dissocial disorder,occurs_in,Adolescence,reversed_occurs_in,groups 3871486_1,occurs_in,Adolescence,0,0,0,0
1,groups 3530725_0,mapped_to,Neuropathic heredofamilial amyloidosis,reversed_mapped_to,groups 3731020_0,mapped_to,Neuropathic heredofamilial amyloidosis,0,0,0,0
2,Entire superior labial artery,reversed_has_entire_anatomy_structure,groups 3462550_0,has_laterality,Side,reversed_has_laterality,groups 3464901_0,has_laterality,Side,0,0
3,Phenobarbital 30 mg oral tablet,inactivation_indicator,723277005,reversed_inactivation_indicator,Thromboembolus of vein following surgical proc...,occurs_after,Surgical procedure,reversed_occurs_after,groups 3903443_3,occurs_after,Surgical procedure
4,Entire lower plate of the cochlear spiral lamina,has_laterality,Side,laterality_of,Structure of inferior sagittal sinus,reversed_laterality_of,Side,0,0,0,0



Node/Edge Tokenized DataFrame:


Unnamed: 0,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7,col_8,col_9,col_10,col_11
0,2,3303,3482,5292,5397,3303,3482,0,0,0,0,0
1,3,3304,3483,3366,5398,3304,3483,0,0,0,0,0
2,4,3305,3484,3307,3486,5306,7189,3307,3486,0,0,0
3,5,3306,3485,5293,5399,3380,3703,5314,7760,3380,3703,0
4,6,3307,3486,5294,5400,3355,3486,0,0,0,0,0



Word-by-Word Tokenized DataFrame:


Unnamed: 0,word_0,word_1,word_2,word_3,word_4,word_5,word_6,word_7,word_8,word_9,...,word_178,word_179,word_180,word_181,word_182,word_183,word_184,word_185,word_186,word_187
0,2,3,4,5,1,0,5358,232,1,0,...,0,0,0,0,0,0,0,0,0,0
1,6,7,8,1,0,5359,28,1,0,5425,...,0,0,0,0,0,0,0,0,0,0
2,9,10,11,12,1,0,5360,227,5361,5362,...,0,0,0,0,0,0,0,0,0,0
3,13,14,15,16,17,1,0,5363,5364,1,...,0,0,0,0,0,0,0,0,0,0
4,9,18,19,20,21,22,23,24,1,0,...,0,0,0,0,0,0,0,0,0,0



Node/Edge Vocabulary len:
8693

Word Vocabulary len:
9762


In [127]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

##########################################
# 1) Simple Toy Dataset
##########################################
class SimpleSequenceDataset(Dataset):
    """
    Example dataset that returns integer sequences of various lengths,
    with 0 as the PAD token.
    """
    def __init__(self, sequences, pad_token=0):
        self.sequences = sequences
        self.pad_token = pad_token

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long)

def collate_fn(batch, pad_token=0):
    """
    Collate function to pad a list of variable-length sequences
    so that each batch is shape [batch_size, seq_len].
    """
    lengths = [len(seq) for seq in batch]
    max_len = max(lengths)
    padded = []
    for seq in batch:
        pad_size = max_len - len(seq)
        padded.append(torch.cat([seq, torch.full((pad_size,), pad_token, dtype=torch.long)]))
    return torch.stack(padded)  # shape = (batch_size, max_len)

##########################################
# 2) Causal Transformer Model
##########################################
class GPTLikeModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=128,
        n_heads=4,
        n_layers=2,
        max_seq_len=50,
        pad_token=0,
        activation_fn=nn.GELU,
    ):
        super().__init__()
        self.pad_token = pad_token
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token)
        # A simple trainable positional encoding table:
        self.pos_emb = nn.Embedding(max_seq_len, d_model)

        # Transformer layers in "decoder" style:
        self.transformer_layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=d_model,
                nhead=n_heads,
                activation=activation_fn,
                batch_first=True  # <--- crucial so we can feed (B, T, E)
            ) for _ in range(n_layers)
        ])

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x, causal_mask=None, key_padding_mask=None):
        """
        x shape: [batch_size, seq_len]
        causal_mask shape: [seq_len, seq_len] if batch_first=True
        key_padding_mask shape: [batch_size, seq_len]
        """
        batch_size, seq_len = x.shape
        
        # Embedding + (optional) positional encoding
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)  # [1, seq_len]
        # shape => (batch_size, seq_len, d_model)
        x = self.embedding(x) + self.pos_emb(positions)

        # Pass through each TransformerDecoderLayer
        # For a "GPT-like" approach, we typically feed the same x as both "tgt" and "memory"
        out = x
        for layer in self.transformer_layers:
            out = layer(
                tgt=out,
                memory=out,
                tgt_mask=causal_mask,                   # (seq_len, seq_len)
                memory_mask=None,                       # not needed for GPT-like
                tgt_key_padding_mask=key_padding_mask,   # (batch_size, seq_len)
                memory_key_padding_mask=key_padding_mask # same as above in GPT-like
            )

        logits = self.fc_out(out)  # [batch_size, seq_len, vocab_size]
        return logits

##########################################
# 3) Utilities: Generate Causal Mask
##########################################
def generate_causal_mask(seq_len, device):
    """
    Generates a causal mask for self-attention (upper-triangular True).
    If using batch_first=True, shape must be [seq_len, seq_len].
    True entries in the mask indicate positions that should be masked (blocked).
    """
    # shape: (seq_len, seq_len), True means "do not attend".
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
    return mask

##########################################
# 4) Training Function
##########################################
def train_one_epoch(
    model, 
    data_loader, 
    optimizer, 
    criterion, 
    device, 
    pad_token=0
):
    model.train()

    total_loss = 0
    correct = 0
    total_tokens = 0

    for batch in data_loader:
        # batch shape: [batch_size, seq_len]
        batch = batch.to(device)
        batch_size, seq_len = batch.shape

        # We do next-token prediction:
        #   Input  = batch[:, :-1]
        #   Target = batch[:, 1:]
        #   So the output should be shape (batch_size, seq_len-1, vocab_size)
        #   We'll generate a (seq_len-1, seq_len-1) causal mask
        #   Also we shift the key_padding_mask to exclude the last token
        if seq_len < 2:
            # If any sequence < 2 tokens, skip
            continue
        inp = batch[:, :-1]   # shape (B, T-1)
        tgt = batch[:, 1:]    # shape (B, T-1)

        # Prepare mask
        c_mask = generate_causal_mask(seq_len - 1, device=device)  # shape (T-1, T-1)
        # Key padding mask: True where PAD => (B, T-1)
        kp_mask = (inp == pad_token)

        # Forward pass
        logits = model(inp, causal_mask=c_mask, key_padding_mask=kp_mask)
        # logits shape = [batch_size, seq_len-1, vocab_size]
        
        # Flatten for cross-entropy
        logits_2d = logits.reshape(-1, logits.size(-1))   # [B*(T-1), vocab_size]
        tgt_1d    = tgt.reshape(-1)                       # [B*(T-1)]

        # Compute loss
        loss = criterion(logits_2d, tgt_1d)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        total_loss += loss.item() * batch_size

        # Accuracy: ignore pad positions
        #  (pred == tgt) & (tgt != pad_token)
        preds = logits_2d.argmax(dim=-1)  # [B*(T-1)]
        valid_mask = (tgt_1d != pad_token)
        correct += (preds[valid_mask] == tgt_1d[valid_mask]).sum().item()
        total_tokens += valid_mask.sum().item()

    avg_loss = total_loss / len(data_loader.dataset)
    avg_acc = correct / (total_tokens + 1e-9)
    return avg_loss, avg_acc


def train_model(
    model,
    data_loader,
    epochs,
    lr=1e-3,
    pad_token=0,
    device='cpu'
):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # CrossEntropy with ignore_index = pad_token
    criterion = nn.CrossEntropyLoss(ignore_index=pad_token)

    for e in range(1, epochs + 1):
        loss, acc = train_one_epoch(model, data_loader, optimizer, criterion, device, pad_token)
        print(f"Epoch [{e}/{epochs}] | Loss: {loss:.4f} | Acc: {acc:.4f}")


##########################################
# 5) Example Usage
##########################################
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Example 'vocab_size' includes pad_token=0 up to some max token
    vocab_size = len(node_edge_vocab)+1
    pad_token = 0

    # Let's create random sequences of length up to 8
    # Real code: you can supply your own sequences from df_node_edge, etc.
    random_data = df_node_edge.to_numpy()

    dataset = SimpleSequenceDataset(random_data, pad_token=pad_token)
    data_loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_token))

    # Build model
    model = GPTLikeModel(
        vocab_size=vocab_size,
        d_model=32,
        n_heads=4,
        n_layers=2,
        max_seq_len=50,   # can be bigger than any real seq
        pad_token=pad_token,
        activation_fn=nn.GELU()
    ).to(device)

    # Train
    train_model(
        model,
        data_loader=data_loader,
        epochs=1000,
        lr=1e-3,
        pad_token=pad_token,
        device=device
    )


Epoch [1/1000] | Loss: 6.6279 | Acc: 0.1297
Epoch [2/1000] | Loss: 4.7819 | Acc: 0.3365
Epoch [3/1000] | Loss: 3.8475 | Acc: 0.4601
Epoch [4/1000] | Loss: 3.3238 | Acc: 0.5094
Epoch [5/1000] | Loss: 2.9389 | Acc: 0.5423
Epoch [6/1000] | Loss: 2.6169 | Acc: 0.5610
Epoch [7/1000] | Loss: 2.3297 | Acc: 0.5831
Epoch [8/1000] | Loss: 2.0696 | Acc: 0.6002
Epoch [9/1000] | Loss: 1.8296 | Acc: 0.6245
Epoch [10/1000] | Loss: 1.6090 | Acc: 0.6454
Epoch [11/1000] | Loss: 1.3885 | Acc: 0.6769
Epoch [12/1000] | Loss: 1.2143 | Acc: 0.7051
Epoch [13/1000] | Loss: 1.0506 | Acc: 0.7384
Epoch [14/1000] | Loss: 0.9230 | Acc: 0.7692
Epoch [15/1000] | Loss: 0.7838 | Acc: 0.8010
Epoch [16/1000] | Loss: 0.6820 | Acc: 0.8231
Epoch [17/1000] | Loss: 0.6110 | Acc: 0.8377
Epoch [18/1000] | Loss: 0.5286 | Acc: 0.8590
Epoch [19/1000] | Loss: 0.4973 | Acc: 0.8648
Epoch [20/1000] | Loss: 0.4206 | Acc: 0.8868
Epoch [21/1000] | Loss: 0.4075 | Acc: 0.8867
Epoch [22/1000] | Loss: 0.3800 | Acc: 0.8948
Epoch [23/1000] | L

KeyboardInterrupt: 