# Getting started with Dictionary Learning

This notebook contains examples of how to solve orthogonal dictionary learning problem.

For more details, please check the documentation website https://pygranso.readthedocs.io/en/latest/

1. Import all necessary modules and add PyGRANSO src folder to system path.

In [1]:
import time
import numpy as np
import torch
import numpy.linalg as la
from scipy.stats import norm
import sys
## Adding PyGRANSO directories. Should be modified by user
sys.path.append('/home/buyun/Documents/GitHub/PyGRANSO')
from pygranso import pygranso
from pygransoStruct import Options, Data, GeneralStruct 

2. Specify torch device, and generate data

In [2]:
device = torch.device('cuda')
n = 30

np.random.seed(1)
# data_in
data_in = Data()
m = 10*n**2   # sample complexity
theta = 0.3   # sparsity level
Y = norm.ppf(np.random.rand(n,m)) * (norm.ppf(np.random.rand(n,m)) <= theta)  # Bernoulli-Gaussian model
Y = torch.from_numpy(Y).to(device=device, dtype=torch.double)

3. Spceify optimization variables and corresponding objective and constrained function.

Note: please strictly follow the format of evalObjFunction and combinedFunction, which will be used in the PyGRANSO main algortihm. *X_struct* and *data_in* are always required.

In [3]:
# variables and corresponding dimensions.
var_in = {"q": [n,1]}


def evalObjFunction(X_struct,data_in = None):
    q = X_struct.q
    q.requires_grad_(True)
    
    # objective function
    qtY = q.t() @ Y
    f = 1/m * torch.norm(qtY, p = 1)
    return f

def combinedFunction(X_struct,data_in = None):
    q = X_struct.q
    q.requires_grad_(True)
    
    # objective function
    qtY = q.t() @ Y
    f = 1/m * torch.norm(qtY, p = 1)

    # inequality constraint, matrix form
    ci = None

    # equality constraint 
    ce = GeneralStruct()
    ce.c1 = q.t() @ q - 1

    return [f,ci,ce]

obj_eval_fn = lambda X_struct,data_in = None : evalObjFunction(X_struct,data_in = None)
comb_fn = lambda X_struct,data_in = None : combinedFunction(X_struct,data_in = None)

4. Specify user-defined options for PyGRANSO algorithm

In [4]:
opts = Options()
opts.QPsolver = 'osqp' 
opts.maxit = 500
x0 = norm.ppf(np.random.rand(n,1))
x0 /= la.norm(x0,2)
opts.x0 = torch.from_numpy(x0).to(device=device, dtype=torch.double)
# opts.opt_tol = 1e-6
# opts.fvalquit = 1e-6
# opts.print_level = 1
opts.print_frequency = 10
# opts.print_ascii = True

4. Run main algorithm

In [5]:
start = time.time()
soln = pygranso(combinedFunction = comb_fn, objEvalFunction = obj_eval_fn,var_dim_map = var_in, torch_device = device, user_opts = opts)
end = time.time()
print("Total Wall Time: {}s".format(end - start))
print(max(abs(soln.final.x))) # should be close to 1



[33m╔═════ QP SOLVER NOTICE ══════════════════════════════════════════════════════════════╗
[0m[33m║  PyGRANSO requires a quadratic program (QP) solver that has a quadprog-compatible   ║
[0m[33m║  interface, as defined by osqp and Gurobi...                                        ║
[0m[33m╚═════════════════════════════════════════════════════════════════════════════════════╝
[0m═════════════════════════════════════════════════════════════════════════════════════════════════════════════════╗
Problem specifications:                                                                                          ║ 
 # of variables                     :   30                                                                       ║ 
 # of inequality constraints        :    0                                                                       ║ 
 # of equality constraints          :    1                                                                       ║ 
═════╦═════════════════════════

  warn("Converting sparse A to a CSC " +


  20 ║ 1.000000 │  0.56954089844 ║  0.56715684330 ║   -  │ 0.002384 ║ S  │     1 │ 1.000000 ║     1 │ 0.052213   ║ 
  30 ║ 1.000000 │  0.49831354891 ║  0.49809841990 ║   -  │ 2.15e-04 ║ S  │     2 │ 0.500000 ║     1 │ 0.060126   ║ 
  40 ║ 1.000000 │  0.49647014551 ║  0.49639000250 ║   -  │ 8.01e-05 ║ S  │     3 │ 0.250000 ║     1 │ 0.066361   ║ 
  50 ║ 1.000000 │  0.49621088442 ║  0.49616684496 ║   -  │ 4.40e-05 ║ S  │     7 │ 0.015625 ║     1 │ 0.035920   ║ 
  60 ║ 1.000000 │  0.49614852900 ║  0.49611868341 ║   -  │ 2.98e-05 ║ S  │     4 │ 0.125000 ║     1 │ 0.053879   ║ 
  70 ║ 1.000000 │  0.49609734060 ║  0.49608064651 ║   -  │ 1.67e-05 ║ S  │     5 │ 0.062500 ║     1 │ 0.035289   ║ 
  80 ║ 1.000000 │  0.49605615563 ║  0.49605445840 ║   -  │ 1.70e-06 ║ S  │     3 │ 0.250000 ║     1 │ 4.82e-04   ║ 
  90 ║ 1.000000 │  0.49604762008 ║  0.49604703924 ║   -  │ 5.81e-07 ║ S  │     4 │ 0.125000 ║     1 │ 2.27e-04   ║ 
 100 ║ 1.000000 │  0.49604548758 ║  0.49604529936 ║   -  │ 1.88e-07 ║ S 