# SVM Classification

In this notebook, we investigate the effect of SVM classification using scattering features. Furthermore, we apply the POC algorithm as a preprocessing step to the classification

In [1]:
import os
import sys
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
import kymatio as km
from  matplotlib import pyplot as plt
from sklearn.svm import SVC

sys.path.append("..")

from lib.data.data_loading import ClassificationDataset
from lib.data.data_processing import convert_images_to_scat, convert_loader_to_scat
from lib.utils.visualizations import display_subset_data, visualize_accuracy_landscape
from lib.clustering.uspec import USPEC
from lib.projections.POC import POC
from CONFIG import CONFIG

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
DATA_PATH = os.path.join(ROOT_PATH, "data")

DATASET = "mnist"

## Dataset and Scattering

In [4]:
# loading all mnist data
dataset = ClassificationDataset(data_path=DATA_PATH, dataset_name=DATASET, valid_size=0.25) 
train_loader = dataset.get_data_loader(split="train", batch_size=128)
valid_loader = dataset.get_data_loader(split="valid", batch_size=128)
test_loader = dataset.get_data_loader(split="test", batch_size=128)

In [5]:
# defining the scattering network
J = 3  # spatial field of the kernel is 2**J 
L = 6  # number of angles in the kernel
shape = (32,32)  # shape of the input images
max_order = 2 # depth of the network
scattering_layer = km.Scattering2D(J=J, shape=shape, max_order=max_order, L=L)
if DEVICE.type == 'cuda':
    scattering_layer = scattering_layer.cuda()

In [6]:
# computing the scattering transform of all images
print("Processing train-set images...")
train_imgs, train_scat_features,\
    train_labels = convert_loader_to_scat(train_loader, scattering=scattering_layer, 
                                          device=DEVICE, equalize=True, verbose=1)
print("Processing valid-set images...")
valid_imgs, valid_scat_features,\
    valid_labels = convert_loader_to_scat(valid_loader, scattering=scattering_layer,
                                          device=DEVICE, equalize=True, verbose=1)
print("Processing test-set images...")
test_imgs, test_scat_features,\
    test_labels = convert_loader_to_scat(test_loader, scattering=scattering_layer,
                                         device=DEVICE, equalize=True, verbose=1)

  0%|          | 0/352 [00:00<?, ?it/s]

Processing train-set images...


100%|██████████| 352/352 [00:45<00:00,  7.78it/s]
  0%|          | 0/118 [00:00<?, ?it/s]

Processing valid-set images...


100%|██████████| 118/118 [00:16<00:00,  7.19it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Processing test-set images...


100%|██████████| 79/79 [00:11<00:00,  7.10it/s]


## POC Preprocessing

In [7]:
poc = POC()
poc.fit(data=train_scat_features)
proj_train_scat = poc.transform(data=train_scat_features, n_dims=2)
proj_valid_scat = poc.transform(data=valid_scat_features, n_dims=2)
proj_test_scat = poc.transform(data=test_scat_features, n_dims=2)

## SVM Training

In [8]:
svm =  SVC(C=10, kernel="rbf")
svm_proj =  SVC(C=10, kernel="rbf")

In [11]:
valid_scat_features_fit = valid_scat_features.reshape(valid_scat_features.shape[0], -1)
svm.fit(valid_scat_features_fit, valid_labels)

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
    kernel='rbf', max_iter=-1, probability=False, random_state=None,
    shrinking=True, tol=0.001, verbose=False)

In [12]:
svm_proj.fit(proj_valid_scat, valid_labels)

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
    kernel='rbf', max_iter=-1, probability=False, random_state=None,
    shrinking=True, tol=0.001, verbose=False)

## SVM Evaluating

In [13]:
test_scat_features_fit = test_scat_features.reshape(test_scat_features.shape[0], -1)
svm.score(test_scat_features_fit, test_labels)

0.9775

In [15]:
svm_proj.score(proj_test_scat, test_labels)

0.9731