In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import GPT2Tokenizer, GPT2LMHeadModel

model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Load the GPT2 model and resize its token embeddings to accommodate the new padding token
generator_model = GPT2LMHeadModel.from_pretrained(model_name)
generator_model.resize_token_embeddings(len(tokenizer))

discriminator_model = GPT2LMHeadModel.from_pretrained(model_name, output_hidden_states=True)
discriminator_model.resize_token_embeddings(len(tokenizer))

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = generator_model
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.logits

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = discriminator_model
        self.classifier = nn.Linear(self.model.config.n_embd, 1)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.hidden_states[-1]  # Get the last hidden state
        cls_token = hidden_states[:, 0, :]  # Get the embeddings of the [CLS] token
        return self.classifier(cls_token)

# Instantiate models
generator = Generator()
discriminator = Discriminator()

# Define loss functions
criterion = nn.BCEWithLogitsLoss()

# Define optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=2e-5)
optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-5)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import load_dataset
from transformers import BertTokenizer

# Load your dataset
dataset = load_dataset('json', data_files='/home/hmasson/deepL/data/original_data.jsonl')['train']

# Initialize the tokenizer
# model_name = "bert-base-uncased"
# tokenizer = BertTokenizer.from_pretrained(model_name)

# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(example['review'], padding='max_length', truncation=True, max_length=64)

tokenized_dataset = dataset.map(tokenize_function, batched=True)

# Convert tokenized dataset to PyTorch tensors
input_ids = torch.tensor(tokenized_dataset['input_ids'])
attention_masks = torch.tensor(tokenized_dataset['attention_mask'])
labels = torch.tensor(tokenized_dataset['label'])

# Split the data into training and validation sets
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(
    input_ids, labels, random_state=42, test_size=0.1
)
train_masks, validation_masks, _, _ = train_test_split(
    attention_masks, labels, random_state=42, test_size=0.1
)

# Create TensorDataset
train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
validation_dataset = TensorDataset(validation_inputs, validation_masks, validation_labels)

# Create DataLoader
train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=32)
validation_dataloader = DataLoader(validation_dataset, sampler=SequentialSampler(validation_dataset), batch_size=32)


In [3]:
import random
import torch

N_EPOCHS = 15
best_loss_G = float('inf')

for epoch in range(N_EPOCHS):
    generator.train()
    discriminator.train()
    
    total_loss_G = 0
    total_loss_D = 0
    
    for batch in train_dataloader:
        input_ids, attention_masks, labels = batch

        # Training Discriminator
        optimizer_D.zero_grad()

        real_labels = torch.ones((input_ids.size(0), 1), device=input_ids.device)
        fake_labels = torch.zeros((input_ids.size(0), 1), device=input_ids.device)

        outputs = discriminator(input_ids, attention_mask=attention_masks)
        loss_real = criterion(outputs, real_labels)
        
        # Generate fake text
        noise = torch.randint(0, tokenizer.vocab_size, input_ids.shape, device=input_ids.device)
        fake_logits = generator(noise)
        fake_texts = torch.argmax(fake_logits, dim=-1).type(torch.LongTensor).to(input_ids.device)
        fake_outputs = discriminator(fake_texts.detach())
        loss_fake = criterion(fake_outputs, fake_labels)
        
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()
        
        # Training Generator
        optimizer_G.zero_grad()
        
        fake_outputs = discriminator(fake_texts)
        loss_G = criterion(fake_outputs, real_labels)
        loss_G.backward()
        optimizer_G.step()
        
        total_loss_G += loss_G.item()
        total_loss_D += loss_D.item()
    
    avg_loss_G = total_loss_G / len(train_dataloader)
    avg_loss_D = total_loss_D / len(train_dataloader)
    
    print(f'Epoch {epoch+1}/{N_EPOCHS} | Generator Loss: {avg_loss_G} | Discriminator Loss: {avg_loss_D}')
    
    # Save the best generator model
    if avg_loss_G < best_loss_G:
        best_loss_G = avg_loss_G
        torch.save(generator.state_dict(), 'best-generator-model.pt')


In [None]:
# Instantiate the generator model
generator = Generator()

# Load the best model state
generator.load_state_dict(torch.load('best-generator-model.pt'))

# Set the model to evaluation mode
generator.eval()


In [None]:
# Function to decode generated text tokens
def decode_generated_text(generated_text_tokens):
    decoded_texts = []
    for tokens in generated_text_tokens:
        decoded_text = tokenizer.decode(tokens, skip_special_tokens=True)
        decoded_texts.append(decoded_text)
    return decoded_texts

# Generate text from random noise
def generate_text(generator, tokenizer, num_samples=5, max_length=64):
    # Generate random noise as input
    noise = torch.randint(0, tokenizer.vocab_size, (num_samples, max_length), device='cuda')
    
    # Generate logits from the noise
    with torch.no_grad():
        generated_logits = generator(noise)
    
    # Convert logits to token indices
    generated_text_tokens = torch.argmax(generated_logits, dim=-1)
    
    # Decode the generated text tokens
    generated_texts = decode_generated_text(generated_text_tokens)
    
    return generated_texts

# Generate and print text samples
generated_texts = generate_text(generator, tokenizer, num_samples=5)
for i, text in enumerate(generated_texts):
    print(f"Generated Text {i+1}:\n{text}\n")
