In [None]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

from torch import nn
import glob

In [None]:

# collect_fn for dataloader

class RNN(nn.Module):

    def __init__(self, hidden_n, vocab_n):
        super().__init__()
        self.whh = nn.Linear(in_features=hidden_n, out_features=hidden_n, bias=False)
        self.wxh = nn.Linear(in_features=vocab_n, out_features=hidden_n, bias=True)
        self.why = nn.Linear(in_features=hidden_n, out_features=vocab_n, bias=True)
        
    #def forward(self, x):
    #    # update the hidden state
    #    self.h = torch.tanh(torch.mm(self.W_hh, self.h) + torch.mm(self.W_xh, x) + bh)
    #    # compute the output vector
    #    y = torch.mm(self.W_hy, self.h) + by
    #    return y
    def forward(self, h, x):
        # update the hidden state
        whhh = self.whh(h)
        wxhx = self.wxh(x)
        h_ = torch.tanh(whhh + wxhx) 
        # compute the output vector
        y = self.why(h_)
        return h_, y

In [None]:
dataframes = []
for path in glob.glob("data/names/*.txt"):
    classname = path.split('/')[-1].split('.')[0]
    abspath = os.path.abspath(path)
    names = pd.read_csv(abspath, header=None)
    names['label'] = [classname] * names.shape[0]
    dataframes.append(names)
data = pd.concat(dataframes)

In [None]:
data['X'] = data[0].apply(str.lower)

In [None]:
vocab = set([char for char in ''.join(data.X.values.tolist())])
SEQUENCE_END = '_'
NULL_VALUE = '0'
vocab.add(SEQUENCE_END)
vocab.add(NULL_VALUE)

vocab2int = dict(zip(vocab,np.arange(len(vocab))))

appended = data.X + '_'
converted = appended.apply(lambda name : [vocab2int[c] for c in name])
data['converted'] = converted

In [None]:
classes = set(data.label)
classes_n = len(classes)
class2int = dict(zip(classes, np.arange(len(classes))))

def onehot_label(dico, value):
    y_onehot = np.zeros((len(dico),1))
    y_onehot[class2int[value]] = 1
    return y_onehot

In [None]:
torch.cuda.set_device(0)
device = torch.device('cuda')

vocab_n = len(vocab)
hidden_n = 3
rnn = RNN(hidden_n=hidden_n, vocab_n=vocab_n).to(device=device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters())

# TRAINING
for index, (name, label) in data[['converted', 'label']].iterrows():
    batch_size = 1
    y = onehot_label(class2int, label)
    
    x_onehot = torch.FloatTensor(len(name), vocab_n).to(device=device)
    x = torch.from_numpy(np.array(name)).to(device=device)

    x_onehot.zero_()
    x_onehot.scatter_(1, x.view(-1,1), 1)
    
    output = []
    hx = torch.zeros(hidden_n).to(device=device).float()
    for i in range(len(name)-1):
        hx,y_pred = rnn(hx, x_onehot[i,:])
        output.append(y_pred)
        
    # Decoding
    # loss and gradient update from hx & y
    
    optimizer.zero_grad()
    for i in reversed(range(1,len(name))):
        y_pred = output[i-1].view(batch_size,-1)
        y_char = x_onehot[i, :].argmax().unsqueeze(0)
    
        #import pdb; pdb.set_trace()
        loss = criterion(y_pred, y_char.long())
        not_last_char = i > 0
        loss.backward(retain_graph=not_last_char)
    optimizer.step()

In [None]:
with torch.no_grad():
    print(rnn(torch.zeros(hidden_n).to(device=device).float(),x_onehot[2,:])[1].argmax())

In [None]:
x