In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm
from transformers import BertTokenizer
from gensim.models import KeyedVectors
from datasets import load_dataset
from sklearn.metrics import classification_report, accuracy_score
import numpy as np
import gensim.downloader as api

# Load AG News Dataset
dataset = load_dataset('ag_news')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize(text):
    return tokenizer(text, padding='max_length', truncation=True, return_tensors='pt', max_length=128)

# Load Word2Vec Embeddings
word2vec = api.load('word2vec-google-news-300')

def get_word2vec_embedding(tokens):
    embeddings = []
    for token in tokens:
        if token in word2vec:
            embeddings.append(word2vec[token])
    if len(embeddings) == 0:
        return np.zeros(word2vec.vector_size)
    return np.mean(embeddings, axis=0)

class AGNewsDataset(data.Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        tokens = tokenizer.tokenize(self.texts[idx])
        embedding = get_word2vec_embedding(tokens)
        label = self.labels[idx]
        return torch.tensor(embedding, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

train_dataset = AGNewsDataset(train_texts, train_labels)
test_dataset = AGNewsDataset(test_texts, test_labels)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 100, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(100 * (word2vec.vector_size // 2), 4)  # 4 classes in AG News

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# Initial model and optimizers
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Synthetic data initialization
num_classes = 4
num_synthetic_per_class = 10
max_length = 128

synthetic_text_data = torch.randn(num_classes * num_synthetic_per_class, word2vec.vector_size, requires_grad=True)
synthetic_labels = torch.tensor([i for i in range(num_classes) for _ in range(num_synthetic_per_class)], dtype=torch.long)

# Optimizer for synthetic data
optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)

def compute_gradients(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients

def layerwise_matching_loss(gw_syn, gw_real):
    loss = 0
    for g_syn, g_real in zip(gw_syn, gw_real):
        loss += ((g_syn - g_real) ** 2).sum()
    return loss



In [None]:
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.utils.data as data

# Training loop with synthetic data gradient matching
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    loss_avg = 0
    train_correct = 0
    train_total = 0

    # Update synthetic data
    for real_inputs, real_labels in tqdm(train_loader, desc=f'Epoch [{epoch + 1}/{num_epochs}] - Updating Synthetic Data', unit='batch'):
        real_inputs, real_labels = real_inputs.to(device), real_labels.to(device)

        # Compute gradients for real data
        gradients_real = compute_gradients(model, real_inputs, real_labels)
        
        # Compute gradients for synthetic data
        synthetic_data_batch = synthetic_text_data[epoch % num_classes * num_synthetic_per_class: (epoch % num_classes + 1) * num_synthetic_per_class]
        synthetic_labels_batch = synthetic_labels[epoch % num_classes * num_synthetic_per_class: (epoch % num_classes + 1) * num_synthetic_per_class]
        gradients_synthetic = compute_gradients(model, synthetic_data_batch.to(device), synthetic_labels_batch.to(device))
        
        # Compute and minimize matching loss
        loss_match = layerwise_matching_loss(gradients_synthetic, gradients_real)
        optimizer_syn.zero_grad()
        optimizer_net.zero_grad()
        loss_match.backward()
        optimizer_syn.step()
        optimizer_net.step()
        loss_avg += loss_match.item()

        # Calculate training accuracy
        outputs = model(real_inputs)
        _, predicted = torch.max(outputs.data, 1)
        train_total += real_labels.size(0)
        train_correct += (predicted == real_labels).sum().item()
    
    train_accuracy = 100 * train_correct / train_total

    # Evaluate model on test set
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f'Epoch [{epoch + 1}/{num_epochs}] - Evaluating on Test Set', unit='batch'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    
    test_accuracy = 100 * test_correct / test_total
    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Matching Loss: {loss_avg / len(train_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%")

# Train model from scratch using synthetic data
model = CNNModel().to(device)
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Create DataLoader for synthetic data
synthetic_dataset = data.TensorDataset(synthetic_text_data.detach(), synthetic_labels)
synthetic_loader = data.DataLoader(synthetic_dataset, batch_size=64, shuffle=True)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for inputs, labels in tqdm(synthetic_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch'):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_net.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_net.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=dataset['test'].features['label'].names)

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from tqdm import tqdm
from transformers import BertTokenizer
from gensim.models import KeyedVectors
from datasets import load_dataset
from sklearn.metrics import classification_report, accuracy_score
import numpy as np
import gensim.downloader as api

# Load AG News Dataset
dataset = load_dataset('ag_news')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize(text):
    return tokenizer(text, padding='max_length', truncation=True, return_tensors='pt', max_length=128)

# Load Word2Vec Embeddings
word2vec = api.load('word2vec-google-news-300')

def get_word2vec_embedding(tokens):
    embeddings = []
    for token in tokens:
        if token in word2vec:
            embeddings.append(word2vec[token])
    if len(embeddings) == 0:
        return np.zeros(word2vec.vector_size)
    return np.mean(embeddings, axis=0)

class AGNewsDataset(data.Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        tokens = tokenizer.tokenize(self.texts[idx])
        embedding = get_word2vec_embedding(tokens)
        label = self.labels[idx]
        return torch.tensor(embedding, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

train_dataset = AGNewsDataset(train_texts, train_labels)
test_dataset = AGNewsDataset(test_texts, test_labels)

train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=64, shuffle=False)

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 100, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(100 * (word2vec.vector_size // 2), 4)  # 4 classes in AG News

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.Adam(model.parameters(), lr=0.001)

# Synthetic data initialization
num_classes = 4
num_synthetic_per_class = 10
max_length = 128

synthetic_text_data = torch.randn(num_classes * num_synthetic_per_class, word2vec.vector_size, requires_grad=True)
synthetic_labels = torch.tensor([i for i in range(num_classes) for _ in range(num_synthetic_per_class)], dtype=torch.long)

# Optimizer for synthetic data
optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)

def compute_gradients(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients

def layerwise_matching_loss(gw_syn, gw_real):
    loss = 0
    for g_syn, g_real in zip(gw_syn, gw_real):
        loss += ((g_syn - g_real) ** 2).sum()
    return loss

# Training loop with synthetic data gradient matching
num_iterations = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

CNNModel(
  (conv1): Conv1d(1, 100, kernel_size=(3,), stride=(1,), padding=(1,))
  (relu): ReLU()
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=15000, out_features=4, bias=True)
)

In [3]:
from tqdm import tqdm
num_iterations = 10
for iteration in tqdm(range(num_iterations)):
    model.train()
    loss_avg = 0

    # Update synthetic data
    for real_inputs, real_labels in tqdm(train_loader):
        real_inputs, real_labels = real_inputs.to(device), real_labels.to(device)

        # Compute gradients for real data
        gradients_real = compute_gradients(model, real_inputs, real_labels)
        
        # Compute gradients for synthetic data
        synthetic_data_batch = synthetic_text_data[iteration % num_classes * num_synthetic_per_class: (iteration % num_classes + 1) * num_synthetic_per_class]
        synthetic_labels_batch = synthetic_labels[iteration % num_classes * num_synthetic_per_class: (iteration % num_classes + 1) * num_synthetic_per_class]
        gradients_synthetic = compute_gradients(model, synthetic_data_batch.to(device), synthetic_labels_batch.to(device))
        
        # Compute and minimize matching loss
        loss_match = layerwise_matching_loss(gradients_synthetic, gradients_real)
        optimizer_syn.zero_grad()
        loss_match.backward()
        optimizer_syn.step()
        loss_avg += loss_match.item()

    # Print loss every 10 iterations
    if iteration % 1 == 0:
        print(f"Iteration {iteration}, Average Matching Loss: {loss_avg / len(train_loader):.4f}")

# Train model from scratch using synthetic data
model = CNNModel().to(device)
optimizer_net = optim.Adam(model.parameters(), lr=0.001)



  0%|                                                                                                                                               | 0/100 [00:00<?, ?it/s]
  0%|                                                                                                                                              | 0/1875 [00:00<?, ?it/s][A
  0%|                                                                                                                                      | 1/1875 [00:00<13:53,  2.25it/s][A
  0%|▏                                                                                                                                     | 2/1875 [00:00<08:03,  3.88it/s][A
  0%|▎                                                                                                                                     | 4/1875 [00:00<04:53,  6.38it/s][A
  0%|▍                                                                                                                     

KeyboardInterrupt: 

In [None]:
# Create DataLoader for synthetic data
synthetic_dataset = data.TensorDataset(synthetic_text_data.detach(), synthetic_labels)
synthetic_loader = data.DataLoader(synthetic_dataset, batch_size=64, shuffle=True)

num_epochs = 10

for epoch in tqdm(range(num_epochs)):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for inputs, labels in tqdm(synthetic_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}'):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer_net.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_net.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_accuracy = 100 * correct / total
    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=dataset['test'].features['label'].names)

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)


## OLD CODE

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import accuracy_score, classification_report
from datasets import load_dataset
import nltk
from nltk.tokenize import word_tokenize
import numpy as np
from tqdm import tqdm
from gensim.models import KeyedVectors
import gensim.downloader as api

# Load the Word2Vec model
word2vec_model = api.load('word2vec-google-news-300')

# Download NLTK data
nltk.download('punkt')

# Load the AG News dataset
dataset = load_dataset("ag_news")

# NLTK Tokenizer Function
def nltk_tokenizer(text):
    return word_tokenize(text.lower())

# Convert to PyTorch tensors
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_length):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        tokenized_text = [self.vocab.get(token, self.vocab['<UNK>']) for token in text]
        if len(tokenized_text) < self.max_length:
            tokenized_text += [self.vocab['<PAD>']] * (self.max_length - len(tokenized_text))
        else:
            tokenized_text = tokenized_text[:self.max_length]
        return torch.tensor(tokenized_text, dtype=torch.long), torch.tensor(label, dtype=torch.long)

# Build vocabulary
def build_vocab(dataset, tokenizer):
    vocab = {'<PAD>': 0, '<UNK>': 1}
    for example in dataset:
        tokens = tokenizer(example['text'])
        for token in tokens:
            if token not in vocab:
                vocab[token] = len(vocab)
    return vocab

# Tokenize and build vocab
train_texts = [nltk_tokenizer(example['text']) for example in dataset['train']]
train_labels = [example['label'] for example in dataset['train']]
vocab = build_vocab(dataset['train'], nltk_tokenizer)

# Set max length for padding
max_length = 128

# Create dataset
full_dataset = AGNewsDataset(train_texts, train_labels, vocab, max_length)

# Split training set into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Prepare test dataset
test_texts = [nltk_tokenizer(example['text']) for example in dataset['test']]
test_labels = [example['label'] for example in dataset['test']]
test_dataset = AGNewsDataset(test_texts, test_labels, vocab, max_length)

# Initialize embedding matrix
def build_embedding_matrix(vocab, word2vec_model, embedding_dim):
    embedding_matrix = np.zeros((len(vocab), embedding_dim))
    for word, idx in vocab.items():
        if word in word2vec_model:
            embedding_matrix[idx] = word2vec_model[word]
        else:
            embedding_matrix[idx] = np.random.normal(size=(embedding_dim,))
    return torch.tensor(embedding_matrix, dtype=torch.float32)

# Build the embedding matrix
embedding_dim = 300  # Word2Vec uses 300-dimensional vectors
embedding_matrix = build_embedding_matrix(vocab, word2vec_model, embedding_dim)

# Define CNN model
class CNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim, embedding_matrix):
        super(CNNModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
        self.conv1 = nn.Conv2d(1, 100, (3, embedding_dim))
        self.conv2 = nn.Conv2d(1, 100, (4, embedding_dim))
        self.conv3 = nn.Conv2d(1, 100, (5, embedding_dim))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(300, output_dim)
    
    def forward(self, x):
        x = self.embedding(x).unsqueeze(1)  # Add channel dimension
        x1 = torch.relu(self.conv1(x)).squeeze(3)
        x1 = torch.max_pool1d(x1, x1.size(2)).squeeze(2)
        x2 = torch.relu(self.conv2(x)).squeeze(3)
        x2 = torch.max_pool1d(x2, x2.size(2)).squeeze(2)
        x3 = torch.relu(self.conv3(x)).squeeze(3)
        x3 = torch.max_pool1d(x3, x3.size(2)).squeeze(2)
        x = torch.cat((x1, x2, x3), 1)
        x = self.dropout(x)
        return self.fc(x)

# Training Parameters
BATCH_SIZE = 64
EPOCHS = 3
OUTPUT_DIM = 4
LR = 0.001
num_iterations = 100

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model, loss function, and optimizer
model = CNNModel(len(vocab), embedding_dim, OUTPUT_DIM, embedding_matrix).to(device)
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.Adam(model.parameters(), lr=LR)

# Synthetic data initialization
num_classes = 4
num_synthetic_per_class = 10

synthetic_text_data = torch.randint(0, len(vocab), (num_classes * num_synthetic_per_class, max_length), dtype=torch.long, device=device, requires_grad=True)
synthetic_labels = torch.tensor([i for i in range(num_classes) for _ in range(num_synthetic_per_class)], dtype=torch.long, device=device)

# Optimizer for synthetic data
optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)

def compute_gradients(model, inputs, labels):
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients

def layerwise_matching_loss(gw_syn, gw_real):
    loss = 0
    for g_syn, g_real in zip(gw_syn, gw_real):
        loss += ((g_syn - g_real) ** 2).sum()
    return loss

# Training loop with synthetic data gradient matching
for iteration in range(num_iterations):
    model.train()
    loss_avg = 0

    # Update synthetic data
    for real_inputs, real_labels in train_loader:
        real_inputs, real_labels = real_inputs.to(device), real_labels.to(device)

        # Compute gradients for real data
        gradients_real = compute_gradients(model, real_inputs, real_labels)
        
        # Compute gradients for synthetic data
        synthetic_data_batch = synthetic_text_data[iteration % num_classes * num_synthetic_per_class: (iteration % num_classes + 1) * num_synthetic_per_class]
        synthetic_labels_batch = synthetic_labels[iteration % num_classes * num_synthetic_per_class: (iteration % num_classes + 1) * num_synthetic_per_class]
        gradients_synthetic = compute_gradients(model, synthetic_data_batch, synthetic_labels_batch)
        
        # Compute and minimize matching loss
        loss_match = layerwise_matching_loss(gradients_synthetic, gradients_real)
        optimizer_syn.zero_grad()
        loss_match.backward()
        optimizer_syn.step()
        loss_avg += loss_match.item()

    # Print loss every 10 iterations
    if iteration % 10 == 0:
        print(f"Iteration {iteration}, Average Matching Loss: {loss_avg / len(train_loader):.4f}")

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=['World', 'Sports', 'Business', 'Sci/Tech'])

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)


[nltk_data] Downloading package punkt to /home/IAIS/rrao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


RuntimeError: Only Tensors of floating point and complex dtype can require gradients

In [2]:
# Synthetic data initialization
num_classes = 4
num_synthetic_per_class = 10

# Initialize synthetic text data with integer type
synthetic_text_data_int = torch.randint(0, len(vocab), (num_classes * num_synthetic_per_class, max_length), dtype=torch.long, device=device)
# Convert to float and set requires_grad=True
synthetic_text_data = synthetic_text_data_int.float().requires_grad_(True)
synthetic_labels = torch.tensor([i for i in range(num_classes) for _ in range(num_synthetic_per_class)], dtype=torch.long, device=device)

# Optimizer for synthetic data
optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)

def compute_gradients(model, inputs, labels):
    outputs = model(inputs.long())  # Convert inputs back to long for embedding lookup
    loss = criterion(outputs, labels)
    gradients = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    return gradients

# Training loop with synthetic data gradient matching
for iteration in range(num_iterations):
    model.train()
    loss_avg = 0

    # Update synthetic data
    for real_inputs, real_labels in train_loader:
        real_inputs, real_labels = real_inputs.to(device), real_labels.to(device)

        # Compute gradients for real data
        gradients_real = compute_gradients(model, real_inputs, real_labels)
        
        # Compute gradients for synthetic data
        synthetic_data_batch = synthetic_text_data[iteration % num_classes * num_synthetic_per_class: (iteration % num_classes + 1) * num_synthetic_per_class]
        synthetic_labels_batch = synthetic_labels[iteration % num_classes * num_synthetic_per_class: (iteration % num_classes + 1) * num_synthetic_per_class]
        gradients_synthetic = compute_gradients(model, synthetic_data_batch, synthetic_labels_batch)
        
        # Compute and minimize matching loss
        loss_match = layerwise_matching_loss(gradients_synthetic, gradients_real)
        optimizer_syn.zero_grad()
        loss_match.backward()
        optimizer_syn.step()
        loss_avg += loss_match.item()

    # Print loss every 10 iterations
    if iteration % 10 == 0:
        print(f"Iteration {iteration}, Average Matching Loss: {loss_avg / len(train_loader):.4f}")

# Final evaluation on test set
model.eval()
test_labels = []
test_preds = []
with torch.no_grad():
    for texts, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
        texts, labels = texts.to(device), labels.to(device)
        outputs = model(texts)
        _, preds = torch.max(outputs, 1)
        test_labels.extend(labels.cpu().numpy())
        test_preds.extend(preds.cpu().numpy())

overall_accuracy = accuracy_score(test_labels, test_preds)
class_report = classification_report(test_labels, test_preds, target_names=['World', 'Sports', 'Business', 'Sci/Tech'])

print(f'Test Accuracy: {overall_accuracy:.4f}')
print('Classification Report:')
print(class_report)


NameError: name 'layerwise_matching_loss' is not defined