In [None]:
import numpy as np
from pprint import pprint

import math
from transformers import GPT2Tokenizer,GPT2LMHeadModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from datasets import load_dataset
from utils import GumbelSoftmax, get_matrices_expansions, get_embeddings
from functional import get_top_in_span,adaptive_sigmoid_probs,reverse_text_embeddings
torch.manual_seed(50)

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

In [None]:
net = GPT2LMHeadModel.from_pretrained("gpt2")
net.eval()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # 设置填充token

In [None]:
def text_to_input(text, max_length=32):
    sample = tokenizer(
        text,
        return_tensors="pt",
        # padding="max_length",  # 确保所有序列等长
        max_length=32,         # 根据你的句子长度调整
        truncation=True
    )
    # print("first_token:",sample["input_ids"][0][0])
    sample["labels"] = sample["input_ids"].clone()
    return sample

def label_to_onehot(label,pad_to_len=20):
    onehot = torch.zeros(pad_to_len, 50257)
    onehot_copy = onehot.clone().detach()
    onehot_copy.scatter_(1, label, 1)
    return onehot_copy

def onehot_criterion(x,pred):
    # x: (batch_size, 50257)
    # pred: (batch_size, 50257)
    loss = torch.nn.BCELoss(reduction='none')(x, pred)
    return loss.sum(dim=1).mean()


In [None]:
text = "The muscles are your body's \"grand central station.\""  # 替换为你的句子
sample = text_to_input(text)
sample_len = sample["input_ids"].shape[1]

gumbels = [GumbelSoftmax() for _ in range(sample_len)]
Us = [nn.Parameter(torch.randn(1, 50257).uniform_(-0.5, 0.5)) for _ in range(sample_len)]


for i,m in net.named_parameters():
    if "weight" in i:
        m.data.uniform_(-0.5, 0.5)
    if "bias" in i:
        m.data.uniform_(-0.5, 0.5)


In [None]:
criterion = onehot_criterion
######### honest partipant #########
# compute original gradient 
out = net(**sample)
logits =[adaptive_sigmoid_probs(logit) for logit in out.logits.squeeze(0)]
logits = torch.stack(logits, dim=0)
orig_loss = criterion(logits, label_to_onehot(sample["labels"]))
dy_dx = torch.autograd.grad(orig_loss, net.parameters(), create_graph=True, allow_unused=True)

# share the gradients with other clients
original_dy_dx = list((_.detach().clone() for _ in dy_dx if _ is not None))
    
dummy_text = "Any Dummy Text To Be Replaced initially" 
dummy_sample = text_to_input(dummy_text)
dummy_ids = dummy_sample["input_ids"]
dummy_label = dummy_sample["labels"]
dummy_label = label_to_onehot(dummy_label, pad_to_len=sample_len)

# optimizer = torch.optim.LBFGS(Us,lr=1e-2)
optimizer = torch.optim.AdamW(Us,lr=1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

_, R_Qs = get_matrices_expansions(original_dy_dx)

In [None]:
history = []
for iters in range(300):
    def closure():
        global dummy_label, dummy_ids
        optimizer.zero_grad()  
        dummy_logits = torch.zeros_like(dummy_label).to(device)
        for pos in range(sample_len):
            emb = get_embeddings(net, pos)
            probs = get_top_in_span(R_Qs[0], emb, 0.005,'l2')
            # gumbel_probs = gumbels[pos](pos_probs[pos],Us[pos],hard=False)
            gumbel_probs = gumbels[pos](probs,Us[pos],hard=False)
            dummy_logits[pos]= gumbel_probs
        
        loss = criterion(dummy_logits,dummy_label) 
        dummy_dy_dx = torch.autograd.grad(loss, net.parameters(), create_graph=True,allow_unused=True)
        dummy_dy_dx = (_ for _ in dummy_dy_dx if _ is not None)
        grad_diff = 0
        grad_count = 0
        for gx, gy in zip(dummy_dy_dx, original_dy_dx): # TODO: fix the variablas here
            grad_diff += ((gx - gy) ** 2).sum()
            grad_count += gx.nelement()
        
        dummy_ids = torch.argmax(dummy_logits.detach(), dim=-1).unsqueeze(0)
        dummy_label = label_to_onehot(dummy_ids, pad_to_len=sample_len)
        grad_diff.backward()
        return grad_diff

    # optimizer.step(closure)
    current_loss = closure()
    optimizer.step()
    scheduler.step()
    print(f"[Iter {iters}] Current loss:", "%.4f" % current_loss.item())
    dummy_text = tokenizer.decode(dummy_ids[0])
    print("dummy_text:",dummy_text)
