In [1]:
import argparse
import torch
import torch.nn as nn
import os
import resnet
import numpy as np
import cv2
import matplotlib.pyplot as plt

from torch.autograd import Variable
from torchvision import transforms
from PIL import Image

TRASH_DICT = {
'1' : 'glass',
'2' : 'metal',
'3' : 'paper',
'4' : 'plastic',
'5' : 'metal',
'6' : 'trash'
}

In [2]:
def variance_of_laplacian(image):
    return cv2.Laplacian(image, cv2.CV_64F).var()


In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [4]:
model = nn.DataParallel(resnet.resnet18(pretrained=True, num_classes=len(TRASH_DICT)))
checkpoint = torch.load('save/model_best.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

model.load_state_dict(state_dict)
model.eval()

pre-trained model loaded successfully


DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [5]:
def inference(save_dir):
    frame = save_dir
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    
    img_transforms = transforms.Compose([transforms.CenterCrop(224),                
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=MEAN, std=STD)])

    image_tensor = img_transforms(frame).float()
    image_tensor = image_tensor.unsqueeze_(0)
    image_tensor.to(device)
    
    softmax = nn.Softmax(dim=1)
    output = model(Variable(image_tensor))
    pred = softmax(output[0].data).numpy()
    trash_idx = str(pred.argmax()+1)
    pred_class, confidence = TRASH_DICT[trash_idx], pred.max()

    return pred_class, confidence


In [6]:
#from skimage import io  # Only needed for web grabbing images, use cv2.imread for local images

# Read images from web
#img_bg = cv2.cvtColor(io.imread('stepbystep/frame103.jpg'), cv2.COLOR_RGB2BGR)
#img = cv2.cvtColor(io.imread('stepbystep/frame212.jpg'), cv2.COLOR_RGB2BGR)

#capture = cv2.VideoCapture(0)
capture = cv2.VideoCapture('test1.mp4')


# Set up and feed background subtractor (cf. tutorial linked in question)
backSub = cv2.createBackgroundSubtractorMOG2(varThreshold=5, detectShadows = False)
i=0
fm=101
while True:
    ret, frame = capture.read()
    
    
    if frame is None:
        break
    if i==0:
        
        roi = cv2.selectROI(frame)
        sfondo = roi
        
    if i==10:
        
        sfondo = frame[int(roi[1]):int(roi[1]+roi[3]),
                      int(roi[0]):int(roi[0]+roi[2])]
        
    if (i%5==0):
        #cv2.imshow('img', frame)
        _ = backSub.apply(sfondo)
        frame_roi = frame[int(roi[1]):int(roi[1]+roi[3]),
                      int(roi[0]):int(roi[0]+roi[2])]
        mask = backSub.apply(frame_roi)

        # Morphological opening and closing to improve mask
        mask_morph = cv2.morphologyEx(mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21)))
        mask_morph = cv2.morphologyEx(mask_morph, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21)))

        # Generate output
        output = cv2.bitwise_and(frame_roi, frame_roi, None, mask_morph)
        
        contours, _ = cv2.findContours(mask_morph, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            #Calculate area and remove small elements
            area = cv2.contourArea(cnt)
            
            if area > 9000:
                #cv.drawContours(frame, [cnt], -1, (0,255,0),2)
                x,y,w,h = cv2.boundingRect(cnt)
                cv2.rectangle(frame_roi, (x,y), (x+w,y+h), (0,255,0),2)
                # Cropping an image
                if(x!=0 and y!=0):
                    cropped_image = frame_roi[y+2:y+h-2, x+2:x+w-2]

                    # Display cropped image
                    cv2.imshow("last cropped", cropped_image)
                    
                    ##############
                    gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
                    fm = variance_of_laplacian(gray)
                    
                    ##############
                    
                    if fm > 500:
                    
                        # You may need to convert the color.
                        img = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
                        im_pil = Image.fromarray(img)

                        #Classification    
                        pred_class, confidence = inference(im_pil)
                        print(f'Prediction: {pred_class}, Confidence: {confidence}')

                        # Save the cropped image
                        save_dir = "output/Garbage_"+str(i)+"_"+pred_class+".jpg"
                        cv2.putText(cropped_image, "{}: {:.2f}".format("fm ", fm), (10, 30),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 3)
                        cv2.imwrite(save_dir, cropped_image)
                    
        

        # Visualization
        cv2.imshow('img', frame)
        #cv2.imshow('mask', mask)
        #cv2.imshow('mask_morph', mask_morph)
        cv2.imshow('output', mask_morph)
        if fm > 500:
            sfondo = frame_roi
    i=i+1
    
    keyboard = cv2.waitKey(30)
    if keyboard == 'q' or keyboard == 27:
        break
        
capture.release()
cv2.destroyAllWindows()