In [56]:
import torch
import torch.nn as nn
import altair as alt
import pandas as pd
import numpy as np
import warnings
import tokenizer
import tokenizers
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split
warnings.filterwarnings("ignore")

In [None]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.cuda.empty_cache()

In [106]:
from train import DataSetLoader, Training
from dataset import TranslationDataset
from config import get_config
from transformer import Transformer
from transformer import TransformerBuilder

In [107]:
def load_model(config):
		# print some nice looking message
		print("=== SUMMIT Training Process ===\n")

		config = config
		max_tokens = int(config['MAX_SUPPORTED_SENTENCE_TOKEN_LENGTH'])
		learning_rate = float(config['LEARNING_RATE'])
		eps = float(config['EPS'])
		seed = int(config['SEED'])
		batch_size = int(config['BATCH_SIZE'])
		epochs = int(config["EPOCHS"])

		# folders
		dataset_folder = Path(config["TRAIN_DIRECTORY"]) / Path(config["datasource"])
		if not Path.exists(dataset_folder): 
			dataset_folder.mkdir(parents = True)
		print(f"Base directory for model-related data: {str(dataset_folder)}")
		checkpoint_folder = dataset_folder / Path(config["CHECKPOINT_DIRECTORY"])
		if not Path.exists(checkpoint_folder): 
			checkpoint_folder.mkdir(parents = True)
		print(f"Checkpoint directory: {str(checkpoint_folder)}")

		# get device
		print("Checking devices...")
		device_str = "cpu"
		if torch.cuda.is_available(): device_str = "cuda"
		device = torch.device(device_str)

		print(f"Device for training: {device}")

		# fix seed
		print(f"Random seed: {seed}")
		torch.manual_seed(seed)

		# get dataset
		print("Loading dataset...")
		train_ds, validation_ds, test_ds, tokenizer_source, tokenizer_target = DataSetLoader.get_dataset(config)

		print(f"Maximum token length found: {max_tokens}")

		# data points printed are the amount of sentence pairs
		print(f"Train dataset size: {len(train_ds)}")
		print(f"Validation dataset size: {len(validation_ds)}")
		print(f"Test dataset size: {len(test_ds)}\n")

		# print random example
		print(f"Example data entry: {train_ds[621]}\n")

		# dataloader
		print("Creating dataloaders...")
		train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
		validation_dataloader = DataLoader(validation_ds, batch_size=1, shuffle=True)
		test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=True)

		print("Loading model")
		# TODO: make use of different configurations ?????
		model = TransformerBuilder.build_transformer(tokenizer_source.get_vocab_size(), tokenizer_target.get_vocab_size(), max_tokens, max_tokens, False, True, config["MODEL_DIMENSIONS"], config["NUM_ENCODER_BLOCKS"], config["NUM_HEADS"], config["DROPOUT"]).to(device)

		optimizer = torch.optim.Adam(model.parameters(), learning_rate, eps = eps)

		old_train_files = list(Path(checkpoint_folder).glob('*'))
		if len(old_train_files) > 0:
			old_train_files.sort(reverse=True)
			old_train_filename = old_train_files[0]
			print(f"Found latest model at: {old_train_filename}")
		
			state = torch.load(old_train_filename)
			model.load_state_dict(state['model_states'])
			optimizer.load_state_dict(state['optimizer_state'])
			global_step = state['global_step']
			epoch = state['epoch'] #to start at next epoch

			print(f"Successfully loaded existing state, at epoch {epoch}")

		return model

In [None]:
config = get_config()
train_dataloader, validation_dataloader, test_dataloader, vocab_source, vocab_target = DataSetLoader.get_dataset(config)
model = load_model(config) 

In [108]:
batch = next(iter(validation_dataloader)) # Loads the next iteration of the validation, goes in sequence, not in parallel
encoder_input = batch["to_encoder"].to(device)  # Gets the encoder-input of the item in the batch
decoder_input = batch["to_decoder"].to(device)

#print(encoder_input)
encoder_input_tokens = [vocab_source.id_to_token(index) for index in encoder_input.cpu().numpy().flatten()]
print(encoder_input_tokens)

print(encoder_input.size())


def load_batch():
    batch = next(iter(validation_dataloader)) # Loads the next iteration of the validation, goes in sequence, not in parallel
    encoder_input = batch["to_encoder"].to(device)  # Gets the encoder-input of the item in the batch
    decoder_input = batch["to_decoder"].to(device)
    encoder_mask = batch["mask_encoder"].to(device) # Gets the mask of the item in the batch
    decoder_mask = batch["mask_decoder"].to(device)

    encoder_input_tokens = [vocab_source.id_to_token(index) for index in encoder_input.cpu().numpy().flatten()] #encoder_input[0] refers to the batch, get it onto cpu, convert tensor to numpy array to make it iterable
    decoder_input_tokens = [vocab_target.id_to_token(idx) for idx in decoder_input.cpu().numpy().flatten()] #Convert Id's to tokens which are mapped to the dictionary

    encoder_input = encoder_input.unsqueeze(0) #This has to be reworked
    assert encoder_input.size(0) == 1, "Batch size must be 1 for validation" 

    model_out = decode_model(model, encoder_input, encoder_mask, vocab_source, vocab_target, config["MAX_SUPPORTED_SENTENCE_TOKEN_LENGTH"], device)

    return batch, encoder_input_tokens, decoder_input_tokens

max_tokens = config['MAX_SUPPORTED_SENTENCE_TOKEN_LENGTH']

def decode_model(model, to_encoder, mask_encoder, vocab_source, vocab_target, config, device):
    s_token = tokenizer_target.token_to_id("<S>")
    e_token = tokenizer_target.token_to_id("<E>")
    
    encoded = model.encode(to_encoder, mask_encoder) # 
    to_decoder = torch.empty(1,1).fill_(s_token).type_as(to_encoder).to(device)
    # Initializes tensor of shape (1,1), fills it with SOS tokens, sets it to be of the same type as to_encoder, gets it onto cuda
    for iteration in range(0, max_tokens): # iterates until it reaches the limit for the sequence length
	    mask_decoder = TranslationDataset.triangular_mask(to_decoder.size(1)).type_as(mask_encoder).to(device)
	    #Creates triangular matrix of initial size (1,1), this will increase with each iteration, makes sure it is of the same type, gets it onto cuda
							
	    # get output
	    output = model.decode(encoded, mask_encoder, to_decoder, mask_decoder) 
	    #Passes all inputs needed into the decoder block

	    p = model.project(output[:, -1])
	    #Extracts last predicted token and passes it through the projection layer, which maps the decoder output to logits over the vocabulary
	    _, most_likely = torch.max(p, dim=1)

	    if most_likely == e_token: break # we reached the end
	    #next run with old content to decode + most likely token
	    to_decoder = torch.cat(
            [
                to_decoder,  # Last input
                torch.empty(1,1).type_as(to_encoder).fill_(most_likely.item()).to(device)  # Creates new tensor with shape (1,1), makes sure of the type, fills it with predicted token, and puts it onto device used
            ], dim=1
        )
        #dim=1 concats it along the row, dim=0 would stack them on top of each other
    return to_decoder.squeeze(0)

['<S>', ',,', 'Wann', 'mögen', 'Sie', 'uns', 'vermißt', 'haben', ',', 'Tom', '?"', '<E>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', '<P>', 

In [122]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens): # Converts attention matrix into Pandas Dataframe
    #m is the attention matrix, max_rows, and the tokens of the attention matrix / their positions. Done to more easily visualize data
    return pd.DataFrame(
        [
            (
                rows, #iterating over all rows
                columns , #iterating over all columns  
                float(m[rows, columns]), #Gets the attention value at [row, ccolumn], is a number
                "%.3d %s" % (rows, row_tokens[rows] if len(row_tokens) > rows else "<blank>"), #"%.3d %s" % formats number as a three-digit integer
                "%.3d %s" % ( columns , col_tokens[columns] if len(col_tokens) > columns  else "<blank>"), #Retrieves the index if it exists, if it is out of range fills it with "blank"
            ) #row_tokens and col_tokens give a number
            for rows in range(m.shape[0]) #.shape returns the shape of the matrix
            for columns in range(m.shape[1]) #.shape returns the shape of the matrix
            if rows < max_row and columns < max_col # The tuple which has been created above is only added if rows < max_row and columns < max_col
        ],
        columns =["row", "column", "value", "row_token", "col_token"], #defines column names for the Pandas DataFrame being created.
    )

def get_attn_map(attn_type: str, layer: int, head: int): #Gets a specific attention type from a specified layer and a specified head
    if attn_type == "encoder":
        
        attn = model.encoder.encoder_module_list._modules['0'].self_attention_layer.attention_scores # self attention_scores comes from calculate_attention in transformer
    elif attn_type == "decoder":
        attn = model.decoder.decoder_module_list._modules['0'].self_attention_layer.attention_scores # self
    elif attn_type == "encoder-decoder":
        attn = model.decoder.decoder_module_list._modules['0'].cross_attention_layer.attention_scores #cross
    return attn[0, head].data

    #Shape (batch_size, num_heads, query_len, key_len) Gets the first sample in the batch for inference and the specified attention head, .data to extract raw tensor values

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(get_attn_map(attn_type, layer, head), max_sentence_len, max_sentence_len, row_tokens, col_tokens) #Attention matrix, max_row, max_col, row_tokens, col_tokens
    #Creates dataframe representation of an attention map

    return (alt.Chart(data=df).mark_rect().encode(  #df is the data we feed it, creates rects
            x=alt.X("col_token", axis=alt.Axis(title="")), # Horizontal axis represents tokens on the column side.
            y=alt.Y("row_token", axis=alt.Axis(title="")), # Vertical axis represents tokens on the row side.
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )  #Changes color intensity based on value, displays values when hovering over the rects
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}") #sets size and gives dynamic titles 
        .interactive() # enables zooming etc.
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int): 
    # Creates grid of attention maps by layers and heads
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len)) #Creates Heatmaps for each layer 
        charts.append(alt.hconcat(*rowCharts)) #horizontal concatenation to arrange attention maps for all heads in a single row
    return alt.vconcat(*charts) # vertical concatenation to stack rows on top of each other
    
    
    #Attention of all heads and all layers that are given as the input

In [115]:
batch, encoder_input_tokens, decoder_input_tokens = load_batch()
print(batch.keys())  # Shows all available keys in the batch

print(f'Source: {batch['text_source']}')
print(f'Target: {batch['text_target']}')
sentence_len = encoder_input_tokens.index("<P>")

dict_keys(['to_encoder', 'to_decoder', 'label', 'text_source', 'text_target', 'mask_encoder', 'mask_decoder'])
Source: ,,Wann mögen Sie uns vermißt haben, Tom?"
Target: "When would they miss us, Tom?"


In [123]:
layers = [0, 1, 2]
heads = [0, 1, 2, 3, 4, 5, 6, 7]

# Encoder Self-Attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))

In [117]:
# Decoder Self-Attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [124]:
# Cross-Attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [None]:
"""
Do I need this?
config = get_config()
file_path = str(Path('.').parent.resolve() / config["TRAIN_DIRECTORY"] / config["datasource"] / config["CHECKPOINT_DIRECTORY"] / config["model_name"])
print(file_path)

#Load pretrained weights
state = torch.load(file_path)
"""