In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
from tqdm import tqdm as tqdm

import logging
logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [6]:
device = 'cpu'

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
generator_model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
generator_model = generator_model.to(device)

In [42]:
text_index = tokenizer.encode('man')
vector = generator_model.transformer.wte.weight[text_index,:]

In [50]:
tokenizer.decode(582)

' man'

In [126]:
def generate_jokes(num_jokes=1,return_embeddings=True):
    model = generator_model
    joke_num=0
    jokes = []
    with torch.no_grad():
       
        for joke_idx in range(num_jokes):
        
            joke_finished = False

            cur_ids = torch.tensor(tokenizer.encode("JOKE:")).unsqueeze(0).to(device)

            for i in range(100):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]
                softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding
                if i < 3:
                    n = 20
                else:
                    n = 3
                next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    joke_finished = True
                    break

            joke_finished=True
            if joke_finished:
                
                joke_num = joke_num + 1
                output_tensor = cur_ids.squeeze().to(device)
                if not return_embeddings:
                    output_list = list(output_tensor.numpy())
                    output_text = tokenizer.decode(output_list)
                    jokes.append(output_text)
                else:
                    joke_embedding = [generator_model.transformer.wte.weight[text_index,:] for text_index in output_tensor]
                    jokes.append(torch.stack(joke_embedding))
    if return_embeddings:
        jokes = torch.stack(jokes)
    return jokes

In [61]:
jokes = generate_jokes(2,return_embeddings=True)
print(jokes.shape)

torch.Size([2, 103, 1024])


In [128]:
type(jokes)

torch.Tensor

In [72]:
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
import os
import json
import csv

class JokesDataset(Dataset):
    def __init__(self, jokes_dataset_path = 'data/'):
        super().__init__()

        short_jokes_path = os.path.join(jokes_dataset_path, 'dadjokesfinal.csv')

        self.joke_list = []
        self.end_of_text_token = "\r"
        
        with open(short_jokes_path) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            
            x = 0
            for row in csv_reader:
                joke_str = f"JOKE:{row[1]}{self.end_of_text_token}"
                self.joke_list.append(joke_str)
        
    def __len__(self):
        return len(self.joke_list)

    def __getitem__(self, item):
        return self.joke_list[item]

In [142]:
sequence_length=103
embedding_dim = 1024
hidden_dim=128
batch_size=2


dataset = JokesDataset()
joke_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#discriminator_model = torch.nn.Sequential(
#    torch.nn.GRU(input_size=embedding_dim,hidden_size=hidden_dim,num_layers=1,batch_first=True),
#    torch.nn.Linear(hidden_dim * 2 * 2, 1),
#    torch.nn.Sigmoid()
#)

In [210]:
class Discriminator(torch.nn.Module):
    def __init__(self, embedding_dim, hidden_dim,num_layers,batch_first):
        super(Discriminator, self).__init__()
        self.gru = torch.nn.GRU(input_size=embedding_dim,hidden_size=hidden_dim,num_layers=1,batch_first=True)
        self.linear = torch.nn.Linear(hidden_dim, 1)
        self.activation = torch.nn.Sigmoid()
    def forward(self, x):
        x, hidden = self.gru(x)
        x = self.linear(hidden)
        output = self.activation(x)
        return output.flatten()

In [228]:
def train_gan(n_iters,verbose=False):
    data_iterator = iter(joke_loader)
    generator = generator_model
    discriminator = Discriminator(embedding_dim=embedding_dim, hidden_dim=hidden_dim,num_layers=1,batch_first=True)
    generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)
    loss = torch.nn.BCELoss()
    for i in range(n_iters):
        if verbose:
            print(f"#### ITERATION {i} ####")
        
        #train generator
        if verbose:
            print("training generator")
        generator_optimizer.zero_grad()
        gen_examples = generate_jokes(batch_size,return_embeddings=True)
        discriminator_preds = discriminator(gen_examples)
        print(discriminator_preds)
        print(discriminator_preds.shape)
        #generator wants discriminator to predict these as true
        labels = torch.ones(len(gen_examples))
        generator_loss = loss(discriminator_preds, torch.ones(len(labels)))
        if verbose:
            print(f"generator loss: {generator_loss}")
        generator_loss.backward()
        generator_optimizer.step()
        
        
        if verbose:
            print("training discriminator on true examples")
        #compute discriminator loss on true examples
        discriminator_optimizer.zero_grad()
        #create list of plain text jokes
        true_examples_text = data_iterator.next()
        #create list of array indices corresponding to embedding dict
        true_examples_vector = [tokenizer.encode(joke) for joke in true_examples_text] 
        #create list of embeddings tensors using embedding dict
        true_examples_list = [generator_model.transformer.wte.weight[text_index,:] for text_index in true_examples_vector]
        #create tensor of embedded jokes padded to same length
        true_examples = torch.nn.utils.rnn.pad_sequence(true_examples_list,batch_first=True)
        
        discriminator_preds_on_true = discriminator(true_examples)
        #discriminator wants to predict these as true
        true_labels = torch.ones(len(true_examples))
        discriminator_loss_on_true = loss(discriminator_preds_on_true, true_labels)
        
        if verbose:
            print(f"discriminator loss on true examples: {discriminator_loss_on_true}")
            print("training discriminator on generated examples")
        #compute discriminator  loss on generated examples
        discriminator_preds_on_gen = discriminator(gen_examples)
        #discriminator wants to predict these as false
        gen_labels = torch.zeros(len(gen_examples))
        discriminator_loss_on_gen = loss(discriminator_preds_on_gen, gen_labels)
        if verbose:
            print(f"discriminator loss on generated examples: {discriminator_loss_on_gen}")
        
        discriminator_loss = (discriminator_loss_on_true + discriminator_loss_on_gen) / 2
        if verbose:
            print(f"total discriminator loss: {discriminator_loss}")
        discriminator_loss.backward()
        discriminator_optimizer.step()

In [None]:
train_gan(1,verbose=True)

#### ITERATION 0 ####
training generator
