In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gensim.downloader as api
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm import tqdm
import pickle
import random

print("Loading Word2Vec embeddings...")
word2vec = api.load('word2vec-google-news-300')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load AG_NEWS dataset using Hugging Face datasets library
print("Loading AG_NEWS dataset...")
dataset = load_dataset('ag_news')

# Tokenizer function
def tokenize(text):
    return text.lower().split()

# Build vocabulary
def build_vocab(texts, tokenizer):
    vocab = {"<unk>": 0}
    for text in texts:
        for token in tokenizer(text):
            if token not in vocab:
                vocab[token] = len(vocab)
    return vocab

# Custom Dataset class
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, tokenizer):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        tokenized_text = [self.vocab[token] if token in self.vocab else self.vocab["<unk>"]
                          for token in self.tokenizer(text)]
        return torch.tensor(tokenized_text, dtype=torch.long), torch.tensor(label, dtype=torch.long)

def pad_sequences(sequences, max_len=None, pad_value=0):
    if not max_len:
        max_len = max(len(seq) for seq in sequences)
    padded_seqs = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
    for i, seq in enumerate(sequences):
        length = len(seq)
        padded_seqs[i, :length] = seq
    return padded_seqs

# Extract texts and labels
train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

# Build vocabulary from training texts
print("Building vocabulary...")
vocab = build_vocab(train_texts, tokenize)

# Create Dataset and DataLoader
train_dataset = AGNewsDataset(train_texts, train_labels, vocab, tokenize)
test_dataset = AGNewsDataset(test_texts, test_labels, vocab, tokenize)

def collate_fn(batch):
    texts, labels = zip(*batch)
    texts_padded = pad_sequences(texts)
    labels = torch.tensor(labels, dtype=torch.long)
    return texts_padded, labels

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Create embedding matrix
print("Creating embeddings matrix...")
embedding_dim = 300
embedding_matrix = np.zeros((len(vocab), embedding_dim))
for word, idx in vocab.items():
    if word in word2vec:
        embedding_matrix[idx] = word2vec[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32).to(device)

# Define the LSTM model
class LSTMClassifier(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, output_dim, n_layers, drop_prob=0.5):
        super(LSTMClassifier, self).__init__()
        num_embeddings, embedding_dim = embedding_matrix.shape
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)  # Freeze embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, (ht, ct) = self.lstm(x)
        out = self.dropout(ht[-1])
        out = self.fc(out)
        return out

hidden_dim = 256
output_dim = 4  # AG_NEWS has 4 classes
n_layers = 2
model = LSTMClassifier(embedding_matrix, hidden_dim, output_dim, n_layers).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def evaluate_syn(model, synthetic_text_data, synthetic_labels, test_loader, device, num_epochs=10, lr=0.001):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer_net = optim.Adam(model.parameters(), lr=lr)

    synthetic_dataset = DataLoader(torch.utils.data.TensorDataset(synthetic_text_data.detach(), synthetic_labels), batch_size=32, shuffle=True)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for inputs, labels in tqdm(synthetic_dataset, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer_net.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_net.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        if epoch % 9 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for text, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
            text, labels = text.to(device), labels.to(device)
            output = model(text)
            _, predicted = torch.max(output, 1)
            total_acc += (predicted == labels).sum().item()
            total_count += labels.size(0)

    overall_accuracy = total_acc / total_count
    print(f'Test Accuracy: {overall_accuracy:.4f}')

    results = {
        'test_accuracy': overall_accuracy,
    }
    with open('results.pkl', 'wb') as f:
        pickle.dump(results, f)

    return overall_accuracy

n_experiments = 1
num_classes = len(set(train_labels))
n_spc = 10  # samples_per_class
n_iterations = 1000
batch_train = 32
batch_real = 32
batch_syn = 32
n_outer_loop = 1
n_inner_loop = 1
num_eval_epochs = 20

data_save = []
accs_hist = []

def get_class_indices(labels, n_classes):
    class_indices = [[] for _ in range(n_classes)]
    for idx, label in enumerate(labels):
        class_indices[label].append(idx)
    return class_indices

train_labels_list = [item['label'] for item in dataset['train']]
class_indices = get_class_indices(train_labels_list, num_classes)

for exp_n in range(n_experiments):
    print(f'\n================== Exp {exp_n} ==================\n')

    # Randomly select synthetic data from train dataset per class
    synthetic_text_data = []
    synthetic_labels = []
    for class_idx in range(num_classes):
        selected_indices = random.sample(class_indices[class_idx], n_spc)
        for idx in selected_indices:
            synthetic_text_data.append(train_dataset[idx][0])
            synthetic_labels.append(train_dataset[idx][1])
    synthetic_text_data = pad_sequences(synthetic_text_data).to(device)
    synthetic_labels = torch.tensor(synthetic_labels, dtype=torch.long).to(device)

    optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)
    optimizer_syn.zero_grad()
    criterion = nn.CrossEntropyLoss().to(device)

    print('Training begins')

    for it in range(n_iterations + 1):
        model = LSTMClassifier(embedding_matrix, hidden_dim, output_dim, n_layers).to(device)
        model.train()
        model_parameters = list(model.parameters())
        optimizer_net = optim.SGD(model.parameters(), lr=0.01)
        optimizer_net.zero_grad()
        loss_avg = 0

        accs_hist.append(evaluate_syn(model, synthetic_text_data, synthetic_labels, test_loader, device, num_epochs=10, lr=0.001))

        for ol in tqdm(range(n_outer_loop)):
            print('Starting outer loop')
            loss = torch.tensor(0.0).to(device)

            for c in range(num_classes):
                text_real = next(iter(train_loader))[0]
                lab_real = torch.ones((text_real.shape[0],), device=device, dtype=torch.long) * c
                text_syn = synthetic_text_data[c * n_spc:(c + 1) * n_spc].to(device)
                lab_syn = torch.ones((n_spc,), device=device, dtype=torch.long) * c
                output_real = model(text_real)
                loss_real = criterion(output_real, lab_real)
                gw_real = torch.autograd.grad(loss_real, model_parameters)
                gw_real = [_.detach().clone() for _ in gw_real]
                output_syn = model(text_syn)
                loss_syn = criterion(output_syn, lab_syn)
                gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
                loss += sum(torch.nn.functional.mse_loss(gw_real_, gw_syn_) for gw_real_, gw_syn_ in zip(gw_real, gw_syn))

            optimizer_syn.zero_grad()
            loss.backward()
            optimizer_syn.step()
            loss_avg += loss.item()

            text_syn_train, label_syn_train = synthetic_text_data.detach(), synthetic_labels.detach()
            dst_syn_train = DataLoader(torch.utils.data.TensorDataset(text_syn_train, label_syn_train), batch_size=batch_train, shuffle=True)

            for il in tqdm(range(n_inner_loop)):
                epoch_loss, epoch_acc = 0, 0
                model.train()
                for batch_text, batch_labels in dst_syn_train:
                    batch_text, batch_labels = batch_text.to(device), batch_labels.to(device)
                    optimizer_net.zero_grad()
                    outputs = model(batch_text)
                    loss = criterion(outputs, batch_labels)
                    loss.backward()
                    optimizer_net.step()
                    epoch_loss += loss.item()
                    _, preds = torch.max(outputs, 1)
                    epoch_acc += torch.sum(preds == batch_labels).item()

                epoch_loss /= len(dst_syn_train.dataset)
                epoch_acc /= len(dst_syn_train.dataset)
                print(f'Inner loop {il + 1}/{n_inner_loop}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

        loss_avg /= (num_classes * n_outer_loop)
        if it % 10 == 0:
            print(f'Iter = {it:04d}, Loss = {loss_avg:.4f}')

    data_save.append([synthetic_text_data.detach().cpu(), synthetic_labels.detach().cpu()])
    torch.save({'data': data_save}, 'results_synthetic_data.pth')


Loading Word2Vec embeddings...
Using device: cuda
Loading AG_NEWS dataset...
Building vocabulary...
Creating embeddings matrix...


Training begins


Training Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 106.70batch/s]


Epoch [1/10], Train Loss: 2.8262, Train Accuracy: 17.50%


Training Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 105.09batch/s]
Training Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 104.23batch/s]
Training Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 105.34batch/s]
Training Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 105.16batch/s]
Training Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 98.76batch/s]
Training Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████

Epoch [10/10], Train Loss: 2.8152, Train Accuracy: 32.50%


Evaluating on test set: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 238/238 [00:01<00:00, 194.55batch/s]


Test Accuracy: 0.2537


  0%|                                                                                                                                                 | 0/1 [00:00<?, ?it/s]


Starting outer loop


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gensim.downloader as api
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm import tqdm
import pickle
import random

print("Loading Word2Vec embeddings...")
word2vec = api.load('word2vec-google-news-300')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load AG_NEWS dataset using Hugging Face datasets library
print("Loading AG_NEWS dataset...")
dataset = load_dataset('ag_news')

# Tokenizer function
def tokenize(text):
    return text.lower().split()

# Build vocabulary
def build_vocab(texts, tokenizer):
    vocab = {"<unk>": 0}
    for text in texts:
        for token in tokenizer(text):
            if token not in vocab:
                vocab[token] = len(vocab)
    return vocab

# Custom Dataset class
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, tokenizer):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        tokenized_text = [self.vocab[token] if token in self.vocab else self.vocab["<unk>"]
                          for token in self.tokenizer(text)]
        return torch.tensor(tokenized_text, dtype=torch.long), torch.tensor(label, dtype=torch.long)

def pad_sequences(sequences, max_len=None, pad_value=0):
    if not max_len:
        max_len = max(len(seq) for seq in sequences)
    padded_seqs = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
    for i, seq in enumerate(sequences):
        length = len(seq)
        padded_seqs[i, :length] = seq
    return padded_seqs

# Extract texts and labels
train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

# Build vocabulary from training texts
print("Building vocabulary...")
vocab = build_vocab(train_texts, tokenize)

# Create Dataset and DataLoader
train_dataset = AGNewsDataset(train_texts, train_labels, vocab, tokenize)
test_dataset = AGNewsDataset(test_texts, test_labels, vocab, tokenize)

def collate_fn(batch):
    texts, labels = zip(*batch)
    texts_padded = pad_sequences(texts)
    labels = torch.tensor(labels, dtype=torch.long)
    return texts_padded, labels

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Create embedding matrix
print("Creating embeddings matrix...")
embedding_dim = 300
embedding_matrix = np.zeros((len(vocab), embedding_dim))
for word, idx in vocab.items():
    if word in word2vec:
        embedding_matrix[idx] = word2vec[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32).to(device)

# Define the LSTM model
class LSTMClassifier(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, output_dim, n_layers, drop_prob=0.5):
        super(LSTMClassifier, self).__init__()
        num_embeddings, embedding_dim = embedding_matrix.shape
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)  # Freeze embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, (ht, ct) = self.lstm(x)
        out = self.dropout(ht[-1])
        out = self.fc(out)
        return out

hidden_dim = 256
output_dim = 4  # AG_NEWS has 4 classes
n_layers = 2
model = LSTMClassifier(embedding_matrix, hidden_dim, output_dim, n_layers).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def evaluate_syn(model, synthetic_text_data, synthetic_labels, test_loader, device, num_epochs=10, lr=0.001):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer_net = optim.Adam(model.parameters(), lr=lr)

    synthetic_dataset = DataLoader(torch.utils.data.TensorDataset(synthetic_text_data.detach(), synthetic_labels), batch_size=32, shuffle=True)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for inputs, labels in tqdm(synthetic_dataset, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer_net.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_net.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        if epoch % 9 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for text, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
            text, labels = text.to(device), labels.to(device)
            output = model(text)
            _, predicted = torch.max(output, 1)
            total_acc += (predicted == labels).sum().item()
            total_count += labels.size(0)

    overall_accuracy = total_acc / total_count
    print(f'Test Accuracy: {overall_accuracy:.4f}')

    results = {
        'test_accuracy': overall_accuracy,
    }
    with open('results.pkl', 'wb') as f:
        pickle.dump(results, f)

    return overall_accuracy

n_experiments = 1
num_classes = len(set(train_labels))
n_spc = 10  # samples_per_class
n_iterations = 1000
batch_train = 32
batch_real = 32
batch_syn = 32
n_outer_loop = 1
n_inner_loop = 1
num_eval_epochs = 20

data_save = []
accs_hist = []

def get_class_indices(labels, n_classes):
    class_indices = [[] for _ in range(n_classes)]
    for idx, label in enumerate(labels):
        class_indices[label].append(idx)
    return class_indices

train_labels_list = [item['label'] for item in dataset['train']]
class_indices = get_class_indices(train_labels_list, num_classes)

for exp_n in range(n_experiments):
    print(f'\n================== Exp {exp_n} ==================\n')

    # Randomly select synthetic data from train dataset per class
    synthetic_text_data = []
    synthetic_labels = []
    for class_idx in range(num_classes):
        selected_indices = random.sample(class_indices[class_idx], n_spc)
        for idx in selected_indices:
            synthetic_text_data.append(train_dataset[idx][0])
            synthetic_labels.append(train_dataset[idx][1])
    synthetic_text_data = pad_sequences(synthetic_text_data).to(device)
    synthetic_labels = torch.tensor(synthetic_labels, dtype=torch.long).to(device)
    
    synthetic_text_data.requires_grad = True

    optimizer_syn = optim.SGD([synthetic_text_data], lr=0.01, momentum=0.9)
    optimizer_syn.zero_grad()
    criterion = nn.CrossEntropyLoss().to(device)

    print('Training begins')

    for it in range(n_iterations + 1):
        model = LSTMClassifier(embedding_matrix, hidden_dim, output_dim, n_layers).to(device)
        model.train()
        model_parameters = list(model.parameters())
        optimizer_net = optim.SGD(model.parameters(), lr=0.01)
        optimizer_net.zero_grad()
        loss_avg = 0

        accs_hist.append(evaluate_syn(model, synthetic_text_data, synthetic_labels, test_loader, device, num_epochs=10, lr=0.001))

        for ol in tqdm(range(n_outer_loop)):
            print('Starting outer loop')
            loss = torch.tensor(0.0).to(device)

            for c in range(num_classes):
                text_real, lab_real = next(iter(train_loader))
                text_real, lab_real = text_real.to(device), lab_real.to(device)
                lab_real = torch.ones((text_real.shape[0],), device=device, dtype=torch.long) * c
                text_syn = synthetic_text_data[c * n_spc:(c + 1) * n_spc].to(device)
                lab_syn = torch.ones((n_spc,), device=device, dtype=torch.long) * c
                output_real = model(text_real)
                loss_real = criterion(output_real, lab_real)
                gw_real = torch.autograd.grad(loss_real, model_parameters, retain_graph=True)
                gw_real = [_.detach().clone() for _ in gw_real]
                output_syn = model(text_syn)
                loss_syn = criterion(output_syn, lab_syn)
                gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
                loss += sum(torch.nn.functional.mse_loss(gw_real_, gw_syn_) for gw_real_, gw_syn_ in zip(gw_real, gw_syn))

            optimizer_syn.zero_grad()
            loss.backward()
            optimizer_syn.step()
            loss_avg += loss.item()

            text_syn_train, label_syn_train = synthetic_text_data.detach(), synthetic_labels.detach()
            dst_syn_train = DataLoader(torch.utils.data.TensorDataset(text_syn_train, label_syn_train), batch_size=batch_train, shuffle=True)

            for il in tqdm(range(n_inner_loop)):
                epoch_loss, epoch_acc = 0, 0
                model.train()
                for batch_text, batch_labels in dst_syn_train:
                    batch_text, batch_labels = batch_text.to(device), batch_labels.to(device)
                    optimizer_net.zero_grad()
                    outputs = model(batch_text)
                    loss = criterion(outputs, batch_labels)
                    loss.backward()
                    optimizer_net.step()
                    epoch_loss += loss.item()
                    _, preds = torch.max(outputs, 1)
                    epoch_acc += torch.sum(preds == batch_labels).item()

                epoch_loss /= len(dst_syn_train.dataset)
                epoch_acc /= len(dst_syn_train.dataset)
                print(f'Inner loop {il + 1}/{n_inner_loop}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

        loss_avg /= (num_classes * n_outer_loop)
        if it % 10 == 0:
            print(f'Iter = {it:04d}, Loss = {loss_avg:.4f}')

    data_save.append([synthetic_text_data.detach().cpu(), synthetic_labels.detach().cpu()])
    torch.save({'data': data_save}, 'results_synthetic_data.pth')


Loading Word2Vec embeddings...
Using device: cuda
Loading AG_NEWS dataset...
Building vocabulary...
Creating embeddings matrix...




RuntimeError: only Tensors of floating point and complex dtype can require gradients

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gensim.downloader as api
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm import tqdm
import pickle
import random

print("Loading Word2Vec embeddings...")
word2vec = api.load('word2vec-google-news-300')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load AG_NEWS dataset using Hugging Face datasets library
print("Loading AG_NEWS dataset...")
dataset = load_dataset('ag_news')

# Tokenizer function
def tokenize(text):
    return text.lower().split()

# Build vocabulary
def build_vocab(texts, tokenizer):
    vocab = {"<unk>": 0}
    for text in texts:
        for token in tokenizer(text):
            if token not in vocab:
                vocab[token] = len(vocab)
    return vocab

# Custom Dataset class
class AGNewsDataset(Dataset):
    def __init__(self, texts, labels, vocab, tokenizer):
        self.texts = texts
        self.labels = labels
        self.vocab = vocab
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        tokenized_text = [self.vocab[token] if token in self.vocab else self.vocab["<unk>"]
                          for token in self.tokenizer(text)]
        return torch.tensor(tokenized_text, dtype=torch.long), torch.tensor(label, dtype=torch.long)

def pad_sequences(sequences, max_len=None, pad_value=0):
    if not max_len:
        max_len = max(len(seq) for seq in sequences)
    padded_seqs = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
    for i, seq in enumerate(sequences):
        length = len(seq)
        padded_seqs[i, :length] = seq
    return padded_seqs

# Extract texts and labels
train_texts = [item['text'] for item in dataset['train']]
train_labels = [item['label'] for item in dataset['train']]
test_texts = [item['text'] for item in dataset['test']]
test_labels = [item['label'] for item in dataset['test']]

# Build vocabulary from training texts
print("Building vocabulary...")
vocab = build_vocab(train_texts, tokenize)

# Create Dataset and DataLoader
train_dataset = AGNewsDataset(train_texts, train_labels, vocab, tokenize)
test_dataset = AGNewsDataset(test_texts, test_labels, vocab, tokenize)

def collate_fn(batch):
    texts, labels = zip(*batch)
    texts_padded = pad_sequences(texts)
    labels = torch.tensor(labels, dtype=torch.long)
    return texts_padded, labels

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Create embedding matrix
print("Creating embeddings matrix...")
embedding_dim = 300
embedding_matrix = np.zeros((len(vocab), embedding_dim))
for word, idx in vocab.items():
    if word in word2vec:
        embedding_matrix[idx] = word2vec[word]
    else:
        embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))

embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32).to(device)

# Define the LSTM model
class LSTMClassifier(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, output_dim, n_layers, drop_prob=0.5):
        super(LSTMClassifier, self).__init__()
        num_embeddings, embedding_dim = embedding_matrix.shape
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)  # Freeze embeddings
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, (ht, ct) = self.lstm(x)
        out = self.dropout(ht[-1])
        out = self.fc(out)
        return out

hidden_dim = 256
output_dim = 4  # AG_NEWS has 4 classes
n_layers = 2
model = LSTMClassifier(embedding_matrix, hidden_dim, output_dim, n_layers).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

class SyntheticDataEmbeddings(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_samples):
        super(SyntheticDataEmbeddings, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.embeddings.weight.data.normal_(0, 1)
        self.embeddings.weight.requires_grad = True

    def forward(self, x):
        return self.embeddings(x)

def evaluate_syn(model, synthetic_data_embeddings, synthetic_text_data, synthetic_labels, test_loader, device, num_epochs=10, lr=0.001):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer_net = optim.Adam(model.parameters(), lr=lr)

    synthetic_text_data = synthetic_data_embeddings(synthetic_text_data)
    synthetic_dataset = DataLoader(torch.utils.data.TensorDataset(synthetic_text_data.detach(), synthetic_labels), batch_size=32, shuffle=True)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        for inputs, labels in tqdm(synthetic_dataset, desc=f'Training Epoch {epoch+1}/{num_epochs}', unit='batch'):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer_net.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_net.step()
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        if epoch % 9 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')

    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for text, labels in tqdm(test_loader, desc='Evaluating on test set', unit='batch'):
            text, labels = text.to(device), labels.to(device)
            output = model(text)
            _, predicted = torch.max(output, 1)
            total_acc += (predicted == labels).sum().item()
            total_count += labels.size(0)

    overall_accuracy = total_acc / total_count
    print(f'Test Accuracy: {overall_accuracy:.4f}')

    results = {
        'test_accuracy': overall_accuracy,
    }
    with open('results.pkl', 'wb') as f:
        pickle.dump(results, f)

    return overall_accuracy

n_experiments = 1
num_classes = len(set(train_labels))
n_spc = 10  # samples_per_class
n_iterations = 1000
batch_train = 32
batch_real = 32
batch_syn = 32
n_outer_loop = 1
n_inner_loop = 1
num_eval_epochs = 20

data_save = []
accs_hist = []

def get_class_indices(labels, n_classes):
    class_indices = [[] for _ in range(n_classes)]
    for idx, label in enumerate(labels):
        class_indices[label].append(idx)
    return class_indices

train_labels_list = [item['label'] for item in dataset['train']]
class_indices = get_class_indices(train_labels_list, num_classes)

for exp_n in range(n_experiments):
    print(f'\n================== Exp {exp_n} ==================\n')

    # Randomly select synthetic data from train dataset per class
    synthetic_text_data = []
    synthetic_labels = []
    for class_idx in range(num_classes):
        selected_indices = random.sample(class_indices[class_idx], n_spc)
        for idx in selected_indices:
            synthetic_text_data.append(train_dataset[idx][0])
            synthetic_labels.append(train_dataset[idx][1])
    synthetic_text_data = pad_sequences(synthetic_text_data).to(device)
    synthetic_labels = torch.tensor(synthetic_labels, dtype=torch.long).to(device)
    
    synthetic_data_embeddings = SyntheticDataEmbeddings(len(vocab), embedding_dim, num_classes * n_spc).to(device)

    optimizer_syn = optim.SGD(synthetic_data_embeddings.parameters(), lr=0.01, momentum=0.9)
    optimizer_syn.zero_grad()
    criterion = nn.CrossEntropyLoss().to(device)

    print('Training begins')

    for it in range(n_iterations + 1):
        model = LSTMClassifier(embedding_matrix, hidden_dim, output_dim, n_layers).to(device)
        model.train()
        model_parameters = list(model.parameters())
        optimizer_net = optim.SGD(model.parameters(), lr=0.01)
        optimizer_net.zero_grad()
        loss_avg = 0

        accs_hist.append(evaluate_syn(model, synthetic_data_embeddings, synthetic_text_data, synthetic_labels, test_loader, device, num_epochs=10, lr=0.001))

        for ol in tqdm(range(n_outer_loop)):
            print('Starting outer loop')
            loss = torch.tensor(0.0).to(device)

            for c in range(num_classes):
                text_real, lab_real = next(iter(train_loader))
                text_real, lab_real = text_real.to(device), lab_real.to(device)
                lab_real = torch.ones((text_real.shape[0],), device=device, dtype=torch.long) * c
                text_syn = synthetic_text_data[c * n_spc:(c + 1) * n_spc].to(device)
                text_syn = synthetic_data_embeddings(text_syn)
                lab_syn = torch.ones((n_spc,), device=device, dtype=torch.long) * c
                output_real = model(text_real)
                loss_real = criterion(output_real, lab_real)
                gw_real = torch.autograd.grad(loss_real, model_parameters, retain_graph=True)
                gw_real = [_.detach().clone() for _ in gw_real]
                output_syn = model(text_syn)
                loss_syn = criterion(output_syn, lab_syn)
                gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
                loss += sum(torch.nn.functional.mse_loss(gw_real_, gw_syn_) for gw_real_, gw_syn_ in zip(gw_real, gw_syn))

            optimizer_syn.zero_grad()
            loss.backward()
            optimizer_syn.step()
            loss_avg += loss.item()

            text_syn_train, label_syn_train = synthetic_text_data.detach(), synthetic_labels.detach()
            dst_syn_train = DataLoader(torch.utils.data.TensorDataset(text_syn_train, label_syn_train), batch_size=batch_train, shuffle=True)

            for il in tqdm(range(n_inner_loop)):
                epoch_loss, epoch_acc = 0, 0
                model.train()
                for batch_text, batch_labels in dst_syn_train:
                    batch_text, batch_labels = batch_text.to(device), batch_labels.to(device)
                    optimizer_net.zero_grad()
                    outputs = model(batch_text)
                    loss = criterion(outputs, batch_labels)
                    loss.backward()
                    optimizer_net.step()
                    epoch_loss += loss.item()
                    _, preds = torch.max(outputs, 1)
                    epoch_acc += torch.sum(preds == batch_labels).item()

                epoch_loss /= len(dst_syn_train.dataset)
                epoch_acc /= len(dst_syn_train.dataset)
                print(f'Inner loop {il + 1}/{n_inner_loop}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

        loss_avg /= (num_classes * n_outer_loop)
        if it % 10 == 0:
            print(f'Iter = {it:04d}, Loss = {loss_avg:.4f}')

    data_save.append([synthetic_text_data.detach().cpu(), synthetic_labels.detach().cpu()])
    torch.save({'data': data_save}, 'results_synthetic_data.pth')


Loading Word2Vec embeddings...
Using device: cuda
Loading AG_NEWS dataset...
Building vocabulary...
Creating embeddings matrix...


Training begins


Training Epoch 1/10:   0%|                                                                                                                         | 0/2 [00:00<?, ?batch/s]


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)