# PSからPLを制御するためのプログラム

## import

In [1]:
import numpy as np
from time import sleep, time
from pynq import PL, Overlay, MMIO, allocate

In [2]:
N = 10
batch_size = 32

## PL制御クラス

In [3]:
class Controller:
  def __init__(self, bitfile, batch_size, N):
    PL.reset()

    ol = Overlay(bitfile)

    dma = ol.axi_dma_0
    self.dma_send = dma.sendchannel
    self.dma_recv = dma.recvchannel

    self.top_mmio = MMIO(ol.ip_dict['top_0']['phys_addr'], ol.ip_dict['top_0']['addr_range'])
    # slv_reg0[2:0] : {next, run, rst_n}
    self.top_mmio.write(0*4, 0b0000)
    
    self.input_buffer  = allocate(shape=(batch_size, N), dtype=np.uint8)
    self.output_buffer = allocate(shape=(batch_size, N), dtype=np.uint8)
  
  def send_data(self, data):
#     while not self.dma_send.idle:
#       sleep(0.001)
    self.input_buffer[:] = data
    
    self.dma_send.transfer(self.input_buffer)
  
  def recv_data(self):
#     while not self.dma_recv.idle:
#       sleep(0.001)
    self.dma_recv.transfer(self.output_buffer)
    return self.output_buffer
  
  def run_mode(self, mode=1):
    # slv_reg1[1:0] : mode
    self.top_mmio.write(1*4, mode)

    # slv_reg0[2:0] : {next, run, rst_n}
    self.top_mmio.write(0*4, 0b101)

    # slv_reg0[2:0] : {next, run, rst_n}
    self.top_mmio.write(0*4, 0b011)

    # slv_reg2[0] : finish
    while self.top_mmio.read(2*4) != 1:
      sleep(0.001)


## 関数

### 文字列変換用

In [4]:
# 使用できる文字のリストを読み込み
char_list = []
with open('char_list_72.txt' , 'r', encoding='utf-8') as file:
  for line in file:
    char_list.append(line.replace('\n', ''))


# 文字リストの番号を顔文字中の番号を対応付け
def convert_str_int(kmj_data, N=10):
  kmj_index = []
  for kmj in kmj_data:
    kmj = list(kmj)
    kmj += ['<PAD>' for _ in range(N - len(kmj))]
    temp = []
    for c in kmj:
      try:
        temp.append(char_list.index(c))
      except:
        temp.append(char_list.index('<UNK>'))
    kmj_index.append(temp)
  
  return kmj_index


# 整数列を文字列に変換
def convert_int_str(x):
  x = np.array(char_list)[x]
  x = [c for c in x if c not in ['<PAD>', '<UNK>']]

  return ''.join(x)

### ミニバッチの作成用

In [5]:
# データセットを読み込む関数
def read_dataset(filename):
  kmj_dataset = []
  with open(filename, 'r', encoding='utf-8') as file:
    for line in file:
      kmj_dataset.append(line.replace('\n', ''))

  return kmj_dataset


# データの前処理を行う関数
# 顔文字を文字ごとに分解 -> 番号付け -> One-hotベクトル
def preprocess(kmj_data):

  # 文字リストの番号を顔文字中の番号を対応付け
  kmj_index = []
  for kmj in kmj_data:
    kmj = list(kmj)
    kmj += ['<PAD>' for _ in range(N - len(kmj))]
    temp = []
    for c in kmj:
      try:
        temp.append(char_list.index(c))
      except:
        temp.append(char_list.index('<UNK>'))
    kmj_index.append(temp)

  # 付けた番号をOne-hotベクトル化
  kmj_num = len(kmj_index)                        # 顔文字数
  char_num = len(char_list)                       # 文字の種類数
  kmj_onehot = np.zeros((kmj_num, N, char_num))   # One-hotベクトルリスト
  for i, index in enumerate(kmj_index):
    mask = range(char_num) == np.array(index).reshape((N, 1))
    kmj_onehot[i][mask] = 1
  
  return kmj_onehot


# ミニバッチを作成する関数
def create_batch(batch_size):
  # データセットの読み込み
  kmj_dataset = read_dataset('kaomoji_MAX=10_DA.txt')

  # データの前処理
  kmj_onehot = preprocess(kmj_dataset)
  kmj_int = kmj_onehot.argmax(axis=-1)

  # データセットを分割
  train_size = int(len(kmj_dataset) * 0.85)
  valid_size = int(len(kmj_dataset) * 0.10)
  test_size  = len(kmj_dataset) - train_size - valid_size

  dataset_train = kmj_int[:train_size]
  dataset_valid = kmj_int[train_size:train_size+valid_size]
  dataset_test  = kmj_int[train_size+valid_size:]

  # ミニバッチに分割
  n_train = int(train_size / batch_size)
  n_valid = int(valid_size / batch_size)
  n_test  = int( test_size / batch_size)

  dataloader_train = [dataset_train[i*batch_size:(i+1)*batch_size] for i in range(n_train)]
  dataloader_valid = [dataset_valid[i*batch_size:(i+1)*batch_size] for i in range(n_valid)]
  dataloader_test  = [ dataset_test[i*batch_size:(i+1)*batch_size] for i in range(n_test )]

  return dataloader_train, dataloader_valid, dataloader_test

## 動作検証

In [6]:
# ミニバッチ作成
dataloader_train, dataloader_valid, dataloader_test = create_batch(batch_size)

### 学習

In [48]:
# PL制御クラスのインスタンス
pl = Controller('kmj_gen_v5_3.bit', batch_size, N)

In [49]:
n_epochs = 100
learning_time = 0

for epoch in range(n_epochs):
  acc_train = 0
  acc_valid = 0
    
  start = time()
  
  # Train
  pl.top_mmio.write(1*4, 0)    # slv_reg1[1:0] : mode (TRAIN)

  pl.send_data(dataloader_train[0])
  pl.top_mmio.write(0*4, 0b101)    # slv_reg0[2:0] : {next, run, rst_n}
  pl.top_mmio.write(0*4, 0b011)    # slv_reg0[2:0] : {next, run, rst_n}

  for i, batch in enumerate(dataloader_train[1:]):
    pl.send_data(batch)
    
    while pl.top_mmio.read(2*4) != 1: # slv_reg2[0] : finish
      sleep(0.001)

    pl.top_mmio.write(0*4, 0b101)    # slv_reg0[2:0] : {next, run, rst_n}
    pl.top_mmio.write(0*4, 0b011)    # slv_reg0[2:0] : {next, run, rst_n}
      
    recv_data = pl.recv_data()
    acc_train += (recv_data == dataloader_train[i]).sum()
  
  while pl.top_mmio.read(2*4) != 1: # slv_reg2[0] : finish
    sleep(0.001)
  
  recv_data = pl.recv_data()
  acc_train += (recv_data == dataloader_train[-1]).sum()


  # Valid
  pl.top_mmio.write(1*4, 1)    # slv_reg1[1:0] : mode (FORWARD)

  pl.send_data(dataloader_valid[0])
  pl.top_mmio.write(0*4, 0b101)    # slv_reg0[2:0] : {next, run, rst_n}
  pl.top_mmio.write(0*4, 0b011)    # slv_reg0[2:0] : {next, run, rst_n}

  for i, batch in enumerate(dataloader_valid[1:]):
    pl.send_data(batch)
    
    while pl.top_mmio.read(2*4) != 1: # slv_reg2[0] : finish
      sleep(0.001)

    pl.top_mmio.write(0*4, 0b101)    # slv_reg0[2:0] : {next, run, rst_n}
    pl.top_mmio.write(0*4, 0b011)    # slv_reg0[2:0] : {next, run, rst_n}
      
    recv_data = pl.recv_data()
    acc_valid += (recv_data == dataloader_valid[i]).sum()
  
  while pl.top_mmio.read(2*4) != 1: # slv_reg2[0] : finish
    sleep(0.001)
  
  recv_data = pl.recv_data()
  acc_valid += (recv_data == dataloader_valid[-1]).sum()

  end = time()
  learning_time += end - start

  print('EPOCH: {:>3}, Train Acc: {:>.3f}, Valid Acc: {:>.3f}, Time: {:>6.3f}'.format(
          epoch+1,
          acc_train / (len(dataloader_train) * batch_size * N),
          acc_valid / (len(dataloader_valid) * batch_size * N),
          learning_time
       ))

EPOCH:   1, Train Acc: 0.238, Valid Acc: 0.334, Time:  0.141
EPOCH:   2, Train Acc: 0.343, Valid Acc: 0.363, Time:  0.282
EPOCH:   3, Train Acc: 0.384, Valid Acc: 0.396, Time:  0.426
EPOCH:   4, Train Acc: 0.412, Valid Acc: 0.419, Time:  0.566
EPOCH:   5, Train Acc: 0.433, Valid Acc: 0.443, Time:  0.710
EPOCH:   6, Train Acc: 0.454, Valid Acc: 0.459, Time:  0.851
EPOCH:   7, Train Acc: 0.468, Valid Acc: 0.470, Time:  0.994
EPOCH:   8, Train Acc: 0.477, Valid Acc: 0.476, Time:  1.135
EPOCH:   9, Train Acc: 0.484, Valid Acc: 0.482, Time:  1.278
EPOCH:  10, Train Acc: 0.490, Valid Acc: 0.490, Time:  1.419
EPOCH:  11, Train Acc: 0.498, Valid Acc: 0.497, Time:  1.563
EPOCH:  12, Train Acc: 0.507, Valid Acc: 0.505, Time:  1.704
EPOCH:  13, Train Acc: 0.516, Valid Acc: 0.516, Time:  1.847
EPOCH:  14, Train Acc: 0.527, Valid Acc: 0.524, Time:  1.989
EPOCH:  15, Train Acc: 0.536, Valid Acc: 0.533, Time:  2.131
EPOCH:  16, Train Acc: 0.544, Valid Acc: 0.541, Time:  2.272
EPOCH:  17, Train Acc: 0

### 検証

In [101]:
pl.top_mmio.write(0*4, 0b000)    # slv_reg0[2:0] : {next, run, rst_n}

In [94]:
kmj_data = ['ヾ(*　∀́　*)ノ']
kmj = preprocess(kmj_data).argmax(axis=-1)
kmj = np.full((batch_size, N), kmj)
print('base     :', convert_int_str(kmj[0]))

base     : ヾ(*　∀́　*)ノ


In [116]:
# kmj = dataloader_test[0]
pl.send_data(kmj)

start = time()

pl.run_mode(3)

end = time()
print(end - start)

recv_data = pl.recv_data()

0.0019061565399169922


In [117]:
for i in range(batch_size):
#   print('base     :', convert_int_str(kmj[i]))
  print('generate :', convert_int_str(recv_data[i]))

generate : ヾ(;^　)・)
generate : ノ(　-・
generate : 　　(　̄)
generate : 　(́()　!
generate : !(((-ω))
generate : (_-))　)
generate : ヾ)>∀゚́　)ノ
generate : (`^_)・
generate : (ヾ(-　)
generate : (*・　　ω・・
generate : (　́ω))
generate : )(　∀　̄)
generate : /゚ω゚　)・・)
generate : (　́　゚　))
generate : (▽　))!!
generate : !・　　^(!・*
generate : ゚()゚　　-・(
generate : !.^-゚^ノ)
generate : 　゚(̄　)・
generate : (゚∀^)
generate : ノ(・-*)
generate : *-(゚
generate : *^(^　^)
generate : 　　　́　̄　)　)
generate : ヾ(-`)
generate : *́(・()(
generate : )(　　・̄))ノ
generate : (^　^̄　̄)
generate : 　^()()・゚
generate : *-((・　^)
generate : (∀　^))!
generate : )ノ-ω　)ノ!!
