# `amp` class: how to use it

We provide in this notebook an example of a single run of the class `amp`.

## Example

In [1]:
from amp import AMP

### Class inputs

In [2]:
### Number of samples
n_samples= 1000
### Number of labels k
n_labels= 3
### Alpha = number of samples / input dim
alpha= 1.5
### Channel
channel= 'argmax'
### Teacher
teacher= 'argmax'
### Prior
prior= 'gauss'
#or
#prior= 'rademacher'
### Damping: f_new = (1-damping)*f_new + f_old
damping= 0.5
### Infinity: numerical computation 
infinity= 10
### Maximum number of iterations
max_iter= 1000
### Convergence threshold: mean(abs(W-W_old)) 
conv= 1e-9

### Initialization: `amp=AMP()`

In [3]:
amp = AMP(n_samples= n_samples, n_labels= n_labels, alpha= alpha,
          channel= channel, prior= prior, 
          damping= damping, infinity= infinity)

### Generate the teacher weights and the data: `amp.data()`

In [4]:
X, y, W_star = amp.data(teacher= teacher) 

  
| | | | | AMP | | | | |
 
--- Teacher weights ---
W_star=  [[-1.34447305 -0.14212705]
 [ 0.73216496  1.35210949]
 [-1.38324991 -0.14730495]
 ...
 [-0.62083503 -0.34800672]
 [-0.36070264  0.4628032 ]
 [-1.15983366  0.16166457]]
 


### Iterate `amp.fit()`

In [5]:
amp.fit(max_iter= max_iter, conv=conv)

--- Initialization ---
W_hat0 ~ gauss
initial overlap =  [[-7.44075551e-05  2.21689346e-05]
 [ 6.11660704e-06  2.67868811e-05]]
   
--- Iterate AMP - alpha = 1.50000000 ---
alpha= 1.50000000 | it= 0 | diff_W= 0.16280799 | mses = 1.54604751 | time: 0.531s 
overlap matrix =  [[0.20659624 0.09457087]
 [0.09264377 0.17324   ]]
alpha= 1.50000000 | it= 1 | diff_W= 0.20981440 | mses = 1.21360719 | time: 0.442s 
alpha= 1.50000000 | it= 2 | diff_W= 0.19791563 | mses = 0.99924102 | time: 0.477s 
alpha= 1.50000000 | it= 3 | diff_W= 0.16673277 | mses = 0.89056781 | time: 0.459s 
alpha= 1.50000000 | it= 4 | diff_W= 0.12834615 | mses = 0.84959055 | time: 0.485s 
alpha= 1.50000000 | it= 5 | diff_W= 0.09015214 | mses = 0.83910415 | time: 0.442s 
alpha= 1.50000000 | it= 6 | diff_W= 0.06058756 | mses = 0.83531406 | time: 0.448s 
alpha= 1.50000000 | it= 7 | diff_W= 0.04576554 | mses = 0.83075506 | time: 0.471s 
alpha= 1.50000000 | it= 8 | diff_W= 0.03851530 | mses = 0.82599152 | time: 0.430s 
alpha= 1.50

### After `amp.fit()`

In [6]:
# Returns estimator and variance
W_hat, C_hat = amp.parameters()
# Returns list of overlap matrices after fit
ov_matrices = amp.overlap_matrix()
# Returns list of MSES after fit
mses_ = amp.get_mses()
# Returns list of  mean(abs(W-W_old))  after fit
W_diff_ = amp.get_diff_W()
# Compute the generalization error
er_gen = amp.eg(new_samples=1e5)

In [7]:
# Final overlap
print(ov_matrices[-1])

[[1.13600785 0.59628221]
 [0.57293475 1.05867923]]


In [8]:
# Generalization error
print(er_gen)

0.33029
