In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm
from transformers import BertTokenizer
from gensim.models import KeyedVectors
from datasets import load_dataset
from sklearn.metrics import classification_report, accuracy_score
import numpy as np
import gensim.downloader as api

## Prepare Data

In [2]:
# Load AG News Dataset
dataset = load_dataset('ag_news')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize(text):
    return tokenizer(text, padding='max_length', truncation=True, return_tensors='pt', max_length=128)

# Load Word2Vec Embeddings
word2vec = api.load('word2vec-google-news-300')

def get_word2vec_embedding(tokens):
    embeddings = []
    for token in tokens:
        if token in word2vec:
            embeddings.append(word2vec[token])
    if len(embeddings) == 0:
        return np.zeros(word2vec.vector_size)
    return np.mean(embeddings, axis=0)

class AGNewsDataset(data.Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        tokens = tokenizer.tokenize(self.texts[idx])
        embedding = get_word2vec_embedding(tokens)
        label = self.labels[idx]
        return torch.tensor(embedding, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

train_dataset = AGNewsDataset(train_texts, train_labels)
test_dataset = AGNewsDataset(test_texts, test_labels)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)


## Model

In [3]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 100, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(100 * (word2vec.vector_size // 2), 4)  # 4 classes in AG News

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

## Init params

In [4]:
# Initialize model and criterion
model = CNNModel()
criterion = nn.CrossEntropyLoss()

In [6]:
# Synthetic data initialization
num_classes = 4
num_synthetic_per_class = 10
max_length = 128

synthetic_text_data = torch.randn(num_classes * num_synthetic_per_class, word2vec.vector_size, requires_grad=True)
synthetic_labels = torch.tensor([i for i in range(num_classes) for _ in range(num_synthetic_per_class)], dtype=torch.long)

init_synthetic_dataset = data.TensorDataset(synthetic_text_data.detach(), synthetic_labels)

## util functions

In [7]:
def compute_gradients(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients

def layerwise_matching_loss(gw_syn, gw_real):
    loss = 0
    for g_syn, g_real in zip(gw_syn, gw_real):
        loss += ((g_syn - g_real) ** 2).sum()
    return loss

def compute_gradients_and_outputs(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients, outputs, loss

def get_real_data_batch_for_class(train_dataset, target_class, batch_size, device):
    """
    Function to get a batch of real data belonging to a specific class from the dataset.
    """
    class_indices = [i for i, label in enumerate(train_dataset.labels) if label == target_class]
    selected_indices = random.sample(class_indices, batch_size)
    real_data_batch = [train_dataset[i][0] for i in selected_indices]
    real_labels_batch = [train_dataset[i][1] for i in selected_indices]
    
    # Stack tensors to create batch and move to device
    real_data_batch = torch.stack(real_data_batch).to(device)
    real_labels_batch = torch.tensor(real_labels_batch).to(device)
    
    return real_data_batch, real_labels_batch

## Main Loop

In [8]:
optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Training loop with synthetic data gradient matching
num_epochs = 10
K = 10
T = 5
inner_loop1_range = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64

In [9]:
from tqdm.notebook import tqdm
import random

# Define a function to sample real data for a given class from the dataset


# Main training loop
for epoch in range(num_epochs):
    model.train()
    loss_avg = 0
    train_correct = 0
    train_total = 0

    model = CNNModel()
    model.to(device)
    
    # Randomly initialize the model weights
    model.apply(lambda m: m.reset_parameters() if hasattr(m, 'reset_parameters') else None)
    
    for k in tqdm(range(K), desc=f'Outer Loop Epoch {epoch+1}/{num_epochs}', unit='loop', leave=False):  # Outer loop
        for inner_loop1 in range(inner_loop1_range):
            for class_idx in range(num_classes):  # Per-class synthetic data update
                synthetic_data_batch = synthetic_text_data[class_idx * num_synthetic_per_class:(class_idx + 1) * num_synthetic_per_class]
                synthetic_labels_batch = synthetic_labels[class_idx * num_synthetic_per_class:(class_idx + 1) * num_synthetic_per_class]
                
                # Sample real data for the current class
                real_data_batch, real_labels_batch = get_real_data_batch_for_class(train_dataset, class_idx, batch_size, device)
                
                # Compute gradients for real data
                gradients_real, _, _ = compute_gradients_and_outputs(model, real_data_batch, real_labels_batch)
                
                # Compute gradients for synthetic data
                gradients_synthetic, _, _ = compute_gradients_and_outputs(model, synthetic_data_batch.to(device), synthetic_labels_batch.to(device))
                
                # Compute matching loss
                loss_match = layerwise_matching_loss(gradients_synthetic, gradients_real)
                
                # Zero the gradients
                optimizer_syn.zero_grad()
                
                # Backward pass and optimize synthetic data
                loss_match.backward()
                optimizer_syn.step()

        synthetic_dataset = data.TensorDataset(synthetic_text_data.detach(), synthetic_labels)
        synthetic_loader = data.DataLoader(synthetic_dataset, batch_size=64, shuffle=True)

        for t in range(T):  # Inner loop for network training
            for synthetic_data_batch, synthetic_labels_batch in tqdm(synthetic_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch', leave=False):
                synthetic_data_batch, synthetic_labels_batch = synthetic_data_batch.to(device), synthetic_labels_batch.to(device)
           
                optimizer_net.zero_grad()
                outputs = model(synthetic_data_batch.to(device))
                loss = criterion(outputs, synthetic_labels_batch.to(device))
                loss.backward()
                optimizer_net.step()
        
                loss_avg += loss.item()
        
                # Calculate training accuracy
                _, predicted = torch.max(outputs.data, 1)
                train_total += synthetic_labels_batch.size(0)
                train_correct += (predicted == synthetic_labels_batch.to(device)).sum().item()

    train_accuracy = 100 * train_correct / train_total

    # Evaluate model on test set
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f'Evaluating Epoch {epoch+1}/{num_epochs}', unit='batch', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_accuracy = 100 * test_correct / test_total
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {loss_avg / (K * T):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%")


Outer Loop Epoch 1/10:   0%|          | 0/10 [00:00<?, ?loop/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 1/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Evaluating Epoch 1/10:   0%|          | 0/119 [00:00<?, ?batch/s]

Epoch [1/10], Average Loss: 0.7218, Train Accuracy: 85.00%, Test Accuracy: 25.00%


Outer Loop Epoch 2/10:   0%|          | 0/10 [00:00<?, ?loop/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 2/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Evaluating Epoch 2/10:   0%|          | 0/119 [00:00<?, ?batch/s]

Epoch [2/10], Average Loss: 1.0775, Train Accuracy: 65.75%, Test Accuracy: 25.00%


Outer Loop Epoch 3/10:   0%|          | 0/10 [00:00<?, ?loop/s]

Training Epoch 3/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 3/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 3/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 3/10:   0%|          | 0/1 [00:00<?, ?batch/s]

Training Epoch 3/10:   0%|          | 0/1 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [40]:
synthetic_dataset.tensors

(tensor([[ 2.3587e+03,  3.7563e+03,  3.5735e+03,  ..., -1.2252e+02,
          -8.0514e+01, -4.8536e+01],
         [-1.1317e+05,  1.0093e+05, -1.1070e+05,  ...,  1.2400e+05,
           4.5166e+04,  2.6488e+04],
         [-1.1095e+05, -9.8747e+04, -9.8204e+04,  ...,  1.8240e+05,
           1.8433e+05,  1.0339e+05],
         ...,
         [ 5.9470e+04,  8.7636e+04,  9.2561e+04,  ..., -3.2375e+04,
           4.4454e+03, -4.9598e+04],
         [-4.3381e+03, -8.1358e+03, -6.5124e+03,  ...,  2.7077e+03,
          -1.1521e+04, -3.7069e+03],
         [ 1.3742e+05,  1.3538e+05,  1.8370e+05,  ..., -2.3106e+04,
           8.5923e+04,  2.9153e+04]]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]))

In [20]:
torch.max(outputs.data, 1)

torch.return_types.max(
values=tensor([ 0.0111,  0.0044,  0.0011,  0.0031,  0.0196,  0.0055,  0.0028,  0.0017,
         0.0061,  0.0095,  0.0088,  0.0065,  0.0082, -0.0055,  0.0048,  0.0098,
         0.0043,  0.0040,  0.0127,  0.0038,  0.0044, -0.0071,  0.0002,  0.0023,
         0.0009,  0.0151,  0.0188,  0.0030,  0.0009,  0.0088,  0.0023,  0.0063,
         0.0032,  0.0072, -0.0020,  0.0058, -0.0020,  0.0071,  0.0066,  0.0026,
         0.0012, -0.0025,  0.0026,  0.0113,  0.0140,  0.0172, -0.0034,  0.0037],
       device='cuda:0'),
indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       device='cuda:0'))

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       device='cuda:0')

## Train model from scratch on synthetic data

In [None]:
# Train model from scratch using synthetic data
model = CNNModel().to(device)
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Create DataLoader for synthetic data
synthetic_dataset = data.TensorDataset(synthetic_text_data.detach(), synthetic_labels)
synthetic_loader = data.DataLoader(synthetic_dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for inputs, labels in tqdm(synthetic_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch'):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_net.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_net.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=dataset['test'].features['label'].names)

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)
