In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [36]:
import pandas as pd
import numpy as np
import sklearn.preprocessing

In [60]:
class NameDataset(Dataset):
    def __init__(self, csv_file="Indian-Female-Names.csv"):
        """
        Arguments:
            csv_file (string): Path to the csv file.
        """
        data = pd.read_csv(csv_file)
        self.data = [data.iloc[i]['name'] for i in range(len(data)) if len(str(data.iloc[i]['name']))==5]
    def __len__(self):
        return len(self.data)
    def _toNP(self, data:list[str]):
        label_binarizer = sklearn.preprocessing.LabelBinarizer()
        label_binarizer.fit(range(26))
        return label_binarizer.transform(np.array([ord(c)-96-1 for c in data.lower()])).astype(np.float32)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self._toNP(self.data[idx])

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [211]:
# Hyperparameter
lr = 3e-4
latent_dim = 32
inp_dim = 5*26
batch_size = 32
num_epochs = 64

In [212]:
dataset = NameDataset()
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [213]:
class Discriminator(nn.Module):
    def __init__(self, inp_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(inp_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, latent_dim, inp_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, inp_dim),
            nn.Tanh()
        )
    def forward(self, x):
        return self.gen(x)

In [266]:
disc = Discriminator(inp_dim).to(device)
gen = Generator(latent_dim, inp_dim).to(device)

fixed_noise = torch.randn((batch_size, latent_dim)).to(device)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

In [267]:
# Load model weights
state_gen = torch.load("gen_model.pt")
gen.load_state_dict(state_gen['state_dict'])
opt_gen.load_state_dict(state_gen['optimizer'])

state_disc = torch.load("disc_model.pt") 
disc.load_state_dict(state_disc['state_dict'])
opt_disc.load_state_dict(state_disc['optimizer'])

In [218]:
for epoch in range(num_epochs):
    for batch_idx, real in enumerate(loader): 
        real = real.view(-1, inp_dim).to(device)
        
        ### train Discriminator: max log(D(real)) + log(1-D(G(z)))
        noise = torch.randn((real.shape[0], latent_dim)).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        
        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real+lossD_fake)/2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()
        
        
        ### train Generator: min log(1-D(G(z))) <-> max log(D(G(z)))
        ## i want to use the fake = gen(noise) again 
        # so disc(fake).view(-1) => disc(fake.detach()).view(-1) or lossD.backward(retain_graph=True)
        
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \ "
                f"Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise).cpu().numpy().reshape(-1, 5, 26)
                print("Generating some names...")
                for i, name in enumerate([''.join([chr(c+96+1) for c in i]) for i in fake.argmax(2)]):
                    print(name, end=' ')
                    if(i+1)%8==0: print()
                print('\n****************************************************************')

Epoch [0/64] \ Loss D: 0.6879, loss G: 0.7686
Generating some names...
pdzyq uwykb zxykb swufv zvcrx kvyix ucbfa jbuya 
fwysq ucjyz uiddf fwjir ujzcq ibcxq udkyz pemio 
umpxr ujjlx eclyp dbrca udcya ubpys ucjxp kxbip 
gekix umpur djjcp fbjca ugjif zcycx lbpuw ftyks 

****************************************************************
Epoch [1/64] \ Loss D: 0.7039, loss G: 0.7762
Generating some names...
rosta mwypi mxspa mwifz ribri kabpa rabfa jasta 
swsta rispi kibti jwhix posta jzhta kisfz jamra 
tibti jojta raspa maspa raspa kzipi mahra kxbpa 
kasra jihtw rosta mxsfa kajpa royda kaspa mwhpi 

****************************************************************
Epoch [2/64] \ Loss D: 0.4882, loss G: 1.4792
Generating some names...
auzea yurkb duuka dvuki nvmki aurea vvzya aamka 
aazea nuoyi nvzai yumki anzea yuoka yhmka yamko 
yumki duzka eflka dvnka nuoya barki dvnka yvrka 
narea ymzai dunya buoga yvrga nunca bvnka nvrki 

****************************************************************
E

In [268]:
print("Training dataset...")
for i, name in enumerate([''.join([chr(c+96+1) for c in i]) for i in real.cpu().numpy().reshape(-1, 5, 26).argmax(2)]):
    print(name, end=' ')
    if(i+1)%8==0: print()

Training dataset...
seema nikki swati sonia seema reema radha sapna 
laxmi salma pinky panku pooja aarti deepa roopa 
deepa suman seema mercy naina deepa laxmi arshi 
kanta nisha rekha geeta priti komal 

In [281]:
with torch.no_grad():
    fake = gen(torch.randn((batch_size, latent_dim)).to(device))
    scores = disc(fake).view(-1).cpu().numpy()
    print("Generating some names...")
    for i, name in enumerate([''.join([chr(c+96+1) for c in i]) for i in fake.cpu().numpy().reshape(-1, 5, 26).argmax(2)]):
        print(f"{name}[{scores[i]:.2f}]", end=' ')
        if(i+1)%8==0: print()
    print('\n****************************************************************')

Generating some names...
kamal[0.26] rumli[0.21] reena[0.43] reena[0.49] pooja[0.30] seena[0.24] resha[0.36] rarhi[0.27] 
sayma[0.27] kumal[0.42] kajal[0.31] risha[0.32] mumla[0.26] sumla[0.20] sakha[0.19] panni[0.27] 
sajal[0.26] kadal[0.36] rosha[0.27] saxma[0.20] nirha[0.40] reeha[0.40] panai[0.27] sasha[0.18] 
pooja[0.42] gamna[0.20] geeta[0.21] gamna[0.15] gumla[0.35] samaa[0.34] andta[0.21] poola[0.24] 

****************************************************************


In [254]:
data = []
data_score = []
alpha_score = 0
with torch.no_grad():
    for _ in range(128):
        fake = gen(torch.randn((batch_size, latent_dim)).to(device))
        scores = disc(fake).view(-1).cpu().numpy()
        
        print("Generating some names...")
        for i, name in enumerate([''.join([chr(c+96+1) for c in i]) for i in fake.cpu().numpy().reshape(-1, 5, 26).argmax(2)]):
            score = scores[i]
            if score>=alpha_score:
                if name not in data:
                    data.append(name)
                    data_score.append(score)
        
with open('output.txt', 'w') as f: 
    f.write('\n'.join( [f"{i} {s:.2f}" for i, s in sorted([(i, s) for i, s in zip(data, data_score)], key=lambda x: x[1], reverse=True)] ))

Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...
Generating some names...


In [262]:
state_gen = {
    'epoch': epoch,
    'state_dict': gen.state_dict(),
    'optimizer': opt_gen.state_dict(),
}
torch.save(state_gen, "gen_model.pt")
state_disc = {
    'epoch': epoch,
    'state_dict': disc.state_dict(),
    'optimizer': opt_disc.state_dict(),
}
torch.save(state_disc, "disc_model.pt")