# ニューラルネットーワークに基づく(実数)MIMO検出

Copyright (c) 2022 Tadashi Wadayama  
Released under the MIT license  
https://opensource.org/licenses/mit-license.php

In [82]:
using LinearAlgebra
using Plots
gr()
using Random
Random.seed!(1)
using Flux

### グローバル変数の設定

In [83]:
K = 50 
noise_std = 0.75
n = 4
h = 50
H = randn(n,n)
adam_lr = 0.01

0.01

In [84]:
println(H)

[0.2972879845354616 -0.839026854388764 0.2290095549097807 0.5837082875687786; 0.3823959677906078 0.31111133849833383 -2.2670863488005306 0.9632716050381906; -0.5976344767282311 2.2950878238373105 0.5299655761667461 0.45879095505371686; -0.01044524463737564 -0.050451229933665284 0.43142152642291204 -0.5223367574215084]


### ミニバッチ生成関数

In [85]:
function mini_batch(K)
    x = 1.0 .- 2.0*rand(0:1, n, K)
    y = H*x + noise_std*randn(n, K)
    return x,y
end

mini_batch (generic function with 1 method)

### ネットワーク構造の定義

In [86]:
layer1 = Dense(n, h) 
layer2 = Dense(h, h) 
layer3 = Dense(h, n) 
function detector(x)
    x = relu.(layer1(x))
    x = relu.(layer2(x))
    x = tanh.(layer3(x))
    return x
end

detector (generic function with 1 method)

### 学習プロセス

In [87]:
opt = ADAM(adam_lr) 
train_itr = 500
loss(x,y) = norm(x-y)^2

ps = Flux.params(layer1, layer2, layer3)
for i in 1:train_itr
    x,y = mini_batch(K)
    gs = gradient(ps) do
        x̂ = detector(y)
        loss(x, x̂)
    end
    Flux.Optimise.update!(opt, ps, gs)
    if i % 100 == 0
        x,y = mini_batch(K)
        x̂ = detector(y)
        println(loss(x, x̂))
    end
end

40.50489921130954
39.19253350512622
54.97720835986238
37.522830596656405
47.177756500637095


### シンボル誤り率を測定する (ニューラル推定器)

In [88]:
total_syms = 0
error_syms = 0
num_loops = 1000

for i in 1:num_loops
    x, y = mini_batch(K)
    x̂ = detector(y)
    total_syms += n*K
    error_syms += sum(sign.(x̂) .!= x)
end
println("total_syms = ", total_syms)
println("error_syms = ", error_syms)
println("symbols error rate = ", error_syms/total_syms)

total_syms = 200000
error_syms = 13264
symbols error rate = 0.06632


### シンボル誤り率を測定する (ZF推定器)

In [89]:
total_syms = 0
error_syms = 0
num_loops = 1000
Hinv = inv(H)
for i in 1:num_loops
    x, y = mini_batch(K)
    x̂ = Hinv*y
    total_syms += n*K
    error_syms += sum(sign.(x̂) .!= x)
end
println("total_syms = ", total_syms)
println("error_syms = ", error_syms)
println("symbols error rate = ", error_syms/total_syms)

total_syms = 200000
error_syms = 53896
symbols error rate = 0.26948
