In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm
from transformers import BertTokenizer
from gensim.models import KeyedVectors
from datasets import load_dataset
from sklearn.metrics import classification_report, accuracy_score
import numpy as np
import gensim.downloader as api

# Load AG News Dataset
dataset = load_dataset('ag_news')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize(text):
    return tokenizer(text, padding='max_length', truncation=True, return_tensors='pt', max_length=128)

# Load Word2Vec Embeddings
word2vec = api.load('word2vec-google-news-300')

def get_word2vec_embedding(tokens):
    embeddings = []
    for token in tokens:
        if token in word2vec:
            embeddings.append(word2vec[token])
    if len(embeddings) == 0:
        return np.zeros(word2vec.vector_size)
    return np.mean(embeddings, axis=0)

class AGNewsDataset(data.Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        tokens = tokenizer.tokenize(self.texts[idx])
        embedding = get_word2vec_embedding(tokens)
        label = self.labels[idx]
        return torch.tensor(embedding, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

train_dataset = AGNewsDataset(train_texts, train_labels)
test_dataset = AGNewsDataset(test_texts, test_labels)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 100, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(100 * (word2vec.vector_size // 2), 4)  # 4 classes in AG News

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# Initial model and optimizers
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Synthetic data initialization
num_classes = 4
num_synthetic_per_class = 10
max_length = 128

synthetic_text_data = torch.randn(num_classes * num_synthetic_per_class, word2vec.vector_size, requires_grad=True)
synthetic_labels = torch.tensor([i for i in range(num_classes) for _ in range(num_synthetic_per_class)], dtype=torch.long)

# Optimizer for synthetic data
optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)

def compute_gradients(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients

def layerwise_matching_loss(gw_syn, gw_real):
    loss = 0
    for g_syn, g_real in zip(gw_syn, gw_real):
        loss += ((g_syn - g_real) ** 2).sum()
    return loss



In [2]:
import torch
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm

# Training loop with synthetic data gradient matching
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

def compute_gradients_and_outputs(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients, outputs, loss

for epoch in range(num_epochs):
    model.train()
    loss_avg = 0
    train_correct = 0
    train_total = 0

    for real_inputs, real_labels in tqdm(train_loader, desc=f'Epoch [{epoch + 1}/{num_epochs}]', unit='batch'):
        real_inputs, real_labels = real_inputs.to(device), real_labels.to(device)

        # Compute gradients and outputs for real data
        gradients_real, outputs, loss_real = compute_gradients_and_outputs(model, real_inputs, real_labels)
        
        # Compute gradients for synthetic data
        synthetic_data_batch = synthetic_text_data[epoch % num_classes * num_synthetic_per_class: (epoch % num_classes + 1) * num_synthetic_per_class]
        synthetic_labels_batch = synthetic_labels[epoch % num_classes * num_synthetic_per_class: (epoch % num_classes + 1) * num_synthetic_per_class]
        gradients_synthetic, _, _ = compute_gradients_and_outputs(model, synthetic_data_batch.to(device), synthetic_labels_batch.to(device))
        
        # Compute matching loss
        loss_match = layerwise_matching_loss(gradients_synthetic, gradients_real)
        
        # Zero the gradients
        optimizer_syn.zero_grad()
        optimizer_net.zero_grad()

        # Backward pass and optimize
        loss_match.backward(retain_graph=True)
        loss_real.backward()
        optimizer_syn.step()
        optimizer_net.step()

        loss_avg += loss_match.item()

        # Calculate training accuracy
        _, predicted = torch.max(outputs.data, 1)
        train_total += real_labels.size(0)
        train_correct += (predicted == real_labels).sum().item()
    
    train_accuracy = 100 * train_correct / train_total

    # Evaluate model on test set
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f'Epoch [{epoch + 1}/{num_epochs}] - Evaluating on Test Set', unit='batch'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_accuracy = 100 * test_correct / test_total
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Matching Loss: {loss_avg / len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%")

# Train model from scratch using synthetic data
model = CNNModel().to(device)
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Create DataLoader for synthetic data
synthetic_dataset = data.TensorDataset(synthetic_text_data.detach(), synthetic_labels)
synthetic_loader = data.DataLoader(synthetic_dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for inputs, labels in tqdm(synthetic_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch'):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_net.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_net.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=dataset['test'].features['label'].names)

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)


Epoch [1/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:15<00:00, 13.82batch/s]
Epoch [1/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.13batch/s]


Epoch [1/10], Average Matching Loss: 44.0947, Train Accuracy: 26.41%, Test Accuracy: 25.25%


Epoch [2/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:15<00:00, 13.82batch/s]
Epoch [2/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 15.79batch/s]


Epoch [2/10], Average Matching Loss: 12.6114, Train Accuracy: 25.01%, Test Accuracy: 25.00%


Epoch [3/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:15<00:00, 13.87batch/s]
Epoch [3/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.05batch/s]


Epoch [3/10], Average Matching Loss: 3.3065, Train Accuracy: 25.03%, Test Accuracy: 25.00%


Epoch [4/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:15<00:00, 13.89batch/s]
Epoch [4/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.07batch/s]


Epoch [4/10], Average Matching Loss: 1.7271, Train Accuracy: 24.99%, Test Accuracy: 25.00%


Epoch [5/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:15<00:00, 13.83batch/s]
Epoch [5/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 15.88batch/s]


Epoch [5/10], Average Matching Loss: 0.9181, Train Accuracy: 25.77%, Test Accuracy: 32.47%


Epoch [6/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:15<00:00, 13.87batch/s]
Epoch [6/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.08batch/s]


Epoch [6/10], Average Matching Loss: 0.4367, Train Accuracy: 40.82%, Test Accuracy: 56.63%


Epoch [7/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:14<00:00, 13.96batch/s]
Epoch [7/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.07batch/s]


Epoch [7/10], Average Matching Loss: 0.1368, Train Accuracy: 43.48%, Test Accuracy: 58.78%


Epoch [9/10]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:14<00:00, 13.94batch/s]
Epoch [9/10] - Evaluating on Test Set: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.09batch/s]


Epoch [9/10], Average Matching Loss: 0.6300, Train Accuracy: 29.21%, Test Accuracy: 43.05%


Epoch [10/10]: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [02:14<00:00, 13.92batch/s]
Epoch [10/10] - Evaluating on Test Set: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 16.12batch/s]


Epoch [10/10], Average Matching Loss: 0.0626, Train Accuracy: 70.20%, Test Accuracy: 78.51%


Training Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 227.79batch/s]


Epoch [1/10], Train Loss: 1.3812, Train Accuracy: 17.50%


Training Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 252.30batch/s]


Epoch [2/10], Train Loss: 1.3514, Train Accuracy: 55.00%


Training Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 272.76batch/s]


Epoch [3/10], Train Loss: 1.3556, Train Accuracy: 40.00%


Training Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 281.63batch/s]


Epoch [4/10], Train Loss: 0.4338, Train Accuracy: 92.50%


Training Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 281.53batch/s]


Epoch [5/10], Train Loss: 0.5421, Train Accuracy: 75.00%


Training Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 280.27batch/s]


Epoch [6/10], Train Loss: 0.6750, Train Accuracy: 75.00%


Training Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 265.77batch/s]


Epoch [7/10], Train Loss: 0.4658, Train Accuracy: 75.00%


Training Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 253.37batch/s]


Epoch [8/10], Train Loss: 0.1819, Train Accuracy: 100.00%


Training Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 263.68batch/s]


Epoch [9/10], Train Loss: 0.0922, Train Accuracy: 100.00%


Training Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 254.76batch/s]


Epoch [10/10], Train Loss: 0.1277, Train Accuracy: 100.00%


Evaluating on test set: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 119/119 [00:07<00:00, 15.99batch/s]


Test Accuracy: 0.2500
Classification Report:
              precision    recall  f1-score   support

       World       0.00      0.00      0.00      1900
      Sports       0.00      0.00      0.00      1900
    Business       0.25      1.00      0.40      1900
    Sci/Tech       0.00      0.00      0.00      1900

    accuracy                           0.25      7600
   macro avg       0.06      0.25      0.10      7600
weighted avg       0.06      0.25      0.10      7600



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [4]:
synthetic_data_batch.shape

torch.Size([10, 300])

In [6]:
synthetic_labels_batch

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [9]:
synthetic_text_data.shape

torch.Size([40, 300])

In [11]:
synthetic_labels

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

In [17]:
import pandas as pd  
pd.Series(train_dataset.labels).value_counts()

2    30000
3    30000
1    30000
0    30000
Name: count, dtype: int64

In [18]:
preds

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       device='cuda:0')