In [1]:
import argparse
import time
import torch
import torch.nn
from torch.autograd import Variable
import torch.nn as nn
from lm import repackage_hidden, LM_LSTM
import numpy as np
import reader

Ниже - обучение базовой модели из статьи: bigLSTM(small PTB)
Обученная модель сохранена в файле lm_model.pt  
Поэтому это можно скипнуть

In [2]:
criterion = nn.CrossEntropyLoss()
data = 'data'
hidden_size = 200
num_steps = 35
num_layers = 2
batch_size = 20
num_epochs = 13 
dp_keep_prob = 0.35
inital_lr = 20.0
save = 'lm_model.pt'

In [None]:
def repackage_hidden(h):
  """Wraps hidden states in new Variables, to detach them from their history."""
  

  if type(h) == torch.Tensor:
    return Variable(h.data)
  else:
    d = tuple(repackage_hidden(v) for v in h)
    return d

In [None]:
def run_epoch(model, data, is_train=False, lr=1.0):
    """Runs the model on the given data."""
    if is_train:
        model.train()
    else:
        model.eval()
    epoch_size = ((len(data) // model.batch_size) - 1) // model.num_steps
    start_time = time.time()
    hidden = model.init_hidden()
    hidden[0].requires_grad=True
    hidden[1].requires_grad=True
    costs = 0.0
    iters = 0
    for step, (x, y) in enumerate(reader.ptb_iterator(data, model.batch_size, model.num_steps)):
        inputs =torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous().cuda()
        model.zero_grad()
        hidden = repackage_hidden(hidden)
        outputs, hidden = model(inputs, hidden)
        targets = torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous().cuda()
        tt = torch.squeeze(targets.view(-1, model.batch_size * model.num_steps))

        loss = criterion(outputs.view(-1, model.vocab_size), tt)
        #print( loss.data.item() , model.num_steps)
        costs += loss.data.item() * model.num_steps
        iters += model.num_steps

        if is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), 0.25)
            for p in model.parameters():
                p.data.add_(-lr, p.grad.data)
            if step % (epoch_size // 10) == 10:
                print("{} perplexity: {:8.2f} speed: {} wps".format(step * 1.0 / epoch_size, np.exp(costs / iters),
                                                       iters * model.batch_size / (time.time() - start_time)))
    return np.exp(costs / iters)


In [None]:
#for google collab
!mkdir /content/data

In [None]:

raw_data = reader.ptb_raw_data(data_path=data)
train_data, valid_data, test_data, word_to_id, id_2_word = raw_data
vocab_size = len(word_to_id)
print('Vocabluary size: {}'.format(vocab_size))
model = LM_LSTM(embedding_dim=hidden_size, num_steps=num_steps, batch_size=batch_size,
                vocab_size=vocab_size, num_layers=num_layers, dp_keep_prob=dp_keep_prob)
model.cuda()
lr = inital_lr
# decay factor for learning rate
lr_decay_base = 1 / 1.15
# we will not touch lr for the first m_flat_lr epochs
m_flat_lr = 14.0

print("########## Training ##########################")
for epoch in range(num_epochs):
    lr_decay = lr_decay_base ** max(epoch - m_flat_lr, 0)
    lr = lr * lr_decay # decay lr if it is time
    train_p = run_epoch(model, train_data, True, lr)
    print('Train perplexity at epoch {}: {:8.2f}'.format(epoch, train_p))
    print('Validation perplexity at epoch {}: {:8.2f}'.format(epoch, run_epoch(model, valid_data)))
print("########## Testing ##########################")
model.batch_size = 1 # to make sure we process all the data
print('Test Perplexity: {:8.2f}'.format(run_epoch(model, test_data)))
with open(save, 'wb') as f:
    torch.save(model, f)
print("########## Done! ##########################")

Загрузка обученной модели

In [3]:
raw_data = reader.ptb_raw_data(data_path=data)
train_data, valid_data, test_data, word_to_id, id_2_word = raw_data
vocab_size = len(word_to_id)
print('Vocabluary size: {}'.format(vocab_size))
model =  torch.load('./lm_model.pt')
model.cuda()
model.eval()

Vocabluary size: 10000


LM_LSTM(
  (dropout): Dropout(p=0.65)
  (word_embeddings): Embedding(10000, 1500)
  (lstm): LSTM(1500, 1500, num_layers=2, dropout=0.65)
  (sm_fc): Linear(in_features=1500, out_features=10000, bias=True)
)

In [4]:
freq = {key:0 for key in id_2_word.keys()}
for j in train_data:
    freq[j] +=1

In [5]:
from math import sqrt
freq_sqr = np.array([sqrt(q) for key, q in freq.items()])

In [6]:
def weighted_svd(A, freq_sqr): 
    Q = torch.diag(torch.from_numpy(freq_sqr).cuda().type(torch.cuda.FloatTensor))
    QA = torch.matmul(Q, A)
    U_, S_, V_ = torch.svd(QA)
    #print(Q.shape, V_.shape, S_.shape)
    #print(torch.matmul(torch.inverse(Q),U_).shape)
    U = torch.matmul(torch.inverse(Q),U_) * S_
    return U, V_

In [7]:
w=model.word_embeddings.weight
U, V = weighted_svd(w, freq_sqr)

In [8]:
new_weights = torch.matmul(U, V.t())
new_weights.shape

torch.Size([10000, 1500])

In [9]:
freq_sqr.shape

(10000,)

In [10]:
def group_reduce(A, c, r, t_max, m_min):
    #init clusters
    cluster_len = A.shape[0] / c
    clusters = torch.split(A, int(cluster_len))
    freq_cls = np.array_split(freq_sqr, c)
    ranks = {}
    V = {}
    U = {}
    for p in range(c): 
        U[p], V[p]  = weighted_svd(clusters[p],freq_cls[p])
        ranks[p] = torch.matrix_rank(torch.matmul(U[p], V[p].t()))
    for t in range(t_max):
        M = {}
        for i in range(A.shape[0]):
            errors = [torch.norm(A[i] - torch.matmul(torch.matmul(V[p], V[p].t()), A[i])) for p in range(c)]
            e_i = min(errors)
            g_i = np.argmin(errors)
            if g_i != i // cluster_len:
                M[i] = (e_i, g_i)
        print(M)
        m = 0.1*len(M)
        t = 0
        is_changed = [False]*c
        for key, value in sorted(M.items(), key=lambda kv: kv[1]):
            freq_cls[value[1]].append(freq_cls[key // cluster_len][key % cluster_len])
            freq_cls[key // cluster_len].delete(freq_cls[key // cluster_len][key % cluster_len])
            clusters[value[1]] = torch.cat(clusters[value[1]], clusters[key // cluster_len][key % cluster_len])
            clusters[key // cluster_len] = torch.cat(clusters[key // cluster_len][ :key % cluster_len], 
                                                     clusters[key // cluster_len][key % cluster_len + 1: ]) 
            is_chaged[value[1]] = True
            is_chaged[key // cluster_len] = True
            
            t += 1
            if m == t:
                break;
        if m < m_min:
            return [torch.matmul(U[p], V[p].t()) for p in range(c)]
        for p in range(c):
            if is_changed[p]:
                U[p], V[p]  = weighted_svd(clusters[p],freq_cls[p])
                ranks[p] = torch.matrix_rank(torch.matmul(U[p], V[p].t()))
    return [torch.matmul(U[p], V[p].t()) for p in range(c)]

In [11]:
group_reduce(w, c = 10, r = 10, t_max = 100, m_min = 5)

10
10
{}


[tensor([[ 0.1311,  0.1610, -0.0722,  ...,  0.2115,  0.1102, -0.1117],
         [ 0.0727,  0.1335, -0.1277,  ...,  0.3543, -0.0537, -0.0631],
         [ 0.1373,  0.3537, -0.2748,  ..., -0.0159, -0.2245, -0.1115],
         ...,
         [ 0.1178, -0.0063,  0.0412,  ...,  0.1042, -0.0494, -0.0364],
         [ 0.0614, -0.1251, -0.0319,  ...,  0.1271,  0.1594, -0.1130],
         [ 0.1573,  0.1949,  0.0006,  ...,  0.1193, -0.0180, -0.0163]],
        device='cuda:0', grad_fn=<MmBackward>),
 tensor([[ 0.1173,  0.0042,  0.0267,  ..., -0.0471, -0.0061,  0.0128],
         [-0.1184,  0.0134,  0.0045,  ..., -0.0175, -0.0679, -0.1969],
         [ 0.0963,  0.1117,  0.1898,  ...,  0.1214,  0.1565, -0.0877],
         ...,
         [-0.0310,  0.1439, -0.0608,  ...,  0.0596,  0.0313, -0.0884],
         [ 0.0678,  0.0133, -0.0158,  ..., -0.0106, -0.1032,  0.1046],
         [ 0.0837, -0.1673,  0.1950,  ...,  0.1291,  0.0456, -0.0165]],
        device='cuda:0', grad_fn=<MmBackward>),
 tensor([[-2.4464e-01,

In [None]:
x = {1: (2, 1), 3: (4, 0), 4: (3, 1), 2: (1, 5), 0: (0, 6)}
sorted_by_value = sorted(x.items(), key=lambda kv: kv[1])
sorted_by_value

In [None]:
A = [1, 2, 3, 4, 5, 6]
A[0:2-3:]

In [None]:
A = np.array([[[1, 2],[3, 4]],[[5, 6],[7, 8]],[[9, 0],[11,12]]])
A_tensor = torch.from_numpy(A)
A_tensor