In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import sys
sys.path.insert(1, "/home/oru2/project/project")

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time
# import matplotlib.pyplot as plt
import attacks
from privacy_accountant import PrivacyAccountant
from tqdm import tqdm
from torchvision import datasets, transforms

from collections import Counter
import torch
from torch.utils.data import DataLoader, TensorDataset


In [4]:
import pandas as pd

data = pd.read_csv("../imdb_data/IMDB Dataset.csv")

In [5]:
#credit: https://www.kaggle.com/code/m0hammadjavad/imdb-sentiment-classifier-pytorch/notebook
import nltk
import os
import spacy
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt_tab')

nlp = spacy.load("en_core_web_sm")

stop_words = set(stopwords.words('english'))

def preprocess_text(text):
    tokens = word_tokenize(text.lower())
    
    filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]

    doc = nlp(' '.join(filtered_tokens))
    
    lemmetized_tokens = [token.lemma_ for token in doc]
    
    return ' '.join(lemmetized_tokens)

if os.path.exists("../imdb_data/IMDB Dataset_with_cleaned_reviews.csv"):
    data = pd.read_csv("../imdb_data/IMDB Dataset_with_cleaned_reviews.csv")
else:  
    data["cleaned_reviews"] = data["review"].apply(preprocess_text)
    data.to_csv("../imdb_data/IMDB Dataset_with_cleaned_reviews.csv")

data["sentiment"] = data["sentiment"].map({"positive": 1, "negative": 0})

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(data["cleaned_reviews"], data["sentiment"], test_size=0.2, random_state=42)
print(f'Training samples: {len(X_train)}, Test samples: {len(X_test)}')


# Create a vocabulary based on the training data
def build_vocab(texts):
    
    # this makes a dict with unique words and their count as the value
    # although this is not going to be used directly, it only gives us unique words without repeatition
    counter = Counter()
    for text in texts:
        counter.update(text.split())
        
    # this makes a dict of unique words and their index as the value
    vocab = {word: idx for idx, (word, _) in enumerate(counter.items(), 1)}  # Reserve index 0 for padding
    
    # this is a convention which is going to be used to convert batches to a fixed size
    vocab['<PAD>'] = 0
    
    return vocab

# Build the vocabulary
vocab = build_vocab(X_train)

# Encode text sequences into integer sequences
def encode_text(text, vocab, max_length=200):
    tokens = text.split()
    encoded = [vocab.get(token, 0) for token in tokens]  # 0 for unknown tokens
    if len(encoded) < max_length:
        encoded += [vocab['<PAD>']] * (max_length - len(encoded))  # Padding
    return encoded[:max_length]  # Truncate to max_length

# Encode all reviews
X_train_encoded = torch.tensor([encode_text(text, vocab) for text in X_train])
X_test_encoded = torch.tensor([encode_text(text, vocab) for text in X_test])
y_train_tensor = torch.nn.functional.one_hot(torch.tensor(y_train.values))
y_test_tensor =  torch.nn.functional.one_hot(torch.tensor(y_test.values))

# Create DataLoader for batching
train_data = TensorDataset(X_train_encoded, y_train_tensor)
test_data = TensorDataset(X_test_encoded, y_test_tensor)

train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(test_data, batch_size=100, shuffle=False)

[nltk_data] Downloading package punkt to /home/oru2/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/oru2/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/oru2/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Training samples: 40000, Test samples: 10000


In [6]:
use_cuda = True
device = torch.device("cuda:1" if use_cuda else "cpu")
batch_size = 100

np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fadea37e290>

In [None]:
def train_model(model, train_loader, num_epochs):
    # TODO: implement this function that trains a given model on the MNIST dataset.
    # this is a general-purpose function for both standard training and adversarial training.
    # (toggle enable_defense parameter to switch between training schemes)
    model.train()
    # epsilons_clean = []
    lr = 1e-2
    losses = []
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in tqdm(range(num_epochs)):
        for index, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            logits = model(inputs)

            loss = nn.BCELoss()(logits.squeeze(), labels.float())
            loss.backward()
            optimizer.zero_grad()
            optimizer.step()
            losses.append(loss.item())
            
            embeddings = model.embedding(inputs)
            adversary_embeddings = attacks.perturb_data(model, embeddings, labels).to(device)

            
            predictions = model.lstm(adversary_embeddings)[0][:, -1, :]
            predictions = torch.sigmoid(model.fc(predictions))

            loss = nn.BCELoss()(predictions.squeeze(), labels.float())
            loss.backward()
            optimizer.zero_grad()
            optimizer.step()

        print(f'Epoch [{epoch}/{num_epochs}] Loss = {loss.item():.3f}')

In [114]:
from model import SentimentModel
vocab_size = len(vocab)
embed_size = 100
hidden_size = 128
output_size = 2  # Binary classification (positive/negative)
num_layers = 2
fc_model = SentimentModel(vocab_size, embed_size, hidden_size, output_size, num_layers).to(device)
num_epochs = 20



In [115]:

train_model(fc_model, train_loader, num_epochs)

  5%|▌         | 1/20 [00:36<11:30, 36.37s/it]

Epoch [0/20] Loss = 0.690


 10%|█         | 2/20 [01:11<10:44, 35.80s/it]

Epoch [1/20] Loss = 0.686


 15%|█▌        | 3/20 [01:47<10:05, 35.63s/it]

Epoch [2/20] Loss = 0.434


 20%|██        | 4/20 [02:22<09:28, 35.55s/it]

Epoch [3/20] Loss = 0.354


 25%|██▌       | 5/20 [02:58<08:52, 35.50s/it]

Epoch [4/20] Loss = 0.247


 30%|███       | 6/20 [03:33<08:16, 35.44s/it]

Epoch [5/20] Loss = 0.239


 35%|███▌      | 7/20 [04:08<07:40, 35.41s/it]

Epoch [6/20] Loss = 0.251


 40%|████      | 8/20 [04:44<07:04, 35.42s/it]

Epoch [7/20] Loss = 0.260


 45%|████▌     | 9/20 [05:19<06:29, 35.41s/it]

Epoch [8/20] Loss = 0.176


 50%|█████     | 10/20 [05:54<05:54, 35.40s/it]

Epoch [9/20] Loss = 0.126


 55%|█████▌    | 11/20 [06:30<05:18, 35.40s/it]

Epoch [10/20] Loss = 0.153


 60%|██████    | 12/20 [07:05<04:43, 35.42s/it]

Epoch [11/20] Loss = 0.093


 65%|██████▌   | 13/20 [07:41<04:08, 35.43s/it]

Epoch [12/20] Loss = 0.153


 70%|███████   | 14/20 [08:16<03:32, 35.46s/it]

Epoch [13/20] Loss = 0.143


 75%|███████▌  | 15/20 [08:52<02:57, 35.46s/it]

Epoch [14/20] Loss = 0.069


 80%|████████  | 16/20 [09:27<02:21, 35.43s/it]

Epoch [15/20] Loss = 0.075


 85%|████████▌ | 17/20 [10:02<01:46, 35.42s/it]

Epoch [16/20] Loss = 0.169


 90%|█████████ | 18/20 [10:38<01:10, 35.43s/it]

Epoch [17/20] Loss = 0.066


 95%|█████████▌| 19/20 [11:13<00:35, 35.44s/it]

Epoch [18/20] Loss = 0.163


100%|██████████| 20/20 [11:49<00:00, 35.46s/it]

Epoch [19/20] Loss = 0.057





In [116]:
torch.save(fc_model, 'models/adv.pt')

In [8]:
fc_model = torch.load('models/baseline.pt')

  fc_model = torch.load('models/baseline.pt')


In [122]:
correct = 0
fc_model.eval()
for j, (inputs, labels) in enumerate(test_loader):
  inputs = inputs.to(device)
  labels = labels.to(device)
  
  logits = fc_model(inputs)

  prediction = torch.argmax(logits, 1)
  correct += (prediction == torch.argmax(labels, dim=1)).sum().item()
  # print('Batch [{}/{}]'.format(j+1, len(test_loader)))
fc_model.train()
print('Accuracy = {}%'.format(float(correct) * 100 / 10000))

Accuracy = 85.27%


In [123]:
correct = 0
eps = 0.1
# attack = attacks.LSTMPGD(fc_model, epsilon=eps, num_steps=10, embedding_layer=fc_model.embedding)
torch.backends.cudnn.enabled = False
for j, (inputs, labels) in enumerate(test_loader):
  inputs = inputs.to(device)
  labels = labels.to(device)
  embeddings = fc_model.embedding(inputs)
  adv_inputs = attacks.perturb_data(fc_model, embeddings, labels, epsilon = eps)
  # print(adv_images)
  logits = fc_model(inputs)

  adv_logits = fc_model.lstm(adv_inputs)[0][:, -1, :]
  adv_logits = torch.sigmoid(fc_model.fc(adv_logits))
  # adv_logits = fc_model(adv_images)

  prediction = torch.argmax(logits, 1)

  adv_prediction = torch.argmax(adv_logits, 1)


  correct += (prediction == torch.argmax(labels)).sum().item()
  correct += (adv_prediction == torch.argmax(labels)).sum().item()
  # print('Batch [{}/{}]'.format(j+1, len(test_loader)))
print('Accuracy = {}%'.format(float(correct) * 100 / 20000))

Accuracy = 50.465%


In [124]:
from art.attacks.inference.membership_inference import MembershipInferenceBlackBox
from art.estimators.classification import PyTorchClassifier

In [125]:
from model import SentimentModelNoEmbed
no_embed_model = SentimentModelNoEmbed(embed_size, hidden_size, output_size, num_layers).to(device)

In [126]:
print(fc_model.state_dict().keys())
state_dict = {k: fc_model.state_dict()[k] for k in filter(lambda x: not x.startswith('embedding'), fc_model.state_dict())}
no_embed_model.load_state_dict(state_dict, strict=False) 


odict_keys(['embedding.weight', 'lstm.weight_ih_l0', 'lstm.weight_hh_l0', 'lstm.bias_ih_l0', 'lstm.bias_hh_l0', 'lstm.weight_ih_l1', 'lstm.weight_hh_l1', 'lstm.bias_ih_l1', 'lstm.bias_hh_l1', 'fc.weight', 'fc.bias'])


<All keys matched successfully>

In [127]:
optimizer = torch.optim.Adam(fc_model.parameters())
criterion = nn.BCELoss()


# Wrap the PyTorch model in ART's PyTorchClassifier
art_classifier = PyTorchClassifier(
    model=no_embed_model,
    loss=criterion,
    optimizer=optimizer,
    input_shape=(200,),
    nb_classes=2
)
attack_train_size = 10000
attack_test_size = 5000

x_train = train_data.tensors[0]
y_train = train_data.tensors[1].detach().numpy()


x_test = test_data.tensors[0]
y_test = test_data.tensors[1].detach().numpy()

x_train = nn.Embedding(vocab_size, embed_size)(x_train).detach().numpy()
x_test = nn.Embedding(vocab_size, embed_size)(x_test).detach().numpy()

attack = MembershipInferenceBlackBox(estimator=art_classifier, attack_model_type="nn")
attack.fit(x_train[:attack_train_size], y_train[:attack_train_size], x_test[:attack_test_size], y_test[:attack_test_size])


In [128]:
mlp_inferred_train_bb = attack.infer(x_train[attack_train_size:], y_train[attack_train_size:])
mlp_inferred_test_bb = attack.infer(x_test[attack_test_size:], y_test[attack_test_size:])

# check accuracy
mlp_train_acc_bb = np.sum(mlp_inferred_train_bb) / len(mlp_inferred_train_bb)
mlp_test_acc_bb = 1 - (np.sum(mlp_inferred_test_bb) / len(mlp_inferred_test_bb))
mlp_acc_bb = (mlp_train_acc_bb * len(mlp_inferred_train_bb) + mlp_test_acc_bb * len(mlp_inferred_test_bb)) / (len(mlp_inferred_train_bb) + len(mlp_inferred_test_bb))

print(f"Members Accuracy: {mlp_train_acc_bb:.4f}")
print(f"Non Members Accuracy {mlp_test_acc_bb:.4f}")
print(f"Attack Accuracy {mlp_acc_bb:.4f}")

Members Accuracy: 0.8790
Non Members Accuracy 0.4380
Attack Accuracy 0.8160


In [63]:
x_train.detach()

tensor([[[ 1.2008e+00,  1.5396e+00,  1.4179e+00,  ...,  5.5041e-01,
          -4.3193e-01, -2.4356e-01],
         [ 2.2167e-01, -2.7137e-01,  2.1014e-01,  ..., -1.0977e+00,
          -7.6189e-01, -1.6361e+00],
         [-8.4064e-03, -9.4465e-02, -4.9066e-01,  ...,  7.0453e-01,
           6.9287e-01, -1.4205e+00],
         ...,
         [ 7.2725e-02,  7.8139e-01,  6.2422e-01,  ..., -2.6550e+00,
           3.7356e-01, -4.9829e-01],
         [ 7.2725e-02,  7.8139e-01,  6.2422e-01,  ..., -2.6550e+00,
           3.7356e-01, -4.9829e-01],
         [ 7.2725e-02,  7.8139e-01,  6.2422e-01,  ..., -2.6550e+00,
           3.7356e-01, -4.9829e-01]],

        [[ 1.4441e+00,  1.1737e+00, -1.4373e+00,  ..., -1.0749e+00,
          -1.6892e+00, -1.3481e+00],
         [ 6.1772e-01, -6.9171e-01,  1.2618e+00,  ...,  1.9872e+00,
          -3.9961e-01, -5.8853e-01],
         [-1.2826e-01, -4.8372e-01, -5.2664e-01,  ..., -5.0066e-01,
          -1.2826e-01, -1.1033e+00],
         ...,
         [ 7.2725e-02,  7