In [None]:
import re
import io
import os
import sys
import math
import requests

import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets

from torchvision.utils import save_image
import torchvision.transforms.functional as TF

from pathlib import Path
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, random_split

In [None]:
device = torch.device('cuda:0')

seed = 2146
torch.manual_seed(seed)

In [None]:
bits = 8
bits_vocab_len = 2**bits

print(f"bits vocab len: {bits_vocab_len}")

## GPT2

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
llm = "gpt2" #gpt2-medium / gpt2-large /gpt2-xl
model = GPT2LMHeadModel.from_pretrained(llm)
llm_tokenizer = GPT2Tokenizer.from_pretrained(llm)

In [None]:
embeddings = model.lm_head.weight
# embedding_matrix = model.transformer.wte.weight
llm_feature_dim = model.config.hidden_size
llm_vocab_len = model.config.vocab_size
model.to(device)

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
print("gpt2 feature dim length:", llm_feature_dim)
print("gpt2 vocabulary length:", llm_vocab_len)
print("gpt2 embedding shape:", embeddings.shape)

## Mapper Network

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [None]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim, bias=False)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

In [None]:
mapper = TokenMapper(bits_vocab_len, llm_feature_dim, device=device)
reverseMapper = TokenMapper(bits_vocab_len, llm_feature_dim, device=device)

## Ground Truth

In [None]:
def generate_next_token_predictions(token_sequences):
    
    outputs = model(input_ids=token_sequences, output_hidden_states=True)

    return outputs.hidden_states[-1]

In [None]:
def generate_next_token_predictions_withfv(token_fv):
    
    # Get model predictions
    outputs = model(inputs_embeds=token_fv, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [None]:
def translate(batch_feature_vectors, embeddings, temperature=1.0):
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape
    

    cosine_similarities = torch.matmul(batch_feature_vectors, embeddings.T)
    sfmx = torch.softmax(cosine_similarities/temperature, dim=2)
    closest_tokens = torch.argmax(sfmx, dim=2)
    mm = torch.matmul(sfmx, embeddings)

    return mm, cosine_similarities, closest_tokens

## Dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.utils.data.sampler import BatchSampler, SequentialSampler

import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [None]:
bits = 8
batch_size = 1
seq_len = 256

In [None]:
import os
import torch
from torch.utils.data import Dataset

class BinaryDataset(Dataset):
    def __init__(self, directories, filetypes=".jpg",transform=None, seq_len=256):
        """
        Args:
            directory (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.directories = directories
        self.transform = transform
        # self.filenames = [f for f in os.listdir(directory) if f.endswith(filetype)]

        self.filenames = []
        for directory in directories:
            # Store the full path to each file
            full_paths = [os.path.join(directory, f) for f in os.listdir(directory) 
                               if any(f.endswith(filetype) for filetype in filetypes)]
            self.filenames.extend(full_paths)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # Construct the full path to the image file
        path = self.filenames[idx]
        
        # Read the image file as bytes
        with open(path, 'rb') as file:
            b = file.read()

        # If a transform is specified, apply it
        if self.transform:
            b = self.transform(b)
        
        return b, self.filenames[idx]

class ToBinaryString:
    def __init__(self, bits=8, segment_length=256):
        self.bits = bits  # Number of bits to group together into an integer
        self.chunk_size = segment_length  # Number of integers per segment

    def __call__(self, image_bytes):
        # Convert bytes to binary string

        binary_string = ''.join(f'{byte:08b}' for byte in image_bytes)

        integers = [int(binary_string[i:i+self.bits], 2) for i in range(0, len(binary_string), self.bits)]
        tensor = torch.tensor(integers)

        padding_size = (self.chunk_size - tensor.size(0) % self.chunk_size) % self.chunk_size

        # Pad the tensor if necessary
        if padding_size > 0:
            tensor = torch.cat([tensor, torch.zeros(padding_size, dtype=tensor.dtype)])

        # Reshape the tensor into chunks of chunk_size
        # Ensure the total length is divisible by chunk_size before reshaping
        total_length = tensor.size(0) + padding_size
        tensor = tensor.view(-1, self.chunk_size)

        return tensor

In [None]:
def pad_collate(batch, seq_len = 256):
    # Find the maximum number of chunks in this batch
    max_chunks = max([x[0].size(0) for x in batch])
    
    # Pad each item in the batch to this size
    padded_batch = []
    for tensors, filename in batch:
        if tensors.size(0) < max_chunks:
            pad_size = (max_chunks - tensors.size(0)) * seq_len
            padded_tensors = torch.cat([tensors, torch.zeros(pad_size, dtype=tensors.dtype).view(-1, seq_len)])
        else:
            padded_tensors = tensors
        padded_batch.append((padded_tensors, filename))
    
    # Stack all the tensors together along a new 0th dimension, and return filenames separately
    tensors, filenames = zip(*padded_batch)
    return torch.stack(tensors), filenames

In [None]:
from torch.utils.data import DataLoader, random_split

# Initialize the dataset
dirs = ['../data/enwik9/ascii/', '../data/imagenet/train/png_small/', '../data/librispeech/train/wav/']
filetypes = ['.txt', '.png', '.wav']

dataset = BinaryDataset(directories=dirs, filetypes=filetypes, transform=ToBinaryString(segment_length=seq_len))

train_size = int(len(dataset)*0.8)
val_size = len(dataset) - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, val_size])

# Setup the DataLoader
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate)
testloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate)


## Train Model

In [None]:
# Hyper Parameters
learning_rate = 5e-6
epochs = 1
gamma = 0.1
temperature = 0.001
alpha = 1

In [None]:
experiment = "test"
algo = "base"
exp_type = "hybrid"
name = f"{bits}bits"
experiment_name = f"{exp_type}/{algo}/{experiment}/{name}/{llm}/lr={learning_rate}/gamma={gamma}/temp={temperature}/promptlen={prompt_len}/seq_len={seq_len}"
experiment_name

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Create a SummaryWriter instance (logs will be saved in 'runs' folder)
writer = SummaryWriter(log_dir = f'../runs_test/{experiment_name}')

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.Adam(list(mapper.parameters()) + list(reverseMapper.parameters()), lr=learning_rate)

In [None]:
global_step = 0
for epoch in range(epochs):
    mapper.train()
    reverseMapper.train()
    for i, dd in enumerate(dataloader):
        
        optimizer.zero_grad()
        data = dd[0]

        ground_truth_tokens = data.reshape(-1, seq_len).to(device)
        one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=bits_vocab_len).float()

        # Logits are to be compared with the next ground truth tokens
        ground_truth_tokens = ground_truth_tokens[:,1:]
        inputs_feature_vector = mapper(one_hot_tokens)
        
        # Map tokens and get ground truth from LLM
        translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
        
        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)

        # Calculate Logits with mapper function
        logits = torch.matmul(final_layer_fv, reverseMapper.mapper.weight)
        logits = logits[:,:-1]
        logits_ = logits.reshape(-1, bits_vocab_len)
        ground_truth_tokens = ground_truth_tokens.reshape(-1)        
        ce_loss = criterion(logits_, ground_truth_tokens)
        
        writer.add_scalar("training/cross_entropy", ce_loss.item(), global_step)
        ce_loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
            
        if global_step%100==0:
            print(f"Epoch {epoch+1}, Batch {global_step}, CE Loss: {ce_loss.mean().item()}")
        global_step+=1

    scheduler.step()
    print(f"Epoch {epoch+1}/{epochs} completed.")
writer.close()

In [None]:
Path(f"../models/{experiment_name}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"../models/{experiment_name}/mapper.pt")
torch.save(reverseMapper.state_dict(), f"../models/{experiment_name}/reversemapper.pt")

## Evaluation

In [None]:
# Load Models (Optional)

mapper.load_state_dict(torch.load(f"../models/{experiment_name}/mapper.pt"))
reverseMapper.load_state_dict(torch.load(f"../models/{experiment_name}/reversemapper.pt"))

In [None]:
global_step = 0
total = []
mapper.eval()
reverseMapper.eval()
for i, dd in enumerate(testloader):

    data = dd[0]
    
    ground_truth_tokens = data.reshape(-1, seq_len).to(device)
    one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=bits_vocab_len).float()

    # Logits are to be compared with the next ground truth tokens
    ground_truth_tokens = ground_truth_tokens[:,1:]
    inputs_feature_vector = mapper(one_hot_tokens)

    # Map tokens and get ground truth from LLM
    translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)

    # Calculate Representation of Last Layer in LLM
    final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)

    # Calculate Logits with mapper function
    logits = torch.matmul(final_layer_fv, reverseMapper.mapper.weight)
    # logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
    logits = logits[:,:-1]
    logits_ = logits.reshape(-1, bits_vocab_len)
    ground_truth_tokens = ground_truth_tokens.reshape(-1)        
    ce_loss = criterion(logits_, ground_truth_tokens)

    total.append(ce_loss.item())
    if global_step%100==0:
        print(f" Batch {global_step}, CE Loss: {ce_loss.mean().item()}")
    global_step+=1

    if global_step % 100 == 0:
        break
        
    torch.cuda.empty_cache()

testing = np.array(total)

print(testing.mean())
print(testing.std())