# モデルから重みとバイアスを抽出

In [1]:
# import random
import numpy as np
# import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch.optim as optim
# import torch.autograd as autograd

## モデル

### Encoder

In [2]:
class Encoder(nn.Module):
  def __init__(self, N, char_num, emb_dim, hid_dim):
    super().__init__()
    self.W_emb = nn.Parameter(torch.tensor(np.random.uniform(
                   low=-np.sqrt(6 / (char_num + emb_dim)),
                   high=np.sqrt(6 / (char_num + emb_dim)),
                   size=(char_num, emb_dim)
                 ).astype('float32')))

    self.W_1 = nn.Parameter(torch.tensor(np.random.uniform(
                 low=-np.sqrt(6 / (N + hid_dim)),
                 high=np.sqrt(6 / (N + hid_dim)),
                 size=(emb_dim, N, hid_dim)
               ).astype('float32')))
    self.b_1 = nn.Parameter(torch.tensor(np.zeros((emb_dim, 1, hid_dim)).astype('float32')))

    self.W_2 = nn.Parameter(torch.tensor(np.random.uniform(
                 low=-np.sqrt(6 / (emb_dim + 1)),
                 high=np.sqrt(6 / (emb_dim + 1)),
                 size=(hid_dim, emb_dim, 1)
               ).astype('float32')))
    self.b_2 = nn.Parameter(torch.tensor(np.zeros((hid_dim, 1, 1)).astype('float32')))


  def forward(self, x):
    x = torch.matmul(x, self.W_emb)
    # x = F.tanh(x)
    x = (torch.matmul(x.transpose(-2, -1).unsqueeze(-2), self.W_1) + self.b_1).squeeze(-2)
    x = F.tanh(x)
    x = (torch.matmul(x.transpose(-2, -1).unsqueeze(-2), self.W_2) + self.b_2).squeeze(-2)
    x = F.tanh(x)
    return x

In [3]:
encoder = Encoder(10, 137, 8, 32)
x = torch.rand(16, 10, 137)
encoder(x).shape

torch.Size([16, 32, 1])

### Decoder

In [4]:
class Decoder(nn.Module):
  def __init__(self, N, char_num, emb_dim, hid_dim):
    super().__init__()
    # self.W_1 = nn.Parameter(torch.tensor(np.random.uniform(
    #              low=-np.sqrt(6 / (hid_dim + emb_dim)),
    #              high=np.sqrt(6 / (hid_dim + emb_dim)),
    #              size=(hid_dim, emb_dim)
    #            ).astype('float32')))
    # self.b_1 = nn.Parameter(torch.tensor(np.zeros((hid_dim, emb_dim)).astype('float32')))

    self.W_1 = nn.Parameter(torch.tensor(np.random.uniform(
                 low=-np.sqrt(6 / (hid_dim + N)),
                 high=np.sqrt(6 / (hid_dim + N)),
                 size=(N, hid_dim, hid_dim)
               ).astype('float32')))
    self.b_1 = nn.Parameter(torch.tensor(np.zeros((N, 1, hid_dim)).astype('float32')))

    # self.W_2 = nn.Parameter(torch.tensor(np.random.uniform(
    #              low=-np.sqrt(6 / (hid_dim + hid_dim)),
    #              high=np.sqrt(6 / (hid_dim + hid_dim)),
    #              size=(hid_dim, hid_dim)
    #            ).astype('float32')))
    # self.b_2 = nn.Parameter(torch.tensor(np.zeros((1, hid_dim)).astype('float32')))

    self.W_out = nn.Parameter(torch.tensor(np.random.uniform(
                   low=-np.sqrt(6 / (hid_dim + char_num)),
                   high=np.sqrt(6 / (hid_dim + char_num)),
                   size=(hid_dim, char_num)
                 ).astype('float32')))
    self.b_out = nn.Parameter(torch.tensor(np.zeros((N, char_num)).astype('float32')))

  def forward(self, x):
    # x = x * self.W_1 + self.b_1
    x = (torch.matmul(x.transpose(-2, -1).unsqueeze(-2), self.W_1) + self.b_1).squeeze(-2)
    x = F.tanh(x)
    # x = (torch.matmul(x.transpose(-2, -1).unsqueeze(-2), self.W_2) + self.b_2).squeeze(-2)
    # x = torch.matmul(x.transpose(-2, -1), self.W_2) + self.b_2
    # x = F.relu(x)
    x = torch.matmul(x, self.W_out) + self.b_out
    return x

In [5]:
decoder = Decoder(10, 137, 8, 32)
x = torch.rand(16, 32, 1)
decoder(x).shape

torch.Size([16, 10, 137])

### Generator

In [6]:
class Generator(nn.Module):
  def __init__(self, N, char_num, emb_dim, hid_dim):
    super().__init__()
    self.encoder = Encoder(N, char_num, emb_dim, hid_dim)
    self.decoder = Decoder(N, char_num, emb_dim, hid_dim)

  def forward(self, x):
    z = self.encoder(x)
    y = self.decoder(z)
    return y

In [7]:
generator = Generator(10, 137, 8, 32)
x = torch.rand(16, 10, 137)
generator(x).shape

torch.Size([16, 10, 137])

## モデルの読み込み


In [8]:
N = 10
char_num = 200
emb_dim = 24
hid_dim = 24

model = Generator(N, char_num, emb_dim, hid_dim)
model.load_state_dict(torch.load('complin_params_acc616.pth'))

model.eval()

Generator(
  (encoder): Encoder()
  (decoder): Decoder()
)

## モデルの動作テスト

In [9]:
KAOMOJI_MAX = 10

kmj_sample = ['(　́ω`)ノ', 'ヾ(*　∀́　*)ノ', '(*　̄∇　̄)ノ']

In [11]:
char_list = []

with open('char_list.txt', mode='r') as file:
  for line in file:
    char_list.append(line.replace('\n', ''))

In [12]:
kmj_index = []

for kmj in kmj_sample:
  kmj = list(kmj)
  kmj += ['<PAD>' for _ in range(KAOMOJI_MAX - len(kmj))]
  temp = []
  for c in kmj:
    try:
      temp.append(char_list.index(c))
    except:
      temp.append(char_list.index('<UNK>'))
  kmj_index.append(temp)

kmj_num = len(kmj_index)        # 顔文字の総数
kmj_size = len(kmj_index[0])    # 1つの顔文字の長さ
char_num = len(char_list)       # 文字の種類数

# One-hotベクトルリスト
kmj_onehot = np.zeros((kmj_num, kmj_size, char_num))

for i, index in enumerate(kmj_index):
  mask = range(char_num) == np.array(index).reshape((kmj_size, 1))
  kmj_onehot[i][mask] = 1

In [13]:
def convert_str(x):
  x = np.array(char_list)[x.argmax(dim=1)]
  x = [c for c in x if c not in ['<PAD>', '<UNK>']]

  return ''.join(x)

for x in torch.tensor(kmj_onehot.astype('float32')):
  y = model(x)
  print('base     :', convert_str(x))
  print('generate :', convert_str(y))

base     : (　́ω`)ノ
generate : (　́Д`)ノ
base     : ヾ(*　∀́　*)ノ
generate : ヾ(*　Д́　*)ノ
base     : (*　̄∇　̄)ノ
generate : (*　́Д　́)ノ


## パラメータの抽出

In [114]:
param_names = {'encoder' : ['W_emb', 'W_1', 'b_1', 'W_2', 'b_2'],
               'decoder' : ['W_1', 'b_1', 'W_out', 'b_out']}

In [127]:
for module_name in param_names.keys():
  submodule = model.get_submodule(module_name)

  for param_name in param_names[module_name]:
    value_list = submodule.get_parameter(param_name).flatten().tolist()

    with open(module_name + '_' + param_name + '.txt', 'w') as file:
      for value in value_list:
        file.write(str(value) + '\n')