# Experiment 5: Scattering Features + CNN Model




In [76]:
import sys
sys.path.append('../src')
import warnings
warnings.filterwarnings("ignore") 

from utils.reduce import reduce_pca
from utils.split import train_test_split, train_test_split_pytorch
from utils.UltrasoundDataset import UltrasoundDataset
from utils.Networks import  BasicBlock, Scattering2dResNet
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.autograd import Variable
from kymatio.torch import Scattering2D
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim
import pickle
import pandas as pd
import numpy as np
import mlflow
import matplotlib.pyplot as plt

In [77]:
# create an experiment in mlruns if doesn't exit and specifiy where to log
mlflow.set_experiment('scattering_cnn_experiment')

INFO: 'scattering_cnn_experiment' does not exist. Creating a new experiment


## Upload Ultrasound images

In [78]:
with open('../data/02_interim/bmodes_steatosis_assessment_IJCARS.pickle', 'rb') as handle:
    df = pickle.load(handle)

In [79]:
M, N= 434, 636 # ultrasound image dimension

In [80]:
# split training and test (by making sure the 10 ultrasound images of one patient is in the same set)
train_data, test_data = train_test_split(df)
train_data, val_data = train_test_split(train_data)

In [81]:
###############################################################################
# If a GPU is available, let's use it!
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
###############################################################################

# Create dataset

In [82]:
# create dataset
# from https://github.com/python-engineer/pytorchTutorial/blob/master/09_dataloader.py

train_dataset = UltrasoundDataset(train_data)
test_dataset  = UltrasoundDataset(test_data)
val_dataset  = UltrasoundDataset(val_data)

In [83]:
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    # in_planes = in_channels
    # out_planes = out_channels
    return nn.Conv2d(in_planes, out_planes, kernel_size=10, stride=stride,
                     padding=1, bias=False)


## Defining scattering transformations

In [84]:
# Set the parameters by cross-validation
from itertools import product
param_batch_size = [10, 20, 50]
param_J = [2,3]
param_max_order = [1,2]
params = list(product(param_batch_size,param_J, param_max_order))


## Training and Testing Functions

In [85]:
def train(model, device, train_loader, optimizer, epoch, scattering):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(scattering(data))
        loss = F.cross_entropy(output, target.type(torch.long))
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader, scattering):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(scattering(data))
            test_loss += F.cross_entropy(output, target.type(torch.long), reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, 100. * correct / len(test_loader.dataset)


## Training

In [None]:
for param in params:
    # Do cross-validation
    with mlflow.start_run():
        batch_size = param[0]
        J = param[1]
        max_order = param[2]
        mlflow.log_param('batch_size',batch_size)
        mlflow.log_param('J', J)
        mlflow.log_param('max_order', max_order)

        #create data loader
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=2)
        test_loader = DataLoader(dataset=test_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=2)
        val_loader = DataLoader(dataset=val_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=2)
        

        # compute scattering features---------------------------------------                       
        if max_order == 1:
            scattering = Scattering2D(J=J, shape=(M, N), max_order=max_order)
            K = 17*1
        else:
            scattering = Scattering2D(J=J, shape=(M, N))
            K = 81*1

        if use_cuda:
            scattering = scattering.cuda()     
        
        model = Scattering2dResNet(K,2).to(device)

        #training------------------------------------------------------------                         
        from tqdm import tqdm
        # Optimizer
        lr = 0.1
        val_loss_no_imp = 0
        val_loss_last = 0
        patience = 5
        for epoch in tqdm(range(0, 50)):
            optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9,
                                      weight_decay=0.0005)
            lr*=0.2
            train(model, device, train_loader, optimizer, epoch+1, scattering)
            val_loss , _ = test(model, device, val_loader, scattering)
            if val_loss > val_loss_last: 
                val_loss_no_imp +=1
            else:
                val_loss_no_imp = 0
            val_loss_last = val_loss
            if val_loss_no_imp >= patience: break
            
        val_loss , val_accuracy = test(model, device, val_loader, scattering)
        mlflow.log_metric('validation_accuracy',val_accuracy)  
        print('Combinaison Done')






  0%|          | 0/50 [00:00<?, ?it/s][A[A





  2%|▏         | 1/50 [00:11<08:59, 11.00s/it][A[A


Validation set: Average loss: 0.7052, Accuracy: 30/50 (60.00%)





  4%|▍         | 2/50 [00:22<08:50, 11.06s/it][A[A


Validation set: Average loss: 0.4883, Accuracy: 40/50 (80.00%)





  6%|▌         | 3/50 [00:33<08:42, 11.13s/it][A[A


Validation set: Average loss: 0.4780, Accuracy: 40/50 (80.00%)





  8%|▊         | 4/50 [00:44<08:33, 11.17s/it][A[A


Validation set: Average loss: 0.4756, Accuracy: 40/50 (80.00%)





 10%|█         | 5/50 [00:56<08:24, 11.20s/it][A[A


Validation set: Average loss: 0.4763, Accuracy: 40/50 (80.00%)





 12%|█▏        | 6/50 [01:07<08:13, 11.23s/it][A[A


Validation set: Average loss: 0.4773, Accuracy: 40/50 (80.00%)





 14%|█▍        | 7/50 [01:18<08:03, 11.24s/it][A[A


Validation set: Average loss: 0.4769, Accuracy: 40/50 (80.00%)





 16%|█▌        | 8/50 [01:29<07:53, 11.26s/it][A[A


Validation set: Average loss: 0.4778, Accuracy: 40/50 (80.00%)





 18%|█▊        | 9/50 [01:41<07:42, 11.28s/it][A[A


Validation set: Average loss: 0.4778, Accuracy: 40/50 (80.00%)





 20%|██        | 10/50 [01:52<07:31, 11.28s/it][A[A


Validation set: Average loss: 0.4778, Accuracy: 40/50 (80.00%)





 22%|██▏       | 11/50 [02:03<07:20, 11.29s/it][A[A


Validation set: Average loss: 0.4776, Accuracy: 40/50 (80.00%)





 24%|██▍       | 12/50 [02:15<07:09, 11.30s/it][A[A


Validation set: Average loss: 0.4765, Accuracy: 40/50 (80.00%)





 26%|██▌       | 13/50 [02:26<06:58, 11.30s/it][A[A


Validation set: Average loss: 0.4768, Accuracy: 40/50 (80.00%)





 28%|██▊       | 14/50 [02:37<06:46, 11.30s/it][A[A


Validation set: Average loss: 0.4767, Accuracy: 40/50 (80.00%)





 30%|███       | 15/50 [02:49<06:35, 11.30s/it][A[A


Validation set: Average loss: 0.4769, Accuracy: 40/50 (80.00%)





 32%|███▏      | 16/50 [03:00<06:23, 11.29s/it][A[A


Validation set: Average loss: 0.4774, Accuracy: 40/50 (80.00%)





 34%|███▍      | 17/50 [03:11<06:12, 11.29s/it][A[A


Validation set: Average loss: 0.4764, Accuracy: 40/50 (80.00%)





 36%|███▌      | 18/50 [03:22<06:01, 11.29s/it][A[A


Validation set: Average loss: 0.4766, Accuracy: 40/50 (80.00%)





 38%|███▊      | 19/50 [03:34<05:49, 11.29s/it][A[A


Validation set: Average loss: 0.4783, Accuracy: 40/50 (80.00%)





 40%|████      | 20/50 [03:45<05:38, 11.29s/it][A[A


Validation set: Average loss: 0.4767, Accuracy: 40/50 (80.00%)





 42%|████▏     | 21/50 [03:56<05:27, 11.29s/it][A[A


Validation set: Average loss: 0.4764, Accuracy: 40/50 (80.00%)





 44%|████▍     | 22/50 [04:08<05:16, 11.30s/it][A[A


Validation set: Average loss: 0.4766, Accuracy: 40/50 (80.00%)





 46%|████▌     | 23/50 [04:19<05:05, 11.30s/it][A[A


Validation set: Average loss: 0.4775, Accuracy: 40/50 (80.00%)





 48%|████▊     | 24/50 [04:30<04:54, 11.31s/it][A[A


Validation set: Average loss: 0.4779, Accuracy: 40/50 (80.00%)





 50%|█████     | 25/50 [04:42<04:42, 11.31s/it][A[A


Validation set: Average loss: 0.4776, Accuracy: 40/50 (80.00%)





 52%|█████▏    | 26/50 [04:53<04:31, 11.30s/it][A[A


Validation set: Average loss: 0.4772, Accuracy: 40/50 (80.00%)





 54%|█████▍    | 27/50 [05:04<04:19, 11.30s/it][A[A


Validation set: Average loss: 0.4763, Accuracy: 40/50 (80.00%)





 56%|█████▌    | 28/50 [05:15<04:08, 11.29s/it][A[A


Validation set: Average loss: 0.4772, Accuracy: 40/50 (80.00%)





 58%|█████▊    | 29/50 [05:27<03:57, 11.30s/it][A[A


Validation set: Average loss: 0.4775, Accuracy: 40/50 (80.00%)





 60%|██████    | 30/50 [05:38<03:45, 11.30s/it][A[A


Validation set: Average loss: 0.4775, Accuracy: 40/50 (80.00%)





 62%|██████▏   | 31/50 [05:49<03:34, 11.29s/it][A[A


Validation set: Average loss: 0.4779, Accuracy: 40/50 (80.00%)





 64%|██████▍   | 32/50 [06:01<03:23, 11.31s/it][A[A


Validation set: Average loss: 0.4768, Accuracy: 40/50 (80.00%)





 66%|██████▌   | 33/50 [06:12<03:12, 11.30s/it][A[A


Validation set: Average loss: 0.4774, Accuracy: 40/50 (80.00%)





 68%|██████▊   | 34/50 [06:23<03:00, 11.30s/it][A[A


Validation set: Average loss: 0.4765, Accuracy: 40/50 (80.00%)





 70%|███████   | 35/50 [06:34<02:49, 11.29s/it][A[A


Validation set: Average loss: 0.4787, Accuracy: 40/50 (80.00%)





 72%|███████▏  | 36/50 [06:46<02:38, 11.29s/it][A[A


Validation set: Average loss: 0.4774, Accuracy: 40/50 (80.00%)





 74%|███████▍  | 37/50 [06:57<02:26, 11.30s/it][A[A


Validation set: Average loss: 0.4774, Accuracy: 40/50 (80.00%)





 76%|███████▌  | 38/50 [07:08<02:15, 11.30s/it][A[A


Validation set: Average loss: 0.4767, Accuracy: 40/50 (80.00%)





 78%|███████▊  | 39/50 [07:20<02:04, 11.30s/it][A[A


Validation set: Average loss: 0.4782, Accuracy: 40/50 (80.00%)





 80%|████████  | 40/50 [07:31<01:52, 11.30s/it][A[A


Validation set: Average loss: 0.4763, Accuracy: 40/50 (80.00%)





 82%|████████▏ | 41/50 [07:42<01:41, 11.30s/it][A[A


Validation set: Average loss: 0.4772, Accuracy: 40/50 (80.00%)





 84%|████████▍ | 42/50 [07:54<01:30, 11.30s/it][A[A


Validation set: Average loss: 0.4770, Accuracy: 40/50 (80.00%)





 86%|████████▌ | 43/50 [08:05<01:19, 11.30s/it][A[A


Validation set: Average loss: 0.4783, Accuracy: 40/50 (80.00%)





 88%|████████▊ | 44/50 [08:16<01:07, 11.29s/it][A[A


Validation set: Average loss: 0.4782, Accuracy: 40/50 (80.00%)





 90%|█████████ | 45/50 [08:27<00:56, 11.30s/it][A[A


Validation set: Average loss: 0.4778, Accuracy: 40/50 (80.00%)





 92%|█████████▏| 46/50 [08:39<00:45, 11.30s/it][A[A


Validation set: Average loss: 0.4769, Accuracy: 40/50 (80.00%)





 94%|█████████▍| 47/50 [08:50<00:33, 11.29s/it][A[A


Validation set: Average loss: 0.4775, Accuracy: 40/50 (80.00%)





 96%|█████████▌| 48/50 [09:01<00:22, 11.29s/it][A[A


Validation set: Average loss: 0.4769, Accuracy: 40/50 (80.00%)





 98%|█████████▊| 49/50 [09:13<00:11, 11.30s/it][A[A


Validation set: Average loss: 0.4774, Accuracy: 40/50 (80.00%)





100%|██████████| 50/50 [09:24<00:00, 11.29s/it][A[A


Validation set: Average loss: 0.4770, Accuracy: 40/50 (80.00%)







Validation set: Average loss: 0.4770, Accuracy: 40/50 (80.00%)

Combinaison Done




  0%|          | 0/50 [00:00<?, ?it/s][A[A





  2%|▏         | 1/50 [00:15<12:25, 15.21s/it][A[A


Validation set: Average loss: 0.4882, Accuracy: 40/50 (80.00%)





  4%|▍         | 2/50 [00:30<12:11, 15.24s/it][A[A


Validation set: Average loss: 0.4457, Accuracy: 50/50 (100.00%)





  6%|▌         | 3/50 [00:45<11:57, 15.27s/it][A[A


Validation set: Average loss: 0.4278, Accuracy: 47/50 (94.00%)





  8%|▊         | 4/50 [01:01<11:43, 15.30s/it][A[A


Validation set: Average loss: 0.3951, Accuracy: 50/50 (100.00%)





 10%|█         | 5/50 [01:16<11:28, 15.31s/it][A[A


Validation set: Average loss: 0.3942, Accuracy: 50/50 (100.00%)





 12%|█▏        | 6/50 [01:31<11:13, 15.31s/it][A[A


Validation set: Average loss: 0.3964, Accuracy: 49/50 (98.00%)





 14%|█▍        | 7/50 [01:47<10:58, 15.32s/it][A[A


Validation set: Average loss: 0.3953, Accuracy: 50/50 (100.00%)





 16%|█▌        | 8/50 [02:02<10:43, 15.32s/it][A[A


Validation set: Average loss: 0.3967, Accuracy: 49/50 (98.00%)





 18%|█▊        | 9/50 [02:17<10:28, 15.33s/it][A[A


Validation set: Average loss: 0.3961, Accuracy: 50/50 (100.00%)





 20%|██        | 10/50 [02:33<10:13, 15.33s/it][A[A


Validation set: Average loss: 0.3963, Accuracy: 50/50 (100.00%)





 22%|██▏       | 11/50 [02:48<09:57, 15.33s/it][A[A


Validation set: Average loss: 0.3938, Accuracy: 50/50 (100.00%)





 24%|██▍       | 12/50 [03:03<09:42, 15.33s/it][A[A


Validation set: Average loss: 0.3987, Accuracy: 49/50 (98.00%)





 26%|██▌       | 13/50 [03:19<09:27, 15.34s/it][A[A


Validation set: Average loss: 0.3963, Accuracy: 50/50 (100.00%)





 28%|██▊       | 14/50 [03:34<09:12, 15.34s/it][A[A


Validation set: Average loss: 0.3925, Accuracy: 50/50 (100.00%)





 30%|███       | 15/50 [03:49<08:56, 15.33s/it][A[A


Validation set: Average loss: 0.3983, Accuracy: 49/50 (98.00%)





 32%|███▏      | 16/50 [04:05<08:40, 15.31s/it][A[A


Validation set: Average loss: 0.3959, Accuracy: 50/50 (100.00%)





 34%|███▍      | 17/50 [04:20<08:25, 15.31s/it][A[A


Validation set: Average loss: 0.3947, Accuracy: 50/50 (100.00%)





 36%|███▌      | 18/50 [04:35<08:09, 15.31s/it][A[A


Validation set: Average loss: 0.3957, Accuracy: 50/50 (100.00%)





 38%|███▊      | 19/50 [04:51<07:54, 15.32s/it][A[A


Validation set: Average loss: 0.3980, Accuracy: 49/50 (98.00%)





 40%|████      | 20/50 [05:06<07:39, 15.31s/it][A[A


Validation set: Average loss: 0.3956, Accuracy: 49/50 (98.00%)





 42%|████▏     | 21/50 [05:21<07:24, 15.33s/it][A[A


Validation set: Average loss: 0.3944, Accuracy: 50/50 (100.00%)





 44%|████▍     | 22/50 [05:37<07:09, 15.33s/it][A[A


Validation set: Average loss: 0.3962, Accuracy: 49/50 (98.00%)





 46%|████▌     | 23/50 [05:52<06:53, 15.33s/it][A[A


Validation set: Average loss: 0.3950, Accuracy: 50/50 (100.00%)





 48%|████▊     | 24/50 [06:07<06:38, 15.33s/it][A[A


Validation set: Average loss: 0.3931, Accuracy: 50/50 (100.00%)





 50%|█████     | 25/50 [06:23<06:23, 15.34s/it][A[A


Validation set: Average loss: 0.3960, Accuracy: 49/50 (98.00%)





 52%|█████▏    | 26/50 [06:38<06:07, 15.32s/it][A[A


Validation set: Average loss: 0.3982, Accuracy: 49/50 (98.00%)





 70%|███████   | 35/50 [08:56<03:49, 15.32s/it][A[A


Validation set: Average loss: 0.3947, Accuracy: 50/50 (100.00%)





 72%|███████▏  | 36/50 [09:11<03:34, 15.33s/it][A[A


Validation set: Average loss: 0.3975, Accuracy: 49/50 (98.00%)





 74%|███████▍  | 37/50 [09:26<03:19, 15.32s/it][A[A


Validation set: Average loss: 0.3981, Accuracy: 49/50 (98.00%)





 76%|███████▌  | 38/50 [09:42<03:03, 15.33s/it][A[A


Validation set: Average loss: 0.3960, Accuracy: 50/50 (100.00%)





 78%|███████▊  | 39/50 [09:57<02:48, 15.34s/it][A[A


Validation set: Average loss: 0.3981, Accuracy: 49/50 (98.00%)





 80%|████████  | 40/50 [10:13<02:33, 15.35s/it][A[A


Validation set: Average loss: 0.3967, Accuracy: 50/50 (100.00%)





 82%|████████▏ | 41/50 [10:28<02:18, 15.36s/it][A[A


Validation set: Average loss: 0.3940, Accuracy: 50/50 (100.00%)





 84%|████████▍ | 42/50 [10:43<02:02, 15.35s/it][A[A


Validation set: Average loss: 0.3943, Accuracy: 50/50 (100.00%)





 86%|████████▌ | 43/50 [10:59<01:47, 15.35s/it][A[A


Validation set: Average loss: 0.3976, Accuracy: 49/50 (98.00%)





 88%|████████▊ | 44/50 [11:14<01:32, 15.35s/it][A[A


Validation set: Average loss: 0.3954, Accuracy: 50/50 (100.00%)



# Run Mlflow to see results

`!mlflow ui`

Should launch something like this:



In [None]:
# !mlflow ui

In [47]:
test_data['id'].unique()

array([ 4,  9, 28, 33, 52, 53], dtype=uint8)