### Embedding Extraction

In [None]:
import torch
from sentence_transformers import SentenceTransformer, util
import argparse
import pandas as pd
from datasets import load_dataset, Dataset
from tqdm import tqdm
from itertools import chain
import numpy as np
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
semantic_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
semantic_model.to(device)
data_name = 'embedding-data/sentence-compression'
original_data = load_dataset(data_name)

orig_text = list(item[0] for item in original_data['train']['set'][:10000])
comp_text = list(item[1] for item in original_data['train']['set'][:10000])

orig_data = Dataset.from_dict({'text': orig_text})
comp_data = Dataset.from_dict({'text': comp_text})

orig_embedding_list = []
comp_embedding_list = []

for i in tqdm(range(orig_data.num_rows)):
    orig_embedding_list.append(semantic_model.encode(orig_data[i]['text'], convert_to_tensor=True).unsqueeze(0).cpu().numpy())
    comp_embedding_list.append(semantic_model.encode(comp_data[i]['text'], convert_to_tensor=True).unsqueeze(0).cpu().numpy())

orig_embeddings = np.vstack(orig_embedding_list)
comp_embeddings = np.vstack(comp_embedding_list)

np.savetxt(os.path.join('m', 'orig_embeddings_sc.text'), orig_embeddings, delimiter=" ")
np.savetxt(os.path.join('m', 'comp_embeddings_sc.text'), comp_embeddings, delimiter=" ")



### Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
import numpy as np
import os
import argparse
from torch.optim import lr_scheduler
from torch.autograd import Variable

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.fc = nn.Linear(dim, dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc(x)
        out = self.relu(out)
        out = out + x 
        return out

class TransformModel(nn.Module):
    def __init__(self, num_layers=4, input_dim=768, hidden_dim=512, output_dim=384):
        super(TransformModel, self).__init__()
        
        self.layers = nn.ModuleList()
        
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        
        for _ in range(num_layers - 2):
            self.layers.append(ResidualBlock(hidden_dim))

        self.layers.append(nn.Linear(hidden_dim, output_dim))

    def forward(self, x):
        for i in range(len(self.layers)):
            x = self.layers[i](x)
        return x

class VectorDataset(Dataset):
    def __init__(self, vectors):
        self.vectors = vectors
    
    def __len__(self):
        return len(self.vectors)
    
    def __getitem__(self, idx):
        return self.vectors[idx]


def sequence_similarity(x, y):
    # euclidean distance
    matches = torch.sqrt(torch.sum((x-y)**2, dim=-1))
    return matches

def sign_loss(x, factor):
    smooth_sign = torch.tanh(x*factor)
    row = torch.abs(torch.mean(torch.mean(smooth_sign, dim=0)))
    col = torch.abs(torch.mean(torch.mean(smooth_sign, dim=1)))
    return (row + col)/2

def loss_fn(
        output_a, 
        output_b, 
        input_a, 
        input_b, 
        output_a_hat, 
        lambda_1,
        lambda_2,
        lambda_3):
    
    s = sequence_similarity(input_a, input_b)
    t_s = sequence_similarity(output_a, output_b)

    min_orig, max_orig, min_new, max_new = 0, 2, -2, 4
    s = (s - min_orig) / (max_orig - min_orig) * (max_new - min_new) + min_new

    loss_1 = torch.abs(s - t_s).mean()
    loss_2 = (sign_loss(output_a, 1000) + sign_loss(output_b, 1000))/2
    loss_3 = torch.abs(sequence_similarity(output_a, output_a_hat)).mean()
    
    total_loss = lambda_1*loss_1 + lambda_2*loss_2 + lambda_3*loss_3
    return total_loss, loss_1, loss_2, loss_3


def parser(args=None):
    parser = argparse.ArgumentParser(description="Train watermark model")
    parser.add_argument("--orig_input_file", type=str, default="orig_embeddings_sc.text")
    parser.add_argument("--comp_input_file", type=str, default="comp_embeddings_sc.text")
    parser.add_argument("--output_model", type=str, default="transform_model.pth")
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--input_dim", type=int, default=768)

    args = parser.parse_args(args=args)
    return args


DATA_DIR = "data"
MODEL_DIR = "model"

args = parser(['--epochs', '500',
               '--lr', '1e-5',
               '--output_model', 'transform_model_1.pth'])
print(args)


orig_embedding_file = os.path.join(DATA_DIR, args.orig_input_file)
comp_embedding_file = os.path.join(DATA_DIR, args.comp_input_file)
orig_embedding_data = np.loadtxt(orig_embedding_file)
comp_embedding_data = np.loadtxt(comp_embedding_file)

orig_data_1 = torch.tensor(orig_embedding_data[:5000], device='cuda', dtype=torch.float32)
orig_data_2 = torch.tensor(orig_embedding_data[5000:], device='cuda', dtype=torch.float32)
comp_data = torch.tensor(comp_embedding_data[:5000], device='cuda', dtype=torch.float32)

orig_dataset_1 = VectorDataset(orig_data_1)
orig_dataset_2 = VectorDataset(orig_data_2)
comp_dataset = VectorDataset(comp_data)

orig_dataloader_1 = DataLoader(orig_dataset_1, batch_size=128, shuffle=False)
orig_dataloader_2 = DataLoader(orig_dataset_2, batch_size=128, shuffle=False)
comp_dataloader = DataLoader(comp_dataset, batch_size=128, shuffle=False)

model = TransformModel(input_dim=args.input_dim).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.2)
scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)


for epoch in range(args.epochs):
    orig_iter_1 = iter(orig_dataloader_1)
    orig_iter_2 = iter(orig_dataloader_2)
    comp_iter = iter(comp_dataloader)

    for _ in range(len(orig_dataloader_1)):
        input_a = next(orig_iter_1).to(device)
        input_b = next(orig_iter_2).to(device)
        input_c = next(comp_iter).to(device)

        output_a = model(input_a)
        output_b = model(input_b)
        output_c = model(input_c)

        loss, loss_1, loss_2, loss_3 = loss_fn(output_a, output_b, input_a, input_b, output_c, 1, 1, 1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if _ % 100 == 0:
            print(f"Epoch [{epoch + 1}/{args.epochs}], Step [{_ + 1}/{len(orig_dataloader_1)}], Loss: {loss.item()}")

model_path = os.path.join(MODEL_DIR, args.output_model)
torch.save(model.state_dict(), model_path)
