# PSからPLを制御するためのプログラム

## import

In [1]:
import numpy as np
from time import sleep
from pynq import PL, Overlay, MMIO, allocate

In [2]:
N = 10
batch_size = 32

## PL制御クラス

In [3]:
class Controller:
  def __init__(self, bitfile, batch_size, N):
    PL.reset()

    ol = Overlay(bitfile)

    dma = ol.axi_dma_0
    self.dma_send = dma.sendchannel
    self.dma_recv = dma.recvchannel

    self.top_mmio = MMIO(ol.ip_dict['top_0']['phys_addr'], ol.ip_dict['top_0']['addr_range'])
    # slv_reg0[2:0] : {set, run, rst_n}
    self.top_mmio.write(0*4, 0b000)
    
    self.input_buffer  = allocate(shape=(batch_size, N), dtype=np.uint8)
    self.output_buffer = allocate(shape=(N,), dtype=np.uint8)
  
  def send_data(self, data):
    # for i in range(len(self.input_buffer)):
    #   self.input_buffer[i] = data[i]
    self.input_buffer[:] = data
    
    self.dma_send.transfer(self.input_buffer)
    while not self.dma_send.idle:
      sleep(0.001)
  
  def recv_data(self):
    self.dma_recv.transfer(self.output_buffer)
    while not self.dma_recv.idle:
      sleep(0.001)
    return self.output_buffer
  
  def run_mode(self, mode=1):
    # slv_reg1[1:0] : mode
    self.top_mmio.write(1*4, mode)

    # slv_reg0[3:0] : {next, set, run, rst_n}
    self.top_mmio.write(0*4, 0b0101)

    # slv_reg0[3:0] : {next, set, run, rst_n}
    self.top_mmio.write(0*4, 0b0011)

    # slv_reg2[0] : finish
    while self.top_mmio.read(2*4) != 1:
      sleep(0.001)


## 関数

### 文字列変換用

In [4]:
# 使用できる文字のリストを読み込み
char_list = []
with open('char_list_72.txt' , 'r', encoding='utf-8') as file:
  for line in file:
    char_list.append(line.replace('\n', ''))


# 文字リストの番号を顔文字中の番号を対応付け
def convert_str_int(kmj_data, N=10):
  kmj_index = []
  for kmj in kmj_data:
    kmj = list(kmj)
    kmj += ['<PAD>' for _ in range(N - len(kmj))]
    temp = []
    for c in kmj:
      try:
        temp.append(char_list.index(c))
      except:
        temp.append(char_list.index('<UNK>'))
    kmj_index.append(temp)
  
  return kmj_index


# 整数列を文字列に変換
def convert_int_str(x):
  x = np.array(char_list)[x]
  x = [c for c in x if c not in ['<PAD>', '<UNK>']]

  return ''.join(x)

### ミニバッチの作成用

In [5]:
# データセットを読み込む関数
def read_dataset(filename):
  kmj_dataset = []
  with open(filename, 'r', encoding='utf-8') as file:
    for line in file:
      kmj_dataset.append(line.replace('\n', ''))

  return kmj_dataset


# データの前処理を行う関数
# 顔文字を文字ごとに分解 -> 番号付け -> One-hotベクトル
def preprocess(kmj_data):

  # 文字リストの番号を顔文字中の番号を対応付け
  kmj_index = []
  for kmj in kmj_data:
    kmj = list(kmj)
    kmj += ['<PAD>' for _ in range(N - len(kmj))]
    temp = []
    for c in kmj:
      try:
        temp.append(char_list.index(c))
      except:
        temp.append(char_list.index('<UNK>'))
    kmj_index.append(temp)

  # 付けた番号をOne-hotベクトル化
  kmj_num = len(kmj_index)                        # 顔文字数
  char_num = len(char_list)                       # 文字の種類数
  kmj_onehot = np.zeros((kmj_num, N, char_num))   # One-hotベクトルリスト
  for i, index in enumerate(kmj_index):
    mask = range(char_num) == np.array(index).reshape((N, 1))
    kmj_onehot[i][mask] = 1
  
  return kmj_onehot


# ミニバッチを作成する関数
def create_batch(batch_size):
  # データセットの読み込み
  kmj_dataset = read_dataset('kaomoji_MAX=10_DA.txt')

  # データの前処理
  kmj_onehot = preprocess(kmj_dataset)
  kmj_int = kmj_onehot.argmax(axis=-1)

  # データセットを分割
  train_size = int(len(kmj_dataset) * 0.85)
  valid_size = int(len(kmj_dataset) * 0.10)
  test_size  = len(kmj_dataset) - train_size - valid_size

  dataset_train = kmj_int[:train_size]
  dataset_valid = kmj_int[train_size:train_size+valid_size]
  dataset_test  = kmj_int[train_size+valid_size:]

  # ミニバッチに分割
  n_train = int(train_size / batch_size)
  n_valid = int(valid_size / batch_size)

  dataloader_train = [dataset_train[i*batch_size:(i+1)*batch_size] for i in range(n_train)]
  dataloader_valid = [dataset_valid[i*batch_size:(i+1)*batch_size] for i in range(n_valid)]

  return dataloader_train, dataloader_valid, dataset_test

## 動作検証

In [22]:
# ミニバッチ作成
dataloader_train, dataloader_valid, dataset_test = create_batch(batch_size)

### 学習

In [151]:
# PL制御クラスのインスタンス
pl = Controller('kmj_gen_v3_0.bit', batch_size, N)

In [152]:
n_epochs = 50

# slv_reg1[1:0] : mode (TRAIN)
pl.top_mmio.write(1*4, 0)

# slv_reg0[3:0] : {next, set, run, rst_n}
pl.top_mmio.write(0*4, 0b0101)

for epoch in range(n_epochs):
  acc_train = 0
  # acc_valid = 0

  for batch in dataloader_train:
    pl.send_data(batch)

    pl.top_mmio.write(0*4, 0b1001)    # slv_reg0[3:0] : {next, set, run, rst_n}
    pl.top_mmio.write(0*4, 0b0011)    # slv_reg0[3:0] : {next, set, run, rst_n}
    
    while pl.top_mmio.read(2*4) != 1: # slv_reg2[0] : finish
      sleep(0.001)
      
    for i in range(len(batch)):
      recv_data = pl.recv_data()
      acc_train += (recv_data == batch[i]).sum()

  print('EPOCH: {:>3}, Train Acc: {:>.3f}'.format(epoch+1, acc_train / (len(dataloader_train) * batch_size * N)))

EPOCH:   1, Train Acc: 0.238
EPOCH:   2, Train Acc: 0.343
EPOCH:   3, Train Acc: 0.384
EPOCH:   4, Train Acc: 0.412
EPOCH:   5, Train Acc: 0.433
EPOCH:   6, Train Acc: 0.454
EPOCH:   7, Train Acc: 0.468
EPOCH:   8, Train Acc: 0.477
EPOCH:   9, Train Acc: 0.484
EPOCH:  10, Train Acc: 0.490
EPOCH:  11, Train Acc: 0.498
EPOCH:  12, Train Acc: 0.507
EPOCH:  13, Train Acc: 0.516
EPOCH:  14, Train Acc: 0.527
EPOCH:  15, Train Acc: 0.536
EPOCH:  16, Train Acc: 0.544
EPOCH:  17, Train Acc: 0.552
EPOCH:  18, Train Acc: 0.559
EPOCH:  19, Train Acc: 0.567
EPOCH:  20, Train Acc: 0.574
EPOCH:  21, Train Acc: 0.580
EPOCH:  22, Train Acc: 0.586
EPOCH:  23, Train Acc: 0.591
EPOCH:  24, Train Acc: 0.596
EPOCH:  25, Train Acc: 0.599
EPOCH:  26, Train Acc: 0.603
EPOCH:  27, Train Acc: 0.606
EPOCH:  28, Train Acc: 0.610
EPOCH:  29, Train Acc: 0.613
EPOCH:  30, Train Acc: 0.616
EPOCH:  31, Train Acc: 0.619
EPOCH:  32, Train Acc: 0.623
EPOCH:  33, Train Acc: 0.626
EPOCH:  34, Train Acc: 0.630
EPOCH:  35, Tr

### 検証

In [161]:
# kmj_data = ['ヾ(*　∀́　*)ノ']
kmj_data = ['(^_^--)']
kmj_index = convert_str_int(kmj_data)

test_input = allocate(shape=(N,), dtype=np.uint8)
test_input[:] = kmj_index[0]

In [162]:
pl.dma_send.transfer(test_input)

In [163]:
# slv_reg0[3:0] : {next, set, run, rst_n}
pl.top_mmio.write(0*4, 0b0000)
pl.run_mode(1)

In [164]:
test_output = pl.recv_data()
# test_output = allocate(shape=(N,), dtype=np.uint8)
# pl.dma_recv.transfer(test_output)
print(kmj_data[0])
print(convert_int_str(test_output))

ヾ(*　∀́　*)ノ
ヾ(*　∀́　*)ノ
