In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision import datasets, models, transforms
import torch.utils.data as data
import matplotlib.pyplot as plt

import pickle as pkl
import copy
import os

from sklearn.metrics import confusion_matrix
import sys
import utils

In [2]:
all_class = os.listdir("../Images/periodic/all")
print(len(all_class))

remove_classes = ["stone-wall4.o"]
test_classes   = all_class[:5]
all_class = [ i for i in all_class if i not in remove_classes ]

train_target_classes = [ i for i in all_class if i not in test_classes ]
test_target_classes =  [ i for i in all_class if i in test_classes ]

print(len(train_target_classes))
print(len(test_target_classes))

55
49
5


In [3]:
train_packs = utils.get_paths(train_target_classes)
test_packs  = utils.get_paths(test_target_classes)
train_transforms = utils.data_transformer_torch_train()
test_transforms  = utils.data_transformer_torch_test()

In [4]:
datasets_train = utils.Img_Dataset(file_list=train_packs[0],transform=train_transforms,labels=train_packs[1],class_labels=train_packs[2])
datasets_test  = utils.Img_Dataset(file_list=test_packs[0] ,transform=test_transforms,labels=test_packs[1],class_labels=test_packs[2])

dataloader_train = torch.utils.data.DataLoader(datasets_train, batch_size=8, shuffle=True,num_workers=8)
dataloader_test  = torch.utils.data.DataLoader(datasets_test , batch_size=8, shuffle=False,num_workers=8)

dataloaders  = {"train":dataloader_train,"val":dataloader_test }
dataset_sizes ={"train":len(datasets_train),"val":len(datasets_test)}

In [5]:
model = models.vgg16(pretrained=True)
model.classifier = nn.Sequential(
        nn.Linear(25088,100),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(100,1) ,nn.Sigmoid())

In [11]:
#criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [14]:
model,loss_dict,acc_dict = utils.training_model(dataloaders,dataset_sizes,model,criterion,
                     optimizer,scheduler,num_epochs=5)

Epoch 1/5
----------
---train---




train Loss: 0.5654 ,ACC:0.6925
---val---
val Loss: 0.5119 ,ACC:0.7300
Epoch 2/5
----------
---train---
train Loss: 0.4630 ,ACC:0.7770
---val---
val Loss: 0.4699 ,ACC:0.7100
Epoch 3/5
----------
---train---
   200/   982

KeyboardInterrupt: 

In [36]:
def test_model(dataloaders,dataset_sizes,model,criterion):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    sum_img = 0
    model.eval()
    
    all_labels = []
    all_preds  = []
    all_clses  = []
    
    running_loss = 0.0
    running_corrects = 0.
    
    phase="val"
    for inputs, labels, cls in dataloaders[phase]:
        
        sum_img += inputs.size(0)
        print("{:6}/{:6}".format(sum_img,dataset_sizes[phase]),end="\r")

        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        preds = model(inputs)
        labels = labels.view_as(preds)
        loss = criterion(preds,labels)
        
        running_loss += loss.item() * inputs.size(0)
        running_corrects  += torch.sum( (preds>0.5) == labels ).item()
        
        all_labels += list(labels.to("cpu").numpy().reshape(-1))
        all_preds  += list(preds.detach().to("cpu").numpy().reshape(-1))
        all_clses  += cls
        
    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc  = running_corrects / dataset_sizes[phase]
    print('Loss: {:.4f} ,ACC:{:.4f}'.format(epoch_loss,epoch_acc))
    return epoch_loss,epoch_acc,np.array(all_labels),np.array(all_preds),all_clses

In [9]:
with open("models/log.pkl","wb") as f:
    pkl.dump([loss_dict,acc_dict],f)

torch.save(model.state_dict(),"models/final_model_wts.pt")

In [54]:
loss,acc,labels,preds,classes = test_model(dataloaders,dataset_sizes,model,criterion)
corect_index = labels == (preds > 0.5)
print(sum( corect_index )/100)
print( np.array(classes)[correct_index] )

Loss: 0.6974 ,ACC:0.5400
0.54
['Fabric.0008_s.o' 'Fabric.0008_s.o' 'Fabric.0008_s.o' 'Fabric.0008_s.o'
 'Fabric.0008_s.o' 'Fabric.0008_s.o' 'sawtooth-wiggle.o'
 'sawtooth-wiggle.o' 'sawtooth-wiggle.o' 'sawtooth-wiggle.o'
 'sawtooth-wiggle.o' 'sawtooth-wiggle.o' 'd30_2180.o' 'd30_2180.o'
 'd30_2180.o' 'X100_0628.o' 'X100_0628.o' 'checkerboard.o'
 'checkerboard.o' 'checkerboard.o' 'checkerboard.o' 'Fabric.0008_s.o'
 'Fabric.0008_s.o' 'Fabric.0008_s.o' 'Fabric.0008_s.o' 'Fabric.0008_s.o'
 'Fabric.0008_s.o' 'Fabric.0008_s.o' 'sawtooth-wiggle.o'
 'sawtooth-wiggle.o' 'sawtooth-wiggle.o' 'sawtooth-wiggle.o'
 'sawtooth-wiggle.o' 'sawtooth-wiggle.o' 'd30_2180.o' 'd30_2180.o'
 'd30_2180.o' 'd30_2180.o' 'd30_2180.o' 'd30_2180.o' 'X100_0628.o'
 'X100_0628.o' 'X100_0628.o' 'X100_0628.o' 'X100_0628.o' 'X100_0628.o'
 'X100_0628.o' 'X100_0628.o' 'X100_0628.o' 'X100_0628.o' 'checkerboard.o'
 'checkerboard.o' 'checkerboard.o' 'checkerboard.o']
