In [None]:
import matplotlib.pyplot as plt 
import os
import glob
import shutil
import numpy as np
import cv2
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch
import torchvision
from torch.nn.modules.loss import BCEWithLogitsLoss
from torch.optim import lr_scheduler
from tqdm import tqdm

from IPython.core.pylabtools import figsize


In [None]:
# data_dir = "./data"
# test_dir = "./data"
# train_dir = "./data"

In [None]:
dat_dir = "/content/data"

#create training dir
training_dir = os.path.join(data_dir,"training")
if not os.path.isdir(training_dir):
  os.mkdir(training_dir)

#create correct_seq in training
correct_seq_training_dir = os.path.join(training_dir,"correct_seq")
if not os.path.isdir(correct_seq_training_dir):
  os.mkdir(correct_seq_training_dir)

#create wrong_seq in training
wrong_seq_training_dir = os.path.join(training_dir,"wrong_seq")
if not os.path.isdir(wrong_seq_training_dir):
  os.mkdir(wrong_seq_training_dir)

#create validation dir
validation_dir = os.path.join(data_dir,"validation")
if not os.path.isdir(validation_dir):
  os.mkdir(validation_dir)

#create correct_seq in validation
correct_seq_validation_dir = os.path.join(validation_dir,"correct_seq")
if not os.path.isdir(correct_seq_validation_dir):
  os.mkdir(correct_seq_validation_dir)

#create wrong_seq in validation
wrong_seq_validation_dir = os.path.join(validation_dir,"wrong_seq")
if not os.path.isdir(wrong_seq_validation_dir):
  os.mkdir(wrong_seq_validation_dir)

In [None]:
split_size = 0.80
wrong_seq_imgs_size = len(glob.glob("/content/data/train/wrong_seq*"))
correct_seq_imgs_size = len(glob.glob("/content/data/train/correct_seq*"))

for i,img in enumerate(glob.glob("/content/data/train/wrong_seq*")):
  if i < (wrong_seq_imgs_size * split_size):
    shutil.move(img,wrong_seq_training_dir)
  else:
    shutil.move(img,wrong_seq_validation_dir)

for i,img in enumerate(glob.glob("/content/data/train/correct_seq*")):
  if i < (correct_seq_imgs_size * split_size):
    shutil.move(img,correct_seq_training_dir)
  else:
    shutil.move(img,correct_seq_validation_dir)

In [None]:
samples_correct_seq = [os.path.join(correct_seq_training_dir,np.random.choice(os.listdir(correct_seq_training_dir),1)[0]) for _ in range(8)]
samples_wrong_seq = [os.path.join(wrong_seq_training_dir,np.random.choice(os.listdir(wrong_seq_training_dir),1)[0]) for _ in range(8)]

nrows = 4
ncols = 4

fig, ax = plt.subplots(nrows,ncols,figsize = (10,10))
ax = ax.flatten()

for i in range(nrows*ncols):
  if i < 8:
    pic = plt.imread(samples_correct_seq[i%8])
    ax[i].imshow(pic)
    ax[i].set_axis_off()
  else:
    pic = plt.imread(samples_wrong_seq[i%8])
    ax[i].imshow(pic)
    ax[i].set_axis_off()
plt.show()

In [None]:
traindir = "/content/data/training"
testdir = "/content/data/validation"

#transformations
train_transforms = transforms.Compose([transforms.Resize((224,224)),
                                       transforms.ToTensor(),                                
                                       torchvision.transforms.Normalize(
                                           mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225],
    ),
                                       ])
test_transforms = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(),
                                      torchvision.transforms.Normalize(
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225],
    ),
                                      ])

#datasets
train_data = datasets.ImageFolder(traindir,transform=train_transforms)
test_data = datasets.ImageFolder(testdir,transform=test_transforms)

#dataloader
trainloader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size=16)
testloader = torch.utils.data.DataLoader(test_data, shuffle = True, batch_size=16)

In [None]:
def make_train_step(model, optimizer, loss_fn):
  def train_step(x,y):
    #make prediction
    yhat = model(x)
    #enter train mode
    model.train()
    #compute loss
    loss = loss_fn(yhat,y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    #optimizer.cleargrads()

    return loss
  return train_step

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet50(pretrained=True)

#freeze all params
for params in model.parameters():
  params.requires_grad_ = False

#add a new final layer
nr_filters = model.fc.in_features  #number of input features of last layer
model.fc = nn.Linear(nr_filters, 1)

model = model.to(device)

In [None]:
#loss
loss_fn = BCEWithLogitsLoss() #binary cross entropy with sigmoid, so no need to use sigmoid in the model

#optimizer
optimizer = torch.optim.Adam(model.fc.parameters()) 

#train step
train_step = make_train_step(model, optimizer, loss_fn)

In [None]:
%%capture

losses = []
val_losses = []

epoch_train_losses = []
epoch_test_losses = []

n_epochs = 10
early_stopping_tolerance = 3
early_stopping_threshold = 0.03

for epoch in range(n_epochs):
  epoch_loss = 0
  for i ,data in tqdm(enumerate(trainloader), total = len(trainloader)): #iterate ove batches
    x_batch , y_batch = data
    x_batch = x_batch.to(device) #move to gpu
    y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
    y_batch = y_batch.to(device) #move to gpu


    loss = train_step(x_batch, y_batch)
    epoch_loss += loss/len(trainloader)
    losses.append(loss)
    
  epoch_train_losses.append(epoch_loss)
  print('\nEpoch : {}, train loss : {}'.format(epoch+1,epoch_loss))

  #validation doesnt requires gradient
  with torch.no_grad():
    cum_loss = 0
    for x_batch, y_batch in testloader:
      x_batch = x_batch.to(device)
      y_batch = y_batch.unsqueeze(1).float() #convert target to same nn output shape
      y_batch = y_batch.to(device)

      #model to eval mode
      model.eval()

      yhat = model(x_batch)
      val_loss = loss_fn(yhat,y_batch)
      cum_loss += loss/len(testloader)
      val_losses.append(val_loss.item())


    epoch_test_losses.append(cum_loss)
    print('Epoch : {}, val loss : {}'.format(epoch+1,cum_loss))  
    
    best_loss = min(epoch_test_losses)
    
    #save best model
    if cum_loss <= best_loss:
      best_model_wts = model.state_dict()
    
    #early stopping
    early_stopping_counter = 0
    if cum_loss > best_loss:
      early_stopping_counter +=1

    if (early_stopping_counter == early_stopping_tolerance) or (best_loss <= early_stopping_threshold):
      print("/nTerminating: early stopping")
      break #terminate training
    
#load best model
model.load_state_dict(best_model_wts)

In [None]:
def inference(test_data):
  idx = torch.randint(1, len(test_data), (1,))
  sample = torch.unsqueeze(test_data[idx][0], dim=0).to(device)

  if torch.sigmoid(model(sample)) < 0.5:
    print("Prediction : correct_seq")
  else:
    print("Prediction : wrong_seq")


  plt.imshow(test_data[idx][0].permute(1, 2, 0))

In [None]:
#iterate_validation_images
# plot_confusion_matrix