In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io as sio
import os

In [2]:
from __future__ import print_function
import time

import numpy as np
np.random.seed(1234)
from functools import reduce
import math as m

import scipy.io
#import theano
#import theano.tensor as T

from scipy.interpolate import griddata
from sklearn.preprocessing import scale
#from utils import augment_EEG, cart2sph, pol2cart

#import lasagne
# from lasagne.layers.dnn import Conv2DDNNLayer as ConvLayer
#from lasagne.layers import Conv2DLayer, MaxPool2DLayer, InputLayer
#from lasagne.layers import DenseLayer, ElemwiseMergeLayer, FlattenLayer
#from lasagne.layers import ConcatLayer, ReshapeLayer, get_output_shape
#from lasagne.layers import Conv1DLayer, DimshuffleLayer, LSTMLayer, SliceLayer


def azim_proj(pos):
    """
    Computes the Azimuthal Equidistant Projection of input point in 3D Cartesian Coordinates.
    Imagine a plane being placed against (tangent to) a globe. If
    a light source inside the globe projects the graticule onto
    the plane the result would be a planar, or azimuthal, map
    projection.

    :param pos: position in 3D Cartesian coordinates
    :return: projected coordinates using Azimuthal Equidistant Projection
    """
    [r, elev, az] = cart2sph(pos[0], pos[1], pos[2])
    return pol2cart(az, m.pi / 2 - elev)


def gen_images(locs, features, n_gridpoints, normalize=True,
               augment=False, pca=False, std_mult=0.1, n_components=2, edgeless=False):
    """
    Generates EEG images given electrode locations in 2D space and multiple feature values for each electrode

    :param locs: An array with shape [n_electrodes, 2] containing X, Y
                        coordinates for each electrode.
    :param features: Feature matrix as [n_samples, n_features]
                                Features are as columns.
                                Features corresponding to each frequency band are concatenated.
                                (alpha1, alpha2, ..., beta1, beta2,...)
    :param n_gridpoints: Number of pixels in the output images
    :param normalize:   Flag for whether to normalize each band over all samples
    :param augment:     Flag for generating augmented images
    :param pca:         Flag for PCA based data augmentation
    :param std_mult     Multiplier for std of added noise
    :param n_components: Number of components in PCA to retain for augmentation
    :param edgeless:    If True generates edgeless images by adding artificial channels
                        at four corners of the image with value = 0 (default=False).
    :return:            Tensor of size [samples, colors, W, H] containing generated
                        images.
    """
    feat_array_temp = []
    nElectrodes = locs.shape[0]     # Number of electrodes
    # Test whether the feature vector length is divisible by number of electrodes
    assert features.shape[1] % nElectrodes == 0
    n_colors = int(features.shape[1] / nElectrodes)
    for c in range(int(n_colors)):
        feat_array_temp.append(features[:, c * nElectrodes : nElectrodes * (c+1)])
    if augment:
        if pca:
            for c in range(n_colors):
                feat_array_temp[c] = augment_EEG(feat_array_temp[c], std_mult, pca=True, n_components=n_components)
        else:
            for c in range(n_colors):
                feat_array_temp[c] = augment_EEG(feat_array_temp[c], std_mult, pca=False, n_components=n_components)
    nSamples = features.shape[0]
    # Interpolate the values
    grid_x, grid_y = np.mgrid[
                     min(locs[:, 0]):max(locs[:, 0]):n_gridpoints*1j,
                     min(locs[:, 1]):max(locs[:, 1]):n_gridpoints*1j
                     ]
    temp_interp = []
    for c in range(n_colors):
        temp_interp.append(np.zeros([nSamples, n_gridpoints, n_gridpoints]))
    # Generate edgeless images
    if edgeless:
        min_x, min_y = np.min(locs, axis=0)
        max_x, max_y = np.max(locs, axis=0)
        locs = np.append(locs, np.array([[min_x, min_y], [min_x, max_y],[max_x, min_y],[max_x, max_y]]),axis=0)
        for c in range(n_colors):
            feat_array_temp[c] = np.append(feat_array_temp[c], np.zeros((nSamples, 4)), axis=1)
    # Interpolating
    for i in range(nSamples):
        for c in range(n_colors):
            temp_interp[c][i, :, :] = griddata(locs, feat_array_temp[c][i, :], (grid_x, grid_y),
                                    method='cubic', fill_value=np.nan)
        print('Interpolating {0}/{1}\r'.format(i+1, nSamples), end='\r')
    # Normalizing
    for c in range(n_colors):
        if normalize:
            temp_interp[c][~np.isnan(temp_interp[c])] = \
                scale(temp_interp[c][~np.isnan(temp_interp[c])])
        temp_interp[c] = np.nan_to_num(temp_interp[c])
    return np.swapaxes(np.asarray(temp_interp), 0, 1)     # swap axes to have [samples, colors, W, H]

def get_fft(snippet):
    Fs = 200.0;  # sampling rate
    #Ts = len(snippet)/Fs/Fs; # sampling interval
    snippet_time = len(snippet)/Fs
    Ts = 0.5/Fs; # sampling interval
    t = np.arange(0,snippet_time,Ts) # time vector

    # ff = 5;   # frequency of the signal
    # y = np.sin(2*np.pi*ff*t)
    y = snippet
#     print('Ts: ',Ts)
#     print(t)
#     print(y.shape)
    n = len(y) # length of the signal
    k = np.arange(n)
    T = n/Fs
    frq = k/T # two sides frequency range
    frq = frq[range(n//2)] # one side frequency range

    Y = np.fft.fft(y)/n # fft computing and normalization
    Y = Y[range(n//2)]
    #Added in: (To remove bias.)
    #Y[0] = 0
    return frq,abs(Y)

def theta_alpha_beta_averages(f,Y):
    theta_range = (4,8)
    alpha_range = (8,12)
    beta_range = (12,40)
    theta = Y[(f>theta_range[0]) & (f<=theta_range[1])].mean()
    alpha = Y[(f>alpha_range[0]) & (f<=alpha_range[1])].mean()
    beta = Y[(f>beta_range[0]) & (f<=beta_range[1])].mean()
    return theta, alpha, beta

def make_frames(mark,df):
    '''
    mark: the time label of your dataset
    df: all channels in dyour dataset
    '''
    frames = []
    
    for i in range(len(mark)-1):
        frame = []
        for channel in df.columns:
            snippet = np.array(df.loc[mark[i]:mark[i+1],int(channel)])
            f,Y =  get_fft(snippet)
            theta, alpha, beta = theta_alpha_beta_averages(f,Y)
            frame.append([theta, alpha, beta])
        frames.append(frame)
        
    return np.array(frames)

def TrainTest_Model(model, trainloader, testloader, n_epoch=30, opti='Adam', learning_rate=0.0001, is_cuda=True, print_epoch =5, verbose=False):
    net = model()
        
    criterion = nn.CrossEntropyLoss()
    
    if opti=='SGD':
        optimizer = optim.SGD(net.parameters(), lr=learning_rate)
    elif opti =='Adam':
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    else: 
        print("Optimizer: "+optim+" not implemented.")
    a = 0
    count = 0
    for epoch in range(n_epoch):
        running_loss = 0.0
        evaluation = []
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            #a=a+1
            """if (a<2):
                print(inputs)
                print(labels)
                print(data)"""
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs.to(torch.float32))
            _, predicted = torch.max(outputs.cpu().data, 1)
            """ if a < 2:
                print(predicted)"""
            evaluation.append((predicted==labels).tolist())
            loss = criterion(outputs, labels.long())
            """if a < 2:
                print(evaluation)"""
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        running_loss = running_loss/(i+1)
        evaluation = [item for sublist in evaluation for item in sublist]
        #print(len(evaluation))
        running_acc = sum(evaluation)/len(evaluation)
        validation_loss, validation_acc = Test_Model(net, testloader, criterion,True)
        count = count+1
        if epoch%print_epoch==(print_epoch-1):
            print('[%d, %3d]\tloss: %.3f\tAccuracy : %.3f\t\tval-loss: %.3f\tval-Accuracy : %.3f' %
             (epoch+1, n_epoch, running_loss, running_acc, validation_loss, validation_acc))
        if count%10 == 0:
            print('epoch:%d \n loss: %.3f\tAccuracy : %.3f\t\tval-loss: %.3f\tval-Accuracy : %.3f' %
                 (count, running_loss, running_acc, validation_loss,validation_acc))
    if verbose:
        print('Finished Training \n loss: %.3f\tAccuracy : %.3f\t\tval-loss: %.3f\tval-Accuracy : %.3f' %
                 (running_loss, running_acc, validation_loss,validation_acc))
    torch.save(net.state_dict(), "BasicCNN_forall_mix.pth")
    return (running_loss, running_acc, validation_loss,validation_acc)

from torch.utils.data.dataset import Dataset
import torch
import scipy.io as sio
import torch.optim as optim
import torch.nn as nn
import numpy as np
class EEGImagesDataset(Dataset):
    """EEGLearn Images Dataset from EEG."""
    
    def __init__(self, label, image):
        self.label = label
        self.Images = image
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self.Images[idx]
        label = self.label[idx]
        sample = (image, label)
        
        return sample

def Test_Model(net, Testloader, criterion, is_cuda=True):
    running_loss = 0.0 
    evaluation = []
    for i, data in enumerate(Testloader, 0):
        input_img, labels = data
        input_img = input_img.to(torch.float32)
        """if is_cuda:
            input_img = input_img.cuda()"""
        outputs = net(input_img)
        _, predicted = torch.max(outputs.cpu().data, 1)
        evaluation.append((predicted==labels).tolist())
        loss = criterion(outputs, labels.long())
        running_loss += loss.item()
    running_loss = running_loss/(i+1)
    evaluation = [item for sublist in evaluation for item in sublist]
    running_acc = sum(evaluation)/len(evaluation)
    return running_loss, running_acc

class BasicCNN(nn.Module):
    '''
    Build the  Mean Basic model performing a classification with CNN 

    param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]
    param kernel: kernel size used for the convolutional layers
    param stride: stride apply during the convolutions
    param padding: padding used during the convolutions
    param max_kernel: kernel used for the maxpooling steps
    param n_classes: number of classes
    return x: output of the last layers after the log softmax
    '''

    def __init__(self, input_image=torch.zeros(1, 3, 32, 32), kernel=(3, 3), stride=1, padding=1, max_kernel=(2, 2),
                 n_classes=4):
        super(BasicCNN, self).__init__()

        n_channel = input_image.shape[1]

        self.conv1 = nn.Conv2d(n_channel, 32, kernel, stride=stride, padding=padding)
        self.conv2 = nn.Conv2d(32, 32, kernel, stride=stride, padding=padding)
        self.conv3 = nn.Conv2d(32, 32, kernel, stride=stride, padding=padding)
        self.conv4 = nn.Conv2d(32, 32, kernel, stride=stride, padding=padding)
        self.pool1 = nn.MaxPool2d(max_kernel)
        self.conv5 = nn.Conv2d(32, 64, kernel, stride=stride, padding=padding)
        self.conv6 = nn.Conv2d(64, 64, kernel, stride=stride, padding=padding)
        self.conv7 = nn.Conv2d(64, 128, kernel, stride=stride, padding=padding)

        self.pool = nn.MaxPool2d((1, 1))
        self.drop = nn.Dropout(p=0.5)

        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, n_classes)
        self.max = nn.LogSoftmax()

    def forward(self, x):
        batch_size = x.shape[0]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool1(x)
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool1(x)
        x = F.relu(self.conv7(x))
        x = self.pool1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.max(x)
        return x

In [4]:
dataset = sio.loadmat("SEED/Preprocessed_EEG/1_20131027.mat")
features_struct = sio.loadmat("SEED/Preprocessed_EEG/1_20131027.mat")
features_1 = features_struct['djc_eeg1']
dfdata_1 = pd.DataFrame(features_1)
df_1 = dfdata_1.T
print(df_1.shape)
features_2 = features_struct['djc_eeg2']
dfdata_2 = pd.DataFrame(features_2)
df_2 = dfdata_2.T
print(df_2.shape)
features_3 = features_struct['djc_eeg3']
dfdata_3 = pd.DataFrame(features_3)
df_3 = dfdata_3.T
print(df_3.shape)


(47001, 62)
(46601, 62)
(41201, 62)


In [7]:
mark_a =np.linspace(0,47000,471)
mark_b =np.linspace(0,46600,467)
mark_c =np.linspace(0,41200,413)

In [8]:
def make_frames(mark,df):
    '''
    mark: the time label of your dataset
    df: all channels in dyour dataset
    '''
    frames = []
    
    for i in range(len(mark)-1):
        frame = []
        for channel in df.columns:
            snippet = np.array(df.loc[mark[i]:mark[i+1],int(channel)])
            f,Y =  get_fft(snippet)
            theta, alpha, beta = theta_alpha_beta_averages(f,Y)
            frame.append([theta, alpha, beta])
        frames.append(frame)
        
    return np.array(frames)

X = make_frames(mark_a,df_1)
Y = make_frames(mark_b,df_2)
Z = make_frames(mark_c,df_3)

In [9]:
X_1 = X.reshape(X.shape[0],X.shape[1]*X.shape[2])
eeg_locs = sio.loadmat("locs_seed.mat")
locs_2d =eeg_locs["locs"]
images_a = gen_images(locs_2d,X_1, 32, normalize=False)

Interpolating 470/470nterpolating 11/470Interpolating 16/470Interpolating 21/470Interpolating 26/470Interpolating 31/470Interpolating 36/470Interpolating 41/470Interpolating 46/470Interpolating 51/470Interpolating 56/470Interpolating 61/470Interpolating 66/470Interpolating 71/470Interpolating 76/470Interpolating 81/470Interpolating 86/470Interpolating 91/470Interpolating 96/470Interpolating 101/470Interpolating 106/470Interpolating 111/470Interpolating 116/470Interpolating 121/470Interpolating 126/470Interpolating 131/470Interpolating 136/470Interpolating 141/470Interpolating 146/470Interpolating 151/470Interpolating 156/470Interpolating 161/470Interpolating 166/470Interpolating 171/470Interpolating 176/470Interpolating 181/470Interpolating 186/470Interpolating 189/470Interpolating 192/470Interpolating 195/470Interpolating 197/470Interpolating 202/470Interpolating 207/470Interpolating 212/470Interpolating 217/470Interpolating 222/470Interpolating 227/470Interpolating 232/470Interpolati

In [10]:
Y_1 = Y.reshape(Y.shape[0],Y.shape[1]*Y.shape[2])
images_b = gen_images(locs_2d,Y_1, 32, normalize=False)

Interpolating 465/466nterpolating 11/466Interpolating 16/466Interpolating 21/466Interpolating 26/466Interpolating 31/466Interpolating 36/466Interpolating 41/466Interpolating 46/466Interpolating 51/466Interpolating 56/466Interpolating 61/466Interpolating 66/466Interpolating 71/466Interpolating 76/466Interpolating 81/466Interpolating 86/466Interpolating 91/466Interpolating 96/466Interpolating 101/466Interpolating 106/466Interpolating 111/466Interpolating 116/466Interpolating 121/466Interpolating 126/466Interpolating 131/466Interpolating 136/466Interpolating 141/466Interpolating 146/466Interpolating 151/466Interpolating 156/466Interpolating 161/466Interpolating 166/466Interpolating 171/466Interpolating 176/466Interpolating 181/466Interpolating 186/466Interpolating 191/466Interpolating 196/466Interpolating 201/466Interpolating 206/466Interpolating 211/466Interpolating 216/466Interpolating 221/466Interpolating 226/466Interpolating 231/466Interpolating 236/466Interpolating 241/466Interpolati

In [11]:
Z_1 = Z.reshape(Z.shape[0],Z.shape[1]*Z.shape[2])
images_c = gen_images(locs_2d,Z_1, 32, normalize=False)

Interpolating 412/412nterpolating 11/412Interpolating 16/412Interpolating 21/412Interpolating 26/412Interpolating 31/412Interpolating 36/412Interpolating 41/412Interpolating 46/412Interpolating 51/412Interpolating 56/412Interpolating 61/412Interpolating 66/412Interpolating 71/412Interpolating 76/412Interpolating 81/412Interpolating 86/412Interpolating 91/412Interpolating 96/412Interpolating 101/412Interpolating 106/412Interpolating 111/412Interpolating 116/412Interpolating 121/412Interpolating 126/412Interpolating 131/412Interpolating 136/412Interpolating 141/412Interpolating 146/412Interpolating 151/412Interpolating 155/412Interpolating 158/412Interpolating 160/412Interpolating 165/412Interpolating 170/412Interpolating 175/412Interpolating 180/412Interpolating 185/412Interpolating 190/412Interpolating 195/412Interpolating 200/412Interpolating 205/412Interpolating 210/412Interpolating 215/412Interpolating 220/412Interpolating 225/412Interpolating 230/412Interpolating 235/412Interpolati

In [12]:
label = []
for i in range(470):
    label.append(0)
for i in range(460):
    label.append(1)
for i in range(412):
    label.append(2)
print(label)
    

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [10]:
print(images.shape)


NameError: name 'images' is not defined

In [13]:
images = np.concatenate((images_a,images_b,images_c),axis=0)

print(images.shape)

(1348, 3, 32, 32)


In [12]:
import numpy as np
import scipy.io as sio
import torch
import os
from os import path

import torch.optim as optim
import torch.nn.functional as F


from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader,random_split

#from Utils import *
#from Models import *

torch.manual_seed(1234)
np.random.seed(1234)


# Introduction: training a simple CNN with the mean of the images.
train_part = 0.8
test_part = 0.2

batch_size = 64

n_epoch = 300
n_rep = 1
Result = []
for r in range(n_rep):
    EEG = EEGImagesDataset(label=label, image=images)
    lengths = [int(len(EEG) * train_part), int(len(EEG) * test_part)]
    if sum(lengths) != len(EEG):
        lengths[0] = lengths[0] + 1
    Train, Test = random_split(EEG, lengths)
    Trainloader = DataLoader(Train, batch_size=batch_size)
    Testloader = DataLoader(Test, batch_size=batch_size)
    res = TrainTest_Model(BasicCNN, Trainloader, Testloader, n_epoch=n_epoch, learning_rate=0.0001, print_epoch=-1,
                              opti='Adam')
    Result.append(res)
#sio.savemat("Res_Basic_Patient"+"all"+".mat", {"res":Result})
Result = np.mean(Result, axis=0)
print ('-'*100)
print('\nBegin Training for all Patient ')
print('End Training with \t loss: %.3f\tAccuracy : %.3f\t\tval-loss: %.3f\tval-Accuracy : %.3f' %
    (Result[0], Result[1], Result[2], Result[3]))
print('\n'+'-'*100)




epoch:10 
 loss: 0.697	Accuracy : 0.638		val-loss: 0.883	val-Accuracy : 0.481
epoch:20 
 loss: 0.354	Accuracy : 0.861		val-loss: 0.476	val-Accuracy : 0.784
epoch:30 
 loss: 0.272	Accuracy : 0.886		val-loss: 0.437	val-Accuracy : 0.821
epoch:40 
 loss: 0.247	Accuracy : 0.900		val-loss: 0.412	val-Accuracy : 0.840
epoch:50 
 loss: 0.213	Accuracy : 0.914		val-loss: 0.379	val-Accuracy : 0.862
epoch:60 
 loss: 0.191	Accuracy : 0.920		val-loss: 0.366	val-Accuracy : 0.877
epoch:70 
 loss: 0.177	Accuracy : 0.928		val-loss: 0.377	val-Accuracy : 0.869
epoch:80 
 loss: 0.152	Accuracy : 0.939		val-loss: 0.364	val-Accuracy : 0.881
epoch:90 
 loss: 0.136	Accuracy : 0.947		val-loss: 0.386	val-Accuracy : 0.877
epoch:100 
 loss: 0.121	Accuracy : 0.950		val-loss: 0.375	val-Accuracy : 0.877
epoch:110 
 loss: 0.109	Accuracy : 0.952		val-loss: 0.429	val-Accuracy : 0.873
epoch:120 
 loss: 0.092	Accuracy : 0.966		val-loss: 0.385	val-Accuracy : 0.892
epoch:130 
 loss: 0.076	Accuracy : 0.973		val-loss: 0.364	val

In [14]:
import torchvision as tv
import torchvision.transforms as transforms
import torch
from PIL import Image
import scipy.io as sio
import numpy as np
from PIL import Image
import warnings
import matplotlib.pyplot as plt


warnings.filterwarnings('ignore')
#i = 15
model = BasicCNN()
# print(model)
# model = torch.load("BasicCNN.pkl")
model.load_state_dict(torch.load("BasicCNN_forall_mix.txt"))
# model = model().cuda()
model.eval()
torch.no_grad()
predict = []
# img_ = image.to(device)

for i in range (10):
    image = images[i:i+1]
    outputs = model(torch.tensor(image).to(torch.float32))
    #print(outputs.data)
    _, predicted = torch.max(outputs.data, 1)
    predict.append(predicted.item())
print(predict)
print(len(predict))
print("all the labels of this patient")
print(label_a)
count = 0
for i in range(10):
    if predict[i]==label_a[i]:
        count= count+1
print(count)

import matplotlib.pyplot as plt
%matplotlib inline
image_1 = images[0]*10
plt.imshow((images[0]*10).astype(np.uint8))

RuntimeError: Error(s) in loading state_dict for BasicCNN:
	size mismatch for fc2.weight: copying a param with shape torch.Size([2, 512]) from checkpoint, the shape in current model is torch.Size([4, 512]).
	size mismatch for fc2.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([4]).