In [1]:
import numpy as np
import pandas as pd
import h5py

pd.set_option("display.max_columns", None)

In [2]:
trainset = h5py.File("datasets/train_catvnoncat.h5")
testset = h5py.File("datasets/test_catvnoncat.h5")
#print(trainset.keys())

features = np.array(trainset["train_set_x"])
labels = np.array(trainset["train_set_y"]).reshape(-1, 1)

print(features.shape)
print(labels.shape)


(209, 64, 64, 3)
(209, 1)


In [3]:
from sklearn.model_selection import train_test_split

X_train, X_cv, Y_train, Y_cv = train_test_split(features, labels, test_size=0.3)

print(X_train.shape)
print(Y_train.shape)
print(X_cv.shape)
print(Y_cv.shape)

(146, 64, 64, 3)
(146, 1)
(63, 64, 64, 3)
(63, 1)


In [19]:
import torch
from torch.nn import Module, Linear, BatchNorm1d, BatchNorm2d, Sigmoid, BCELoss, ReLU, Conv2d, MaxPool2d
from torch.optim import SGD
from torchmetrics import Accuracy

class Classifier(Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.bnin = BatchNorm2d(3)
        self.conv1 = Conv2d(3, 32, kernel_size=7, padding=3)
        self.bn1 = BatchNorm2d(32)
        self.act1 = ReLU()
        self.mp1 = MaxPool2d(2)
        
        self.conv2 = Conv2d(32, 96, kernel_size=5, padding=2)
        self.bn2 = BatchNorm2d(96)
        self.act2 = ReLU()
        self.mp2 = MaxPool2d(2)     
        
        self.conv3 = Conv2d(96, 192, kernel_size=3, padding=1)
        self.bn3 = BatchNorm2d(192)
        self.act3 = ReLU()
        self.mp3 = MaxPool2d(2)
        
        self.fc4 = Linear(192*8*8, 200)
        self.bn4 = BatchNorm1d(200)
        self.act4 = ReLU()
        
        self.out = Linear(200, 1)
        self.act5 = Sigmoid()
    
    def forward(self, x):
        x = x.permute(0,3,1,2)
        
        x = self.bnin(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.mp1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.mp2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        x = self.mp3(x)

        x = torch.flatten(x, start_dim=1)
        x = self.fc4(x)
        x = self.bn4(x)
        x = self.act4(x)
        
        x = self.out(x)
        x = self.act5(x)

        return x


    
    def fit(self, X_train, Y_train, loss_fn, opt, X_cv=None, Y_cv=None, epochs=1):
            
        for i in range(epochs):   
            self.train()
            opt.zero_grad()
        
            Y_pred = self(X_train)
            cost = loss_fn(Y_pred, Y_train)
            
            acc = Accuracy().cuda()
            
            msg = f"Iter: {i}    loss: {cost.item(): .4f}    accuracy: {acc(torch.round(Y_pred).int(), Y_train.int()): .4f}    "
            if torch.is_tensor(X_cv) and torch.is_tensor(Y_cv):
                Y_pred_val = self(X_cv)
                cost_val = loss_fn(Y_pred_val, Y_cv)
                msg += f"cv_loss: {cost_val.item(): .4f}    cv_accuracy: {acc(torch.round(Y_pred_val).int(), Y_cv.int()): .4f}"
            
            print(msg)
            
            cost.backward()
        
            opt.step()

In [20]:
input_dim, output_dim = (X_train.shape[1], 1)
LEARNING_RATE = 0.0001
EPOCHS = 10000

model_v1 = Classifier(input_dim, output_dim).cuda()
optimizer = SGD(model_v1.parameters(), lr=LEARNING_RATE, weight_decay=0)
criterion = BCELoss().cuda()

In [None]:
X_train_t = torch.from_numpy(X_train).float().cuda()
Y_train_t = torch.from_numpy(Y_train).float().cuda()
X_cv_t = torch.from_numpy(X_cv).float().cuda()
Y_cv_t = torch.from_numpy(Y_cv).float().cuda()

model_v1.fit(X_train_t, Y_train_t, criterion, optimizer, X_cv=X_cv_t, Y_cv=Y_cv_t, epochs=EPOCHS)

Iter: 0    loss:  0.7247    accuracy:  0.4658    cv_loss:  0.7065    cv_accuracy:  0.5556
Iter: 1    loss:  0.7225    accuracy:  0.4726    cv_loss:  0.7051    cv_accuracy:  0.5397
Iter: 2    loss:  0.7203    accuracy:  0.4726    cv_loss:  0.7038    cv_accuracy:  0.5397
Iter: 3    loss:  0.7181    accuracy:  0.4795    cv_loss:  0.7024    cv_accuracy:  0.5238
Iter: 4    loss:  0.7160    accuracy:  0.4795    cv_loss:  0.7011    cv_accuracy:  0.5238
Iter: 5    loss:  0.7139    accuracy:  0.4932    cv_loss:  0.6999    cv_accuracy:  0.5397
Iter: 6    loss:  0.7119    accuracy:  0.4863    cv_loss:  0.6986    cv_accuracy:  0.5397
Iter: 7    loss:  0.7098    accuracy:  0.4932    cv_loss:  0.6973    cv_accuracy:  0.5397
Iter: 8    loss:  0.7078    accuracy:  0.5000    cv_loss:  0.6961    cv_accuracy:  0.5397
Iter: 9    loss:  0.7059    accuracy:  0.5068    cv_loss:  0.6949    cv_accuracy:  0.5397
Iter: 10    loss:  0.7039    accuracy:  0.5068    cv_loss:  0.6937    cv_accuracy:  0.5397
Iter: 11 

Iter: 91    loss:  0.5948    accuracy:  0.7260    cv_loss:  0.6331    cv_accuracy:  0.6825
Iter: 92    loss:  0.5938    accuracy:  0.7260    cv_loss:  0.6327    cv_accuracy:  0.6825
Iter: 93    loss:  0.5928    accuracy:  0.7329    cv_loss:  0.6322    cv_accuracy:  0.6825
Iter: 94    loss:  0.5919    accuracy:  0.7329    cv_loss:  0.6317    cv_accuracy:  0.6825
Iter: 95    loss:  0.5909    accuracy:  0.7329    cv_loss:  0.6312    cv_accuracy:  0.6825
Iter: 96    loss:  0.5899    accuracy:  0.7329    cv_loss:  0.6307    cv_accuracy:  0.6825
Iter: 97    loss:  0.5890    accuracy:  0.7329    cv_loss:  0.6303    cv_accuracy:  0.6825
Iter: 98    loss:  0.5880    accuracy:  0.7329    cv_loss:  0.6298    cv_accuracy:  0.6825
Iter: 99    loss:  0.5871    accuracy:  0.7466    cv_loss:  0.6293    cv_accuracy:  0.6825
Iter: 100    loss:  0.5861    accuracy:  0.7466    cv_loss:  0.6289    cv_accuracy:  0.6825
Iter: 101    loss:  0.5852    accuracy:  0.7466    cv_loss:  0.6284    cv_accuracy:  0.68

Iter: 181    loss:  0.5214    accuracy:  0.8082    cv_loss:  0.6005    cv_accuracy:  0.6984
Iter: 182    loss:  0.5207    accuracy:  0.8082    cv_loss:  0.6002    cv_accuracy:  0.6984
Iter: 183    loss:  0.5201    accuracy:  0.8082    cv_loss:  0.5999    cv_accuracy:  0.6984
Iter: 184    loss:  0.5194    accuracy:  0.8082    cv_loss:  0.5996    cv_accuracy:  0.6984
Iter: 185    loss:  0.5187    accuracy:  0.8082    cv_loss:  0.5994    cv_accuracy:  0.6984
Iter: 186    loss:  0.5180    accuracy:  0.8082    cv_loss:  0.5991    cv_accuracy:  0.6984
Iter: 187    loss:  0.5174    accuracy:  0.8082    cv_loss:  0.5989    cv_accuracy:  0.6984
Iter: 188    loss:  0.5167    accuracy:  0.8082    cv_loss:  0.5986    cv_accuracy:  0.6984
Iter: 189    loss:  0.5160    accuracy:  0.8082    cv_loss:  0.5983    cv_accuracy:  0.6984
Iter: 190    loss:  0.5154    accuracy:  0.8082    cv_loss:  0.5981    cv_accuracy:  0.6984
Iter: 191    loss:  0.5147    accuracy:  0.8082    cv_loss:  0.5978    cv_accura

Iter: 271    loss:  0.4673    accuracy:  0.8630    cv_loss:  0.5804    cv_accuracy:  0.7143
Iter: 272    loss:  0.4668    accuracy:  0.8630    cv_loss:  0.5802    cv_accuracy:  0.7143
Iter: 273    loss:  0.4663    accuracy:  0.8630    cv_loss:  0.5800    cv_accuracy:  0.7143
Iter: 274    loss:  0.4657    accuracy:  0.8630    cv_loss:  0.5798    cv_accuracy:  0.7143
Iter: 275    loss:  0.4652    accuracy:  0.8630    cv_loss:  0.5796    cv_accuracy:  0.7143
Iter: 276    loss:  0.4647    accuracy:  0.8630    cv_loss:  0.5795    cv_accuracy:  0.7143
Iter: 277    loss:  0.4642    accuracy:  0.8630    cv_loss:  0.5793    cv_accuracy:  0.7143
Iter: 278    loss:  0.4636    accuracy:  0.8630    cv_loss:  0.5791    cv_accuracy:  0.7143
Iter: 279    loss:  0.4631    accuracy:  0.8630    cv_loss:  0.5789    cv_accuracy:  0.7143
Iter: 280    loss:  0.4626    accuracy:  0.8630    cv_loss:  0.5787    cv_accuracy:  0.7143
Iter: 281    loss:  0.4621    accuracy:  0.8630    cv_loss:  0.5786    cv_accura

Iter: 361    loss:  0.4242    accuracy:  0.8973    cv_loss:  0.5660    cv_accuracy:  0.7302
Iter: 362    loss:  0.4238    accuracy:  0.8973    cv_loss:  0.5659    cv_accuracy:  0.7302
Iter: 363    loss:  0.4233    accuracy:  0.8973    cv_loss:  0.5658    cv_accuracy:  0.7302
Iter: 364    loss:  0.4229    accuracy:  0.8973    cv_loss:  0.5656    cv_accuracy:  0.7302
Iter: 365    loss:  0.4225    accuracy:  0.8973    cv_loss:  0.5655    cv_accuracy:  0.7302
Iter: 366    loss:  0.4220    accuracy:  0.8973    cv_loss:  0.5653    cv_accuracy:  0.7302
Iter: 367    loss:  0.4216    accuracy:  0.8973    cv_loss:  0.5652    cv_accuracy:  0.7302
Iter: 368    loss:  0.4212    accuracy:  0.8973    cv_loss:  0.5651    cv_accuracy:  0.7302
Iter: 369    loss:  0.4207    accuracy:  0.8973    cv_loss:  0.5649    cv_accuracy:  0.7302
Iter: 370    loss:  0.4203    accuracy:  0.8973    cv_loss:  0.5648    cv_accuracy:  0.7302
Iter: 371    loss:  0.4199    accuracy:  0.8973    cv_loss:  0.5647    cv_accura