In [22]:
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from parameters import *
from prompt import *

dataset_train = load_dataset('glue', DATASET, split='train')
dataset_val = load_dataset('glue', DATASET, split='validation')
tokenizer = AutoTokenizer.from_pretrained(MODEL)

Reusing dataset glue (/home/t-chuhanwu/.cache/huggingface/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/t-chuhanwu/.cache/huggingface/datasets/glue/mnli_mismatched/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/t-chuhanwu/.cache/huggingface/datasets/glue/mnli_mismatched/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [None]:
def encode(examples):
    if DATASET=='mnli':
        return tokenizer(examples['premise'], examples['hypothesis'],  truncation=True, padding='max_length',max_length=128)
    if DATASET=='qnli':
        return tokenizer(examples['question'], examples['sentence'],  truncation=True, padding='max_length',max_length=128)
    if DATASET=='qqp': 
        return tokenizer(examples['question1'], examples['question2'], truncation=True, padding='max_length',max_length=128)
    if DATASET=='sst2': 
        return tokenizer(examples['sentence'],  truncation=True, padding='max_length',max_length=128)

dataset_train = dataset_train.map(encode, batched=True)
dataset_val = dataset_val.map(encode, batched=True)

dataset_train = dataset_train.map(lambda examples: {'labels': examples['label']}, batched=True)
dataset_val = dataset_val.map(lambda examples: {'labels': examples['label']}, batched=True)

dataset_train.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=CLIENT_NUM)
dataset_val.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=CLIENT_NUM)


In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer,AutoConfig


config=AutoConfig.from_pretrained(MODEL)
model = BertPrefixForSequenceClassification.from_pretrained(MODEL) #BertPromptForSequenceClassification/BertForSequenceClassification

from tqdm import tqdm 
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model.train().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR)

for name,para in model.named_parameters():
    if 'prefix_encoder' not in name and 'classifier' not in name  and 'pooler' not in name :
        para.requires_grad=False
    else:
        para.requires_grad=True


In [None]:
from tqdm import tqdm
from sklearn.metrics import *
import numpy as np
norm_l2=NORMCLIP
eps=10.
delta=1e-3  
all_loss=0.
for epoch in range(EPOCH):
    for i, batch in enumerate(tqdm(dataloader_train)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        all_loss+=loss.data
        torch.nn.utils.clip_grad_norm_(model.parameters(), NORMCLIP)
        for p in model.parameters():
            if p.grad is not None:
                # add equivalent noise after aggregation
                p.grad += torch.FloatTensor(np.random.normal(0, 2*NORMCLIP*np.sqrt(2*np.log(1.25/DELTA))/EPS/np.sqrt(CLIENT_NUM),size=p.grad.size())).cuda()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            print(f"loss: {all_loss/(i+1)}")
    model.eval()  
    all_pred=[]
    all_label=[]
    for i, batch in enumerate(tqdm(dataloader_val)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        all_pred += np.argmax(outputs[1].detach().cpu().numpy(),axis=-1).tolist()
        all_label+=batch['labels'].detach().cpu().numpy().tolist() 
    print(accuracy_score(all_label,all_pred))
    model.train()
        