# Training a Simple GAN Model for Sentence Embeddings

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard
import pandas as pd
import ast
import numpy as np

MAX_LENGTH = 768


class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, z_dim, emb_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, emb_dim),
            nn.Tanh(),  # Assuming you want to normalize the outputs
        )

    def forward(self, x):
        return self.gen(x)


class CustomDataset(Dataset):
    def __init__(self, embeddings):
        self.embeddings = embeddings

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx]



# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
embed_dim = MAX_LENGTH  # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(embed_dim).to(device)
gen = Generator(z_dim, embed_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)


# Preprocessing Data

In [55]:
#Limit number of rows for experimentation
df = pd.read_csv('author_csv.csv')
num_rows = len(df)
df = df[:num_rows]
df

Unnamed: 0.1,Unnamed: 0,author_labels,cls_tokens
0,0,1,"[0.278237104415893, -0.33750003576278603, 0.84..."
1,1,1,"[-0.13946822285652102, 0.093057677149772, 0.73..."
2,2,1,"[-1.493581295013427, 0.748962104320526, 1.0086..."
3,3,1,"[-0.9141901135444641, 0.6804959774017331, 0.90..."
4,4,1,"[0.040547348558902005, 0.08085644245147701, 1...."
...,...,...,...
32592,32592,0,"[0.036508537828922, -0.9830706119537351, 0.194..."
32593,32593,0,"[-0.49513417482376104, -0.42453348636627203, 0..."
32594,32594,0,"[-0.6720252633094781, -0.37544131278991705, 2...."
32595,32595,0,"[0.153745874762535, -0.533583104610443, -0.371..."


In [56]:
# Load the text lines

embeddings = df['cls_tokens']

#Turn EagerTensors list to Normal Tensors list
embeddings_pytorch = [torch.tensor(np.array(ast.literal_eval(e)), dtype=torch.float32) for e in embeddings]

# Convert list of tensors to a single tensor
embeddings_tensor = torch.stack(embeddings_pytorch).squeeze(1)  # Adjust dimensions as needed

embeddings_tensor

tensor([[ 0.2782, -0.3375,  0.8478,  ..., -0.7637, -0.5338,  0.3607],
        [-0.1395,  0.0931,  0.7390,  ...,  0.8308, -0.3973,  1.1339],
        [-1.4936,  0.7490,  1.0087,  ...,  1.5033, -1.2829, -0.5658],
        ...,
        [-0.6720, -0.3754,  2.1211,  ...,  1.0377, -0.6048, -0.7254],
        [ 0.1537, -0.5336, -0.3715,  ...,  1.6371,  0.6499, -0.5228],
        [-0.7284, -0.3828, -0.2024,  ...,  1.4613,  0.1921,  1.5803]])

In [57]:
embeddings_tensor.shape

torch.Size([32597, 768])

In [58]:
embeddings = embeddings_tensor


In [59]:
#Initialize Dataset
embeddings = embeddings_tensor 

# Instantiate the custom dataset
dataset = CustomDataset(embeddings)

# Training the Actual GAN model

In [60]:
# Now, create the DataLoader using the dataset
batch_size = 64  # Or any other batch size you wish to use
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optimizers
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

# TensorBoard writers
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

# Assuming the generator (gen), discriminator (disc), and their optimizers (opt_gen, opt_disc) are defined
# Also assuming a loss function (criterion) is defined
# z_dim is the dimensionality of the latent space (noise vector)

num_epochs = 10  # Number of epochs to train for

for epoch in range(num_epochs):
    for batch_idx, real_embeddings in enumerate(loader):
        batch_size = real_embeddings[0].size(0)
        real_embeddings = real_embeddings[0].to(device)

        # Train Discriminator
        # Generate fake embeddings
        noise = torch.randn(batch_size, z_dim, device=device)
        fake_embeddings = gen(noise)

        # Get discriminator predictions on real and fake data
        disc_real = disc(real_embeddings).view(-1)
        disc_fake = disc(fake_embeddings.detach()).view(-1)

        # Calculate loss on real and fake
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        # Update discriminator
        opt_disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        # Train Generator
        # Generate fake embeddings
        output = disc(fake_embeddings).view(-1)
        lossG = criterion(output, torch.ones_like(output))

        # Update generator
        opt_gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Optional: Print out loss values or save models/checkpoints here

       
        print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
              Loss D: {lossD:.4f}, loss G: {lossG:.4f}")


Epoch [0/10] Batch 0/510               Loss D: 0.2026, loss G: 1.8443
Epoch [0/10] Batch 1/510               Loss D: 0.2227, loss G: 1.9815
Epoch [0/10] Batch 2/510               Loss D: 0.3088, loss G: 2.0508
Epoch [0/10] Batch 3/510               Loss D: 0.1642, loss G: 2.1059
Epoch [0/10] Batch 4/510               Loss D: 0.1674, loss G: 2.1592
Epoch [0/10] Batch 5/510               Loss D: 0.1575, loss G: 2.2047
Epoch [0/10] Batch 6/510               Loss D: 0.1380, loss G: 2.2544
Epoch [0/10] Batch 7/510               Loss D: 0.1456, loss G: 2.3149
Epoch [0/10] Batch 8/510               Loss D: 0.1931, loss G: 2.3384
Epoch [0/10] Batch 9/510               Loss D: 0.1785, loss G: 2.3505
Epoch [0/10] Batch 10/510               Loss D: 0.1213, loss G: 2.3451
Epoch [0/10] Batch 11/510               Loss D: 0.1086, loss G: 2.3685
Epoch [0/10] Batch 12/510               Loss D: 0.1304, loss G: 2.3658
Epoch [0/10] Batch 13/510               Loss D: 0.2303, loss G: 2.3554
Epoch [0/10] Bat

# Generating a Fake CLS Token

In [61]:
num_examples = 1
noise = torch.randn(num_examples, z_dim)

with torch.no_grad():  # We don't need to track gradients for generation
    fake_data = gen(noise) 
    
fake_data.shape

torch.Size([1, 768])

In [62]:
fake_data

tensor([[ 1.7973e-01,  1.2505e-01,  9.1502e-02,  5.6659e-01,  2.0206e-01,
          6.0736e-01, -5.7118e-01,  2.9765e-01,  7.5517e-01,  3.0931e-02,
         -5.6140e-01, -8.9618e-02, -1.2270e-01,  4.5628e-01,  7.5275e-01,
         -5.9645e-01,  2.0139e-01,  8.1219e-01, -3.6552e-01, -5.2813e-01,
          3.5227e-01, -5.0198e-01,  3.6557e-01, -7.6460e-01, -1.9204e-01,
          2.6984e-01,  2.6896e-01,  5.2190e-01,  2.2123e-01, -8.2898e-01,
          5.6847e-01, -1.0157e-01, -3.3123e-01,  5.6599e-01,  5.3486e-01,
          5.4067e-02, -3.4344e-01,  1.1928e-01, -2.4428e-01,  6.4715e-01,
         -2.9957e-01, -8.0467e-01,  1.9973e-01, -2.2332e-01,  3.4915e-01,
         -2.8261e-01, -6.6266e-01, -6.2377e-01, -2.1833e-01,  2.2443e-01,
         -5.7208e-01,  3.7842e-01,  7.2347e-01, -3.9107e-02, -2.1125e-01,
          3.3803e-01, -4.7618e-01, -4.0101e-01,  4.7798e-01, -2.7365e-01,
         -4.0838e-01,  7.3872e-01, -5.5772e-01, -4.5435e-01, -3.2097e-01,
          1.2993e-01, -4.8096e-01,  3.