In [None]:
# note:
# if youre training on a server thru SSH, i would recommend you use the .py file
# for training, as this would let you remount a TMUX terminmal if you disconnect.

In [None]:
###   CONFIGURATION   ###

In [None]:
from gensim.models import Word2Vec
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from IPython.display import clear_output
import tokenizer
import os
import math
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime, timedelta

In [None]:
# word2vec
model_file = fr"./embedding_models/YOUR_GENSIM_MODEL_NAME.model"
embeddings_model = Word2Vec.load(model_file)

vector_size = embeddings_model.vector_size        # aka embedding dim 

# neural net settings
weights_file = fr"./REAN_weights/YOUR_PYTORCH_FILE.pth"

context_length = 128                              # tokens to consider
attn_heads = 8                                    # num attention heads per mechanism (per transformer block)
dropout_prob = 0.0                                # 0.0 ---> everything normal   |   1.0 ---> everything is random

# pytorch
run_device = torch.device("cuda")
storage_device = torch.device("cpu")

In [None]:
###   NEURAL NET ARCHITECTURE   ###

In [None]:
class leaky_tanh_smart(nn.Module):
    def __init__(self, leaky_range=(0, 3), squishy_range=(0, 3)):
        super(leaky_tanh_smart, self).__init__()
        # register leakyness and squishyness as trainable parameters
        self.leakyness = nn.Parameter(torch.rand(1, dtype=torch.float32) * (leaky_range[1] - leaky_range[0]) + leaky_range[0])
        self.squishyness = nn.Parameter(torch.rand(1, dtype=torch.float32) * (squishy_range[1] - squishy_range[0]) + squishy_range[0])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        applies the leaky tanh activation function over the input tensor x.\n
        for more info on leaky tanh and its parameters go to: https://www.desmos.com/calculator/kpzsfbtqww
        
        Args:
            x (torch.Tensor): tensor over which to apply activation function.
        
        Returns:
            torch.Tensor: returns x after function applied, keeps the same shape.
        """
        
        return F.tanh(x * self.squishyness) + self.leakyness * x

In [None]:
class attention_mech(nn.Module):
    def __init__(self, vector_size=vector_size, attn_heads=attn_heads):
        super(attention_mech, self).__init__()
        # MultiheadAttention module
        self.multihead_attn = nn.MultiheadAttention(embed_dim=vector_size, num_heads=attn_heads)
        
        # Layer normalization
        self.norm = nn.LayerNorm(vector_size)

    def forward(self, x):
        # Prepare for multi-head attention (transpose to (sentence_len, batch_size, embedding_dim))
        x = x.transpose(0, 1)
        
        # Create causal mask
        seq_len = x.size(0)
        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=x.device), diagonal=1).bool()
        
        # Apply multi-head attention with the causal mask
        attn_output, attn_weights = self.multihead_attn(x, x, x, attn_mask=causal_mask)
        
        # Apply layer normalization to the attention output
        attn_output = self.norm(attn_output)
        
        # Transpose back to (batch_size, sentence_len, embedding_dim)
        output = attn_output.transpose(0, 1)
        
        return output, attn_weights

In [None]:
class positional_encoding(nn.Module):
    def __init__(self):
        super(positional_encoding, self).__init__()

    def forward(self, x):
        batch_size, context_length, vector_size = x.size()

        # Generate positions (shape: [context_length, 1])
        position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1).to(x.device)

        # Compute the divisor term (shape: [vector_size // 2])
        div_term = torch.exp(torch.arange(0, vector_size, 2).float() * (-math.log(10000.0) / vector_size)).to(x.device)

        # Initialize positional encoding tensor (shape: [context_length, vector_size])
        pe = torch.zeros(context_length, vector_size, device=x.device)
        
        # Apply sine to even indices and cosine to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)  # sine for even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # cosine for odd indices

        # Add positional encoding to the input
        x = x + pe.unsqueeze(0)  # Add positional encoding, shape becomes (batch_size, context_length, vector_size)

        return x

In [None]:
class transformer_block(nn.Module):
    def __init__(self, vector_size=vector_size):
        super(transformer_block, self).__init__()
        
        self.activ_func = leaky_tanh_smart()
        
        self.attn = attention_mech()
        
        self.fc = nn.Linear(vector_size, vector_size)
        
        self.norm1 = nn.LayerNorm(vector_size)
        self.norm2 = nn.LayerNorm(vector_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm1(x + self.attn(x)[0])
        x = self.norm2(x + self.activ_func(self.fc(x)))
        
        return x

In [None]:
class REAN(nn.Module):
    def __init__(self):
        super(REAN, self).__init__()
        
        self.pos_encoding = positional_encoding()
        
        self.tblock1 = transformer_block()
        self.tblock2 = transformer_block()
        self.tblock3 = transformer_block()
        self.tblock4 = transformer_block()

    def forward(self, segment: torch.Tensor) -> torch.Tensor:
        """
        this function is primarily used for training, where the network needs to predict the next token, for every token in the sequence
        
        Args:
            segment (torch.Tensor): this is a tensor of size (batches, context_length, vector_size) representing a sequence of tokens (of course from the tokenizer and using the correct word2vec model)
        
        Returns:
            torch.Tensor: a tensor of shape (batches, context_length, vector_size) (same as segment) representing the sequence predicted by the network shifted future-way
        """
        
        ###                  INPUT                 ###
        #    (batches, context_len, vector_size)
        #                      ↓
        
        segment = self.pos_encoding(segment)
        
        segment = self.tblock1(segment)
        segment = self.tblock2(segment)
        segment = self.tblock3(segment)
        segment = self.tblock4(segment)
        
        return segment
    
        #                      ↓
        #    (batches, context_len, vector_size)
        ###                 OUTPUT                 ###

    def predict(self, segment: torch.Tensor) -> torch.Tensor:
        """
        function is for predicting the embeddings vector of the next token in a given sequence
        
        Args:
            segment (torch.Tensor): this is a tensor of size (batches, context_length, vector_size) representing a sequence of tokens (of course from the tokenizer and using the correct word2vec model)
        
        Returns:
            torch.Tensor: a tensor of shape (batches, vector_size) representing the embeddings vector of the next token to be added into the sequence
        """
        
        ###                  INPUT                 ###
        #    (batches, context_len, vector_size)
        #                      ↓
        
        segment = self.forward(segment)
        
        return segment[:, -1, :]
        
        #                      ↓
        #           (batches, vector_size)
        ###                 OUTPUT                 ###

In [None]:
###   BUILD NET & DEPENDENCIES   ###

In [None]:
net = torch.load(weights_file)

net.to(run_device)

print(f"neural net weight: {sum(param.numel() * param.element_size() for param in net.parameters()) / (1024 ** 3):.4f}GB")

In [None]:
###   UTIL FUNCS   ###

In [None]:
def vectorize_segment(segment: list[str], model: Word2Vec=embeddings_model, default: int = 0, used_device=storage_device) -> np.ndarray:
    """
    encodes all words in a given list to corresponding vectors in given model.
    words not found in the model will be given a vector with "default" value
    
    Args:
        sentence (list): list of strings (tokenized sentence)
        model (Word2Vec): model to use when encoding
        default (int): fill vector with this value if word is not found in model
    
    Returns:
        np.array: 2d array with dim1 = len(sentence) and dim2 = model.vector_size
    """
    
    # generate inital array with default values
    vectorized = np.ones((len(segment), model.vector_size)) * default
    
    # loop over every word in list
    for current_word, current_word_idx in zip(segment, range(len(segment))):
        # only add correct values if word is in model, otherwise leave as default
        if current_word in model.wv:
            # the try except block is needed because (current_word in model.wv) sometimes gives a false positive... yeah gensim
            try:
                vectorized[current_word_idx] = model.wv.get_vector(current_word, norm=False)
            except:
                pass
    
    vectorized = torch.tensor(vectorized, dtype=torch.float32, device=used_device)
    
    return vectorized

In [None]:
def devectorize_segment(vectorized_segment: torch.Tensor, model: Word2Vec=embeddings_model, not_in_vocab_token="[NIV]", NIV_threshold=0.01) -> list:
    """
    decodes vectors into nearest word found in model, if no near words found, adds a not in vocab token
    
    Args:
        vectorized_sentence (np.array): 2d arrat with vectors of words to be decoded
        model (Word2Vec): model to use when decoding
    
    Returns:
        list: list of strings (words) whos vectors most closely match those provided
    """
    
    result = []
    
    # make sure vectors are ready to be processed
    vectorized_segment = vectorized_segment.cpu().numpy()
    
    # go over all words and find closest match in model
    for current_word in vectorized_segment:
        similarities = model.wv.similar_by_vector(current_word)
        
        # check if its not a bullshit vector
        if similarities[0][1] > NIV_threshold:
            result.append(similarities[0][0])
        else:
            result.append(not_in_vocab_token)
    
    return result

In [None]:
def pad_or_truncate(suspected_tensor: torch.tensor, target_length: int, default: int=0) -> torch.Tensor:
    """
    pads or truncates a given tensor along dim 0 to target_length with "default" as padding
    
    Args:
        suspected_tensor (torch.tensor): tensor to pad or truncate
        target_length (int): target length of tensor
        default (int): value to use for padding
    
    Returns:
        torch.tensor: tensor of proper length no matter what
    """
    
    if len(suspected_tensor) < target_length:
        # pad
        suspected_tensor = torch.cat((torch.ones(target_length - len(suspected_tensor), suspected_tensor.shape[1], dtype=torch.float32, device=suspected_tensor.device) * default, suspected_tensor))
    else:
        # truncate
        suspected_tensor = suspected_tensor[-target_length:]
    
    return suspected_tensor

In [None]:
def prepare_segment_for_net(segment: list[str], length: int=context_length, used_device: torch.DeviceObjType=storage_device):
    """
    function to take a sentence, and do everything to make it possible to input into the net
    
    Args:
        segment (list[str]): a list of tokens (ideally from the tokenizer) of a sentence / text
        length (int): the number of tokens to which pad or truncate to. for correct operation: keep at the net's context length
    
    Returns:
        torch.Tensor: tokenized segment in the correct length
    """
    
    # turn into embedding vectors
    vectorized = vectorize_segment(segment, used_device=used_device)
    
    # trim / add into length
    trimmed = pad_or_truncate(vectorized, length)
    
    # add fake batch dimension
    batched = trimmed.unsqueeze(0)
    
    return batched

In [None]:
def predict_word(segment: list[str], net: REAN=net):
    # turn tokenized text into net's format
    prepared_segment = prepare_segment_for_net(segment, used_device=next(net.parameters()).device)
    
    # run net
    prediction_vector = net.predict(prepared_segment).detach()
    
    # turn vector back into token
    predicted_token = devectorize_segment(prediction_vector)
    
    return predicted_token

In [None]:
def predict_sequence(segment: list[str], num_tokens: int, net: REAN=net, display_tqdm=False):
    result = segment.copy()
    
    for _ in tqdm(range(num_tokens), disable=not display_tqdm):
        result += predict_word(result, net=net)
    
    return result[len(segment):]

In [None]:
###   EVAL   ###

In [None]:
net.eval()
clear_output()

In [None]:
prompt = "human: " + "write a list of the top 10 sports cars" + " network: "
tokens_to_predict = 128
display_tqdm = True

print(tokenizer.detokenize_segment(predict_sequence(tokenizer.tokenize_segment(prompt), tokens_to_predict, display_tqdm=display_tqdm)))