In [1]:
# =============================================================================
# Libs
# =============================================================================
from torch.utils.data import Dataset
import torch.nn.functional as F
from collections import Counter
from os.path import exists
import torch.optim as optim
import torch.nn as nn
import numpy as np
import random
import torch
import math
import re


# =============================================================================
# Transformer
# =============================================================================
def attention(q, k, v, mask = None, dropout = None):
    scores = q.matmul(k.transpose(-2, -1))
    scores /= math.sqrt(q.shape[-1])
    
    #mask
    scores = scores if mask is None else scores.masked_fill(mask == 0, -1e3)
    
    scores = F.softmax(scores, dim = -1)
    scores = dropout(scores) if dropout is not None else scores
    output = scores.matmul(v)
    return output

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, out_dim, dropout=0.1):
        super().__init__()
        
#        self.q_linear = nn.Linear(out_dim, out_dim)
#        self.k_linear = nn.Linear(out_dim, out_dim)
#        self.v_linear = nn.Linear(out_dim, out_dim)
        self.linear = nn.Linear(out_dim, out_dim*3)

        self.n_heads = n_heads
        self.out_dim = out_dim
        self.out_dim_per_head = out_dim // n_heads
        self.out = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
    
    def split_heads(self, t):
        return t.reshape(t.shape[0], -1, self.n_heads, self.out_dim_per_head)
    
    def forward(self, x, y=None, mask=None):
        #in decoder, y comes from encoder. In encoder, y=x
        y = x if y is None else y
        
        qkv = self.linear(x) # BS * SEQ_LEN * (3*EMBED_SIZE_L)
        q = qkv[:, :, :self.out_dim] # BS * SEQ_LEN * EMBED_SIZE_L
        k = qkv[:, :, self.out_dim:self.out_dim*2] # BS * SEQ_LEN * EMBED_SIZE_L
        v = qkv[:, :, self.out_dim*2:] # BS * SEQ_LEN * EMBED_SIZE_L
        
        #break into n_heads
        q, k, v = [self.split_heads(t) for t in (q,k,v)]  # BS * SEQ_LEN * HEAD * EMBED_SIZE_P_HEAD
        q, k, v = [t.transpose(1,2) for t in (q,k,v)]  # BS * HEAD * SEQ_LEN * EMBED_SIZE_P_HEAD
        
        #n_heads => attention => merge the heads => mix information
        scores = attention(q, k, v, mask, self.dropout) # BS * HEAD * SEQ_LEN * EMBED_SIZE_P_HEAD
        scores = scores.transpose(1,2).contiguous().view(scores.shape[0], -1, self.out_dim) # BS * SEQ_LEN * EMBED_SIZE_L
        out = self.out(scores)  # BS * SEQ_LEN * EMBED_SIZE
        
        return out

class FeedForward(nn.Module):
    def __init__(self, inp_dim, inner_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(inp_dim, inner_dim)
        self.linear2 = nn.Linear(inner_dim, inp_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        #inp => inner => relu => dropout => inner => inp
        return self.linear2(self.dropout(F.relu(self.linear1(x)))) 

class EncoderLayer(nn.Module):
    def __init__(self, n_heads, inner_transformer_size, inner_ff_size, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(n_heads, inner_transformer_size, dropout)
        self.ff = FeedForward(inner_transformer_size, inner_ff_size, dropout)
        self.norm1 = nn.LayerNorm(inner_transformer_size)
        self.norm2 = nn.LayerNorm(inner_transformer_size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        x2 = self.norm1(x)
        x = x + self.dropout1(self.mha(x2, mask=mask))
        x2 = self.norm2(x)
        x = x + self.dropout2(self.ff(x2))
        return x

class Transformer(nn.Module):
    def __init__(self, n_code, n_heads, embed_size, inner_ff_size, n_embeddings, seq_len, dropout=.1):
        super().__init__()
        
        #model input
        self.embeddings = nn.Embedding(n_embeddings, embed_size)
        self.pe = PositionalEmbedding(embed_size, seq_len)
        
        #backbone
        encoders = []
        for i in range(n_code):
            encoders += [EncoderLayer(n_heads, embed_size, inner_ff_size, dropout)]
        self.encoders = nn.ModuleList(encoders)
        
        #language model
        self.norm = nn.LayerNorm(embed_size)
        self.linear = nn.Linear(embed_size, n_embeddings, bias=False)
                
    
    def forward(self, x):
        x = self.embeddings(x)
        x = x + self.pe(x)
        for encoder in self.encoders:
            x = encoder(x)
        x = self.norm(x)
        x = self.linear(x)
        return x

# Positional Embedding
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len = 80):
        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_seq_len, d_model)
        pe.requires_grad = False
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return self.pe[:,:x.size(1)] #x.size(1) = seq_len
    
# =============================================================================
# Dataset
# =============================================================================
class SentencesDataset(Dataset):
    #Init dataset
    def __init__(self, sentences, vocab, seq_len):
        dataset = self
        
        dataset.sentences = sentences
        dataset.vocab = vocab + ['<ignore>', '<oov>', '<mask>']
        dataset.vocab = {e:i for i, e in enumerate(dataset.vocab)} 
        dataset.rvocab = {v:k for k,v in dataset.vocab.items()}
        dataset.seq_len = seq_len
        
        #special tags
        dataset.IGNORE_IDX = dataset.vocab['<ignore>'] #replacement tag for tokens to ignore
        dataset.OUT_OF_VOCAB_IDX = dataset.vocab['<oov>'] #replacement tag for unknown words
        dataset.MASK_IDX = dataset.vocab['<mask>'] #replacement tag for the masked word prediction task
    
    
    #fetch data
    def __getitem__(self, index, p_random_mask=0.15):
        dataset = self
        
        #while we don't have enough word to fill the sentence for a batch
        s = []
        while len(s) < dataset.seq_len:
            s.extend(dataset.get_sentence_idx(index % len(dataset)))
            index += 1
        
        #ensure that the sequence is of length seq_len
        s = s[:dataset.seq_len]
        [s.append(dataset.IGNORE_IDX) for i in range(dataset.seq_len - len(s))] #PAD ok
        
        #apply random mask
        s = [(dataset.MASK_IDX, w) if random.random() < p_random_mask else (w, dataset.IGNORE_IDX) for w in s]
        
        return {'input': torch.Tensor([w[0] for w in s]).long(),
                'target': torch.Tensor([w[1] for w in s]).long()}

    #return length
    def __len__(self):
        return len(self.sentences)

    #get words id
    def get_sentence_idx(self, index):
        dataset = self
        s = dataset.sentences[index]
        s = [dataset.vocab[w] if w in dataset.vocab else dataset.OUT_OF_VOCAB_IDX for w in s] 
        return s

# =============================================================================
# Methods / Class
# =============================================================================
def get_batch(loader, loader_iter):
    try:
        batch = next(loader_iter)
    except StopIteration:
        loader_iter = iter(loader)
        batch = next(loader_iter)
    return batch, loader_iter

# =============================================================================
# #Init
# =============================================================================
print('initializing..')
batch_size = 1024
seq_len = 20
embed_size = 128
inner_ff_size = embed_size * 4
n_heads = 8
n_code = 8
n_vocab = 40000
dropout = 0.1
# n_workers = 12

#optimizer
optim_kwargs = {'lr':1e-4, 'weight_decay':1e-4, 'betas':(.9,.999)}

# =============================================================================
# Input
# =============================================================================
#1) load text
print('loading text...')
pth = 'training.txt'
sentences = open(pth).read().lower().split('\n')

#2) tokenize sentences (can be done during training, you can also use spacy udpipe)
print('tokenizing sentences...')
special_chars = ',?;.:/*!+-()[]{}"\'&'
sentences = [re.sub(f'[{re.escape(special_chars)}]', ' \g<0> ', s).split(' ') for s in sentences]
sentences = [[w for w in s if len(w)] for s in sentences]

#3) create vocab if not already created
print('creating/loading vocab...')
pth = 'vocab.txt'
if not exists(pth):
    words = [w for s in sentences for w in s]
    vocab = Counter(words).most_common(n_vocab) #keep the N most frequent words
    vocab = [w[0] for w in vocab]
    open(pth, 'w+').write('\n'.join(vocab))
else:
    vocab = open(pth).read().split('\n')

#4) create dataset
print('creating dataset...')
dataset = SentencesDataset(sentences, vocab, seq_len)
# kwargs = {'num_workers':n_workers, 'shuffle':True,  'drop_last':True, 'pin_memory':True, 'batch_size':batch_size}
kwargs = {'shuffle':True,  'drop_last':True, 'pin_memory':True, 'batch_size':batch_size}
data_loader = torch.utils.data.DataLoader(dataset, **kwargs)


# =============================================================================
# Model
# =============================================================================
#init model
print('initializing model...')
model = Transformer(n_code, n_heads, embed_size, inner_ff_size, len(dataset.vocab), seq_len, dropout)
model = model.cuda()

# =============================================================================
# Optimizer
# =============================================================================
print('initializing optimizer and loss...')
optimizer = optim.Adam(model.parameters(), **optim_kwargs)
loss_model = nn.CrossEntropyLoss(ignore_index=dataset.IGNORE_IDX)

# =============================================================================
# Train
# =============================================================================
print('training...')
print_each = 10
model.train()
batch_iter = iter(data_loader)
n_iteration = 10000
for it in range(n_iteration):
    
    #get batch
    batch, batch_iter = get_batch(data_loader, batch_iter)
    
    #infer
    masked_input = batch['input']
    masked_target = batch['target']
    
    masked_input = masked_input.cuda(non_blocking=True)
    masked_target = masked_target.cuda(non_blocking=True)
    output = model(masked_input)
    
    #compute the cross entropy loss 
    output_v = output.view(-1,output.shape[-1])
    target_v = masked_target.view(-1,1).squeeze()
    loss = loss_model(output_v, target_v)
    
    #compute gradients
    loss.backward()
    
    #apply gradients
    optimizer.step()
    
    #print step
    if it % print_each == 0:
        print('it:', it, 
              ' | loss', np.round(loss.item(),2),
              ' | Δw:', round(model.embeddings.weight.grad.abs().sum().item(),3))
    
    #reset gradients
    optimizer.zero_grad()
    

# =============================================================================
# Results analysis
# =============================================================================
print('saving embeddings...')
N = 3000
np.savetxt('values.tsv', np.round(model.embeddings.weight.detach().cpu().numpy()[0:N], 2), delimiter='\t', fmt='%1.2f')
s = [dataset.rvocab[i] for i in range(N)]
open('names.tsv', 'w+').write('\n'.join(s) )


print('end')





C:\Users\INHOPE\anaconda3\lib\site-packages\numpy\.libs\libopenblas.wcdjnk7yvmpzq2me2zzhjjrj3jikndb7.gfortran-win_amd64.dll
C:\Users\INHOPE\anaconda3\lib\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll


initializing..
loading text...
tokenizing sentences...
creating/loading vocab...
creating dataset...
initializing model...
initializing optimizer and loss...
training...
it: 0  | loss 10.37  | Δw: 1.1
it: 10  | loss 9.62  | Δw: 0.603
it: 20  | loss 9.39  | Δw: 0.377
it: 30  | loss 9.24  | Δw: 0.308
it: 40  | loss 9.07  | Δw: 0.256
it: 50  | loss 8.91  | Δw: 0.222
it: 60  | loss 8.79  | Δw: 0.2
it: 70  | loss 8.62  | Δw: 0.199
it: 80  | loss 8.43  | Δw: 0.18
it: 90  | loss 8.23  | Δw: 0.18
it: 100  | loss 8.12  | Δw: 0.17
it: 110  | loss 8.03  | Δw: 0.163
it: 120  | loss 7.83  | Δw: 0.158
it: 130  | loss 7.74  | Δw: 0.149
it: 140  | loss 7.59  | Δw: 0.146
it: 150  | loss 7.49  | Δw: 0.151
it: 160  | loss 7.34  | Δw: 0.14
it: 170  | loss 7.23  | Δw: 0.136
it: 180  | loss 7.13  | Δw: 0.137
it: 190  | loss 7.01  | Δw: 0.13
it: 200  | loss 6.89  | Δw: 0.13
it: 210  | loss 6.9  | Δw: 0.136
it: 220  | loss 6.81  | Δw: 0.127
it: 230  | loss 6.73  | Δw: 0.129
it: 240  | loss 6.77  | Δw: 0.129
i

it: 2350  | loss 5.43  | Δw: 4.605
it: 2360  | loss 5.39  | Δw: 4.523
it: 2370  | loss 5.36  | Δw: 4.555
it: 2380  | loss 5.3  | Δw: 4.706
it: 2390  | loss 5.32  | Δw: 4.739
it: 2400  | loss 5.32  | Δw: 4.79
it: 2410  | loss 5.28  | Δw: 4.679
it: 2420  | loss 5.31  | Δw: 4.881
it: 2430  | loss 5.3  | Δw: 4.699
it: 2440  | loss 5.36  | Δw: 5.088
it: 2450  | loss 5.37  | Δw: 5.019
it: 2460  | loss 5.35  | Δw: 4.592
it: 2470  | loss 5.27  | Δw: 5.241
it: 2480  | loss 5.35  | Δw: 4.762
it: 2490  | loss 5.32  | Δw: 4.747
it: 2500  | loss 5.29  | Δw: 4.904
it: 2510  | loss 5.32  | Δw: 4.908
it: 2520  | loss 5.38  | Δw: 4.738
it: 2530  | loss 5.32  | Δw: 4.976
it: 2540  | loss 5.33  | Δw: 4.923
it: 2550  | loss 5.34  | Δw: 4.992
it: 2560  | loss 5.3  | Δw: 4.772
it: 2570  | loss 5.19  | Δw: 4.809
it: 2580  | loss 5.38  | Δw: 5.086
it: 2590  | loss 5.2  | Δw: 4.772
it: 2600  | loss 5.15  | Δw: 4.94
it: 2610  | loss 5.1  | Δw: 4.989
it: 2620  | loss 5.33  | Δw: 5.159
it: 2630  | loss 5.26  | Δw

it: 4710  | loss 4.72  | Δw: 6.762
it: 4720  | loss 4.73  | Δw: 7.202
it: 4730  | loss 4.69  | Δw: 6.77
it: 4740  | loss 4.65  | Δw: 6.795
it: 4750  | loss 4.72  | Δw: 6.498
it: 4760  | loss 4.76  | Δw: 6.541
it: 4770  | loss 4.77  | Δw: 6.683
it: 4780  | loss 4.62  | Δw: 6.79
it: 4790  | loss 4.7  | Δw: 6.48
it: 4800  | loss 4.76  | Δw: 6.901
it: 4810  | loss 4.72  | Δw: 6.837
it: 4820  | loss 4.67  | Δw: 6.878
it: 4830  | loss 4.7  | Δw: 6.927
it: 4840  | loss 4.68  | Δw: 6.576
it: 4850  | loss 4.77  | Δw: 6.908
it: 4860  | loss 4.69  | Δw: 6.939
it: 4870  | loss 4.7  | Δw: 6.635
it: 4880  | loss 4.67  | Δw: 6.753
it: 4890  | loss 4.74  | Δw: 7.001
it: 4900  | loss 4.79  | Δw: 7.18
it: 4910  | loss 4.69  | Δw: 6.973
it: 4920  | loss 4.7  | Δw: 6.887
it: 4930  | loss 4.62  | Δw: 6.948
it: 4940  | loss 4.64  | Δw: 6.924
it: 4950  | loss 4.76  | Δw: 6.717
it: 4960  | loss 4.67  | Δw: 6.896
it: 4970  | loss 4.7  | Δw: 6.672
it: 4980  | loss 4.7  | Δw: 6.663
it: 4990  | loss 4.72  | Δw: 6

it: 7070  | loss 4.54  | Δw: 8.212
it: 7080  | loss 4.4  | Δw: 8.357
it: 7090  | loss 4.51  | Δw: 8.31
it: 7100  | loss 4.42  | Δw: 8.347
it: 7110  | loss 4.51  | Δw: 8.286
it: 7120  | loss 4.51  | Δw: 8.319
it: 7130  | loss 4.5  | Δw: 8.354
it: 7140  | loss 4.46  | Δw: 8.08
it: 7150  | loss 4.5  | Δw: 8.271
it: 7160  | loss 4.54  | Δw: 8.498
it: 7170  | loss 4.56  | Δw: 8.476
it: 7180  | loss 4.54  | Δw: 8.536
it: 7190  | loss 4.5  | Δw: 8.376
it: 7200  | loss 4.44  | Δw: 8.456
it: 7210  | loss 4.56  | Δw: 8.475
it: 7220  | loss 4.45  | Δw: 8.49
it: 7230  | loss 4.43  | Δw: 8.133
it: 7240  | loss 4.37  | Δw: 8.239
it: 7250  | loss 4.57  | Δw: 8.479
it: 7260  | loss 4.41  | Δw: 8.734
it: 7270  | loss 4.46  | Δw: 8.405
it: 7280  | loss 4.44  | Δw: 8.29
it: 7290  | loss 4.54  | Δw: 8.309
it: 7300  | loss 4.46  | Δw: 8.399
it: 7310  | loss 4.49  | Δw: 8.995
it: 7320  | loss 4.45  | Δw: 8.353
it: 7330  | loss 4.5  | Δw: 8.447
it: 7340  | loss 4.45  | Δw: 8.462
it: 7350  | loss 4.48  | Δw: 

it: 9420  | loss 4.3  | Δw: 10.333
it: 9430  | loss 4.28  | Δw: 10.132
it: 9440  | loss 4.29  | Δw: 10.829
it: 9450  | loss 4.34  | Δw: 10.404
it: 9460  | loss 4.29  | Δw: 10.255
it: 9470  | loss 4.38  | Δw: 10.711
it: 9480  | loss 4.3  | Δw: 10.463
it: 9490  | loss 4.28  | Δw: 10.376
it: 9500  | loss 4.32  | Δw: 10.75
it: 9510  | loss 4.32  | Δw: 10.335
it: 9520  | loss 4.26  | Δw: 10.255
it: 9530  | loss 4.26  | Δw: 9.987
it: 9540  | loss 4.42  | Δw: 10.54
it: 9550  | loss 4.34  | Δw: 10.783
it: 9560  | loss 4.41  | Δw: 10.646
it: 9570  | loss 4.29  | Δw: 10.428
it: 9580  | loss 4.28  | Δw: 10.388
it: 9590  | loss 4.34  | Δw: 10.681
it: 9600  | loss 4.29  | Δw: 10.668
it: 9610  | loss 4.25  | Δw: 10.904
it: 9620  | loss 4.3  | Δw: 10.398
it: 9630  | loss 4.25  | Δw: 10.732
it: 9640  | loss 4.33  | Δw: 10.245
it: 9650  | loss 4.35  | Δw: 11.012
it: 9660  | loss 4.26  | Δw: 10.656
it: 9670  | loss 4.32  | Δw: 10.584
it: 9680  | loss 4.33  | Δw: 10.591
it: 9690  | loss 4.25  | Δw: 10.61