In [3]:
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from sklearn import datasets
from torch import nn, optim
from typing_extensions import Literal

%matplotlib inline
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# prepare data

In [4]:
# ! rm words_alpha*

In [5]:
# download dictionary

# ! curl https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt -o words_alpha.txt
! curl https://www.mit.edu/~ecprice/wordlist.10000 -o words_alpha.txt

corpus_fn = 'words_alpha.txt'
# corpus_fn = 'https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt'

with open(corpus_fn, 'r') as f:
    wordlist = f.read().split()

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 75880  100 75880    0     0   148k      0 --:--:-- --:--:-- --:--:--  148k


In [6]:
wordlist[:10]

['a',
 'aa',
 'aaa',
 'aaron',
 'ab',
 'abandoned',
 'abc',
 'aberdeen',
 'abilities',
 'ability']

In [7]:
wordlist = [word + '$' for word in wordlist]

In [8]:
wordlist[:10]

['a$',
 'aa$',
 'aaa$',
 'aaron$',
 'ab$',
 'abandoned$',
 'abc$',
 'aberdeen$',
 'abilities$',
 'ability$']

In [9]:
# make character list

def build_char_list(wordlist):
    charlist = set()
    for word in wordlist:
        charlist.update(word)
    charlist.add('$') # end character
    return sorted(charlist)

build_char_list(['abc', 'abd', 'aba'])

['$', 'a', 'b', 'c', 'd']

In [10]:
max(['a', 'b', 'cc'], key=len)
all([type(w) == str for w in wordlist])

True

In [11]:
charlist = build_char_list(wordlist)
input_dim = len(charlist)
input_length = max(32, len(max(wordlist, key=len)))
print('Number of unique characters: ', input_dim)
print('Max word length (32 if less): ', input_length)

Number of unique characters:  27
Max word length (32 if less):  32


In [12]:
id2char = dict(zip(range(input_dim), charlist))
char2id = dict(zip(charlist, range(input_dim)))

In [13]:
import pandas as pd

def encode(wordlist, 
           charlist=charlist, 
           input_dim=input_dim, 
           input_length=input_length,
           char2id=char2id,
           unknown_policy='zero'):
    inp = np.zeros((len(wordlist), input_length, input_dim))
    for i, word in enumerate(wordlist):
        if unknown_policy == 'skip':
            ints = [char2id.get(x) for x in word if x in char2id.keys()]
            inp[i, np.arange(len(ints)), ints] = 1
        elif unknown_policy == 'zero':
            ints = np.array([char2id.get(x) if x in char2id.keys() else -1 for x in word])
            ints_bool = (ints != -1)
            inp[i, np.arange(len(ints))[ints_bool], ints[ints_bool]] = 1
    return inp

def encode1d(wordlist, 
           charlist=charlist, 
           input_dim=input_dim, 
           input_length=input_length,
           char2id=char2id,
           unknown_policy='zero'):
    inp = np.zeros((len(wordlist), input_length))
    inp = inp - 1 # set non-character symbols to -1
    for i, word in enumerate(wordlist):
        if unknown_policy == 'skip':
            ints = [char2id.get(x) for x in word if x in char2id.keys()]
            inp[i, :len(ints)] = ints
        elif unknown_policy == 'zero':
            ints = np.array([char2id.get(x) if x in char2id.keys() else -1 for x in word])
            inp[i, :len(ints)] = ints
    return inp


def decode(out, 
           charlist=charlist,
           id2char=id2char):
    texts = []
    out = np.argmax(out, axis=2)
    for output in out:
        text = [id2char.get(x) for x in output if x in id2char.keys()]
        if '$' in text:
            text = text[:text.index('$')]
        text = ''.join(text)
        texts.append(text)
    return texts

In [14]:
encode1d(['aaabbb333'])

array([[ 1.,  1.,  1.,  2.,  2.,  2., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1., -1., -1., -1.]])

## 2. Vanilla GAN: MNIST

Теперь давайте обучим ту же самую архитектуру на чуть-чуть более серьёзные данных. Попробуем генерировать цифры из датасета MNIST.

In [17]:
class GeneratorChars(nn.Module):
    def __init__(self, 
                 input_length=input_length, 
                 input_dim=input_dim,
                 n_conv_filters=64,
                 n_fc_neurons=128,
                ):
        super().__init__()
        
        self.conv1 = nn.Sequential(nn.Conv1d(input_dim, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv2 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv3 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   nn.MaxPool1d(3))
        
        dimension = 2

        self.fc1 = nn.Sequential(nn.Linear(dimension, n_fc_neurons), nn.Dropout(0.5))
        self.fc2 = nn.Sequential(nn.Linear(n_fc_neurons, n_fc_neurons), nn.Dropout(0.5))
        self.fc3 = nn.Linear(n_fc_neurons, input_dim)

    def forward(
        self, z: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = self.conv1(z)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv3(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        print(x.shape)
        return x


class DiscriminatorChars(nn.Module):
    def __init__(self, 
                 input_length=32, 
                 input_dim=input_dim,
                 n_conv_filters=64,
                 n_fc_neurons=128,
                 embedding_dim=128):
        super().__init__()
        
        self.emb = nn.Embedding(input_dim, embedding_dim)
        self.conv1 = nn.Sequential(nn.Conv1d(input_dim, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv2 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv3 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   nn.MaxPool1d(3))
        
        dimension = 2

        self.fc1 = nn.Sequential(nn.Linear(dimension, n_fc_neurons), nn.Dropout(0.5))
        self.fc2 = nn.Sequential(nn.Linear(n_fc_neurons, n_fc_neurons), nn.Dropout(0.5))
        self.fc3 = nn.Linear(n_fc_neurons, 1)

    def forward(
        self, x: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv3(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        res = torch.sigmoid(x)
        return res

In [18]:
class DatasetChars:
    def __init__(self, 
                 wordlist,
                 test=False,
                 emb=False,
                ):
        self.wordlist = wordlist
        
        if test:
            self.charlist = charlist
            self.id2char = id2char
            self.char2id = char2id
            self.input_dim = input_dim
            self.input_length = input_length
        else:
            self.charlist = self.build_char_list()
            self.id2char = dict(zip(range(input_dim), charlist))
            self.char2id = dict(zip(charlist, range(input_dim)))
            self.input_dim = len(charlist)
            self.input_length = max(32, len(max(wordlist, key=len)))
        if emb:
            self.encoded_list = self.encode1d()
        else:
            self.encoded_list = self.encode()
        
    def build_char_list(self):
        charlist = set()
        for word in self.wordlist:
            charlist.update(word)
        charlist.add('$') # end character
        return sorted(charlist)
    
    def encode(self, unknown_policy='zero'):
        inp = np.zeros((len(self.wordlist), self.input_length, self.input_dim))
        for i, word in enumerate(self.wordlist):
            if unknown_policy == 'skip':
                ints = [self.char2id.get(x) for x in word if x in self.char2id.keys()]
                inp[i, np.arange(len(ints)), ints] = 1
            elif unknown_policy == 'zero':
                ints = np.array([self.char2id.get(x) if x in self.char2id.keys() else -1 for x in word])
                ints_bool = (ints != -1)
                inp[i, np.arange(len(ints))[ints_bool], ints[ints_bool]] = 1
        return torch.Tensor(inp)

    def encode1d(self, unknown_policy='zero'):
        inp = np.zeros((len(self.wordlist), self.input_length))
        inp = inp - 1 # set non-character symbols to -1
        for i, word in enumerate(self.wordlist):
            if unknown_policy == 'skip':
                ints = [char2id.get(x) for x in word if x in self.char2id.keys()]
                inp[i, :len(ints)] = ints
            elif unknown_policy == 'zero':
                ints = np.array([char2id.get(x) if x in self.char2id.keys() else -1 for x in word])
                inp[i, :len(ints)] = ints
        return inp

    def __getitem__(self, idx):
        return self.encoded_list[idx]

    def __len__(self):
        return len(self.wordlist)

In [19]:
# transform = transforms.Compose(
#     [transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
# )

from sklearn.model_selection import train_test_split

train_wordlist, test_wordlist = train_test_split(
     wordlist, test_size=0.2, random_state=42)

trainset = DatasetChars(train_wordlist, emb=True)
testset = DatasetChars(test_wordlist, emb=True)
trainloader_chars = torch.utils.data.DataLoader(
    trainset, batch_size=16, shuffle=True, num_workers=16, pin_memory=True
)



In [20]:
!pip install -q torchgan

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 0.8.0 requires torch==1.8.0, but you have torch 1.11.0 which is incompatible.[0m


In [21]:
from torchgan.models import Generator, Discriminator

class GeneratorChars(Generator):
    def __init__(self, 
                 encoding_dims=None,
                 input_length=input_length, 
                 input_dim=32,
                 n_conv_filters=64,
                 n_fc_neurons=128,
                ):
        super().__init__(encoding_dims=input_dim)
        self.encoding_dims = input_dim
        print(self.encoding_dims)
        
        self.conv1 = nn.Sequential(nn.Conv1d(input_dim, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv2 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv3 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   nn.MaxPool1d(3))
        
        dimension = 2

        self.fc1 = nn.Sequential(nn.Linear(dimension, n_fc_neurons), nn.Dropout(0.5))
        self.fc2 = nn.Sequential(nn.Linear(n_fc_neurons, n_fc_neurons), nn.Dropout(0.5))
        self.fc3 = nn.Linear(n_fc_neurons, input_dim)

    def forward(
        self, z: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = self.conv1(z)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv3(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class DiscriminatorChars(Discriminator):
    def __init__(self, 
                 input_dims=None,
                 input_length=32, 
                 input_dim=input_dim,
                 n_conv_filters=64,
                 n_fc_neurons=128,
                 embedding_dim=128,
                ):
        super().__init__(input_dims=input_dim)
        self.input_dims = input_dim
        
        self.emb = nn.Embedding(input_dim, embedding_dim)
        self.conv1 = nn.Sequential(nn.Conv1d(input_dim, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv2 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   )
        self.conv3 = nn.Sequential(nn.Conv1d(n_conv_filters, n_conv_filters, kernel_size=3, padding=0), nn.ReLU(),
                                   nn.MaxPool1d(3))
        
        dimension = 2

        self.fc1 = nn.Sequential(nn.Linear(dimension, n_fc_neurons), nn.Dropout(0.5))
        self.fc2 = nn.Sequential(nn.Linear(n_fc_neurons, n_fc_neurons), nn.Dropout(0.5))
        self.fc3 = nn.Linear(n_fc_neurons, 1)

    def forward(
        self, x: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        print(x.shape)
        print(x)
#         x = x.transpose(0, 1)
        x = self.emb(x.long())
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv3(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        res = torch.sigmoid(x)
        return res

AttributeError: module 'torch.nn.quantized' has no attribute 'Dropout'

In [None]:
gen_chars = GeneratorChars(100)
gen_chars.to(device)

discr_chars = DiscriminatorChars()
discr_chars.to(device)

# prior_chars = torch.distributions.Normal(
#     torch.zeros(input_dim, input_length).to(device), torch.ones(input_dim, input_length).to(device)
# )

gen_opt_chars = optim.Adam(gen_chars.parameters(), lr=3e-5)
discr_opt_chars = optim.Adam(discr_chars.parameters(), lr=3e-5, betas=(0.5, 0.999))

In [None]:
gan_network = {
    "generator": {
        "name": GeneratorChars,
        "args": {
            "encoding_dims": 100,
        },
        "optimizer": {"name": optim.Adam, "args": {"lr": 0.0001, "betas": (0.5, 0.999)}},
    },
    "discriminator": {
        "name": DiscriminatorChars,
        "args": {
        },
        "optimizer": {"name": optim.Adam, "args": {"lr": 0.0003, "betas": (0.5, 0.999)}},
    },
}


In [None]:
from torchgan.trainer import Trainer
from torchgan.losses import LeastSquaresDiscriminatorLoss, LeastSquaresGeneratorLoss

trainer = Trainer(
    models=gan_network,
    losses_list=[LeastSquaresDiscriminatorLoss(), LeastSquaresGeneratorLoss()],
    ncritic=1, 
    epochs=1, 
    sample_size=2, 
    checkpoints='./model/gan', 
    retain_checkpoints=5, 
#     recon='./images', 
    log_dir=None, 
    test_noise=None, 
    nrow=8
)

In [None]:
trainer.train(trainloader_chars)

In [None]:
train_gan(
    trainloader_chars,
    gen_chars,
    discr_chars,
    gen_opt_chars,
    discr_opt_chars,
    gan_loss,
    prior_chars,
    num_epochs=8,
    gen_steps=10,
    discr_steps=1,
    verbose_num_iters=100,
#     data_type="mnist",
)

In [None]:
z = prior_chars.sample((16,))

sampled_chars = gen_chars(z)

print(*decode(sampled_chars.detach().cpu().numpy()), sep='\n')

In [None]:
sampled_chars