# ニューラルネットーワークに基づく(実数)MIMO検出

Copyright (c) 2022 Tadashi Wadayama  
Released under the MIT license  
https://opensource.org/licenses/mit-license.php

In [1]:
using LinearAlgebra
using Plots
gr()
using Random
Random.seed!(1)
using Flux

### グローバル変数の設定

In [2]:
K = 50 
noise_std = 0.75
n = 4
h = 50
H = randn(n,n)
adam_lr = 0.01

0.01

In [3]:
println(H)

[-1.1886800049871964 0.4041521072069755 -0.6774167737228607 -0.4069644601387791; 0.7878269862203976 -0.10357598296850325 -0.13561837074286126 -0.06886416086215355; 1.1780259137155593 0.5953191566843961 -0.10787034449583965 0.6400915308449808; 0.3855116016279269 0.45547048325456163 0.7183315170355797 -0.10213552217354001]


### ミニバッチ生成関数

In [4]:
function mini_batch(K)
    x = 1.0 .- 2.0*rand(0:1, n, K)
    y = H*x + noise_std*randn(n, K)
    return x,y
end

mini_batch (generic function with 1 method)

### ネットワーク構造の定義

In [5]:
layer1 = Dense(n, h) 
layer2 = Dense(h, h) 
layer3 = Dense(h, n) 
function detector(x)
    x = relu.(layer1(x))
    x = relu.(layer2(x))
    x = tanh.(layer3(x))
    return x
end

detector (generic function with 1 method)

### 学習プロセス

In [6]:
opt = Flux.ADAM(adam_lr) 
train_itr = 500
loss(x,y) = norm(x-y)^2

ps = Flux.params(layer1, layer2, layer3)
for i in 1:train_itr
    x,y = mini_batch(K)
    gs = Flux.gradient(ps) do
        x̂ = detector(y)
        loss(x, x̂)
    end
    Flux.Optimise.update!(opt, ps, gs)
    if i % 100 == 0
        x,y = mini_batch(K)
        x̂ = detector(y)
        println(loss(x, x̂))
    end
end

62.14068795016006
60.018315089408155
89.26857220126797
65.1438902637938
75.32398886423107


### シンボル誤り率を測定する (ニューラル推定器)

In [7]:
total_syms = 0
error_syms = 0
num_loops = 1000

for i in 1:num_loops
    x, y = mini_batch(K)
    x̂ = detector(y)
    total_syms += n*K
    error_syms += sum(sign.(x̂) .!= x)
end
println("total_syms = ", total_syms)
println("error_syms = ", error_syms)
println("symbols error rate = ", error_syms/total_syms)

total_syms = 200000
error_syms = 22899
symbols error rate = 0.114495


### シンボル誤り率を測定する (ZF推定器)

In [8]:
total_syms = 0
error_syms = 0
num_loops = 1000
Hinv = inv(H)
for i in 1:num_loops
    x, y = mini_batch(K)
    x̂ = Hinv*y
    total_syms += n*K
    error_syms += sum(sign.(x̂) .!= x)
end
println("total_syms = ", total_syms)
println("error_syms = ", error_syms)
println("symbols error rate = ", error_syms/total_syms)

total_syms = 200000
error_syms = 31193
symbols error rate = 0.155965
