In [None]:
import cocpit

import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import cv2
from natsort import natsorted
from PIL import Image

import torch
from torchvision import transforms
from torch.utils.data import Dataset

%load_ext autoreload
%autoreload 2

In [None]:
torch.cuda.empty_cache() 

In [None]:
plt_params = {'axes.labelsize': 'xx-large',
         'axes.titlesize':'xx-large',
         'xtick.labelsize':'x-large',
         'ytick.labelsize':'xx-large'}
plt.rcParams["font.family"] = "serif"
plt.rcParams.update(plt_params)

### check classifications from df or db file

In [None]:
campaign='MACPEX'
df = pd.read_csv('final_databases_v2/no_mask/'+campaign+'.csv')
desired_size = 1000
for file, class_ in zip(df['filename'], df['classification']):
    #print(file)
    image = cocpit.pic.Image('cpi_data/campaigns/'+campaign+'/single_imgs/', file)
    image.resize_stretch(desired_size)
    print(class_)
    if class_ != 'blurry':
        plt.imshow(image.image_og)
        plt.show()


### check classifications from specific model and validation dataloader

In [None]:
class_names = ['aggs','blank','blurry','budding',
              'bullets','columns','compact irregulars',
              'fragments','needles','plates','rimed aggs',
              'rimed columns','spheres']
model = torch.load('/data/data/saved_models/no_mask/e20_bs128_k0_1models_vgg19').cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_data = torch.load('/data/data/saved_models/no_mask/val_data_vgg19_e20_b128.pt')

val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=128,
                                         shuffle=True,
                                         num_workers=20,
                                         pin_memory=True)


In [None]:
model.eval()
for batch_idx, ((imgs, labels, paths), index) in enumerate(val_loader):
    #predictions = model_ft(imgs)
    #preds = torch.max(predictions, 1).indices.tolist()    
    for path, label in zip(paths, labels):
        probs, classes = cocpit.check_classifications.predict(path, device, model)  
        label = label.numpy()
        crystal_names = [class_names[e] for e in classes]
        if crystal_names[0] != class_names[label]:
            print('labeled as: ', class_names[label])
            cocpit.check_classifications.view_classify(path, probs, crystal_names)

In [None]:
#VGG NO OVERFIT, 10 epochs
class_names = ['aggs','blank','blurry','budding',
              'bullets','columns','compact irregulars',
              'fragments','needles','plates','rimed aggs',
              'rimed columns','spheres']
model = torch.load('/data/data/saved_models/no_mask/e10_bs128_k0_1models_vgg19').cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_data = torch.load('/data/data/saved_models/no_mask/val_data_no_overfit.pt')

val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=128,
                                         shuffle=True,
                                         num_workers=20,
                                         pin_memory=True)


In [None]:
model.eval()
for batch_idx, ((imgs, labels, paths), index) in enumerate(val_loader):
    #predictions = model_ft(imgs)
    #preds = torch.max(predictions, 1).indices.tolist()    
    for path in paths:
        probs, classes = cocpit.check_classifications.predict(path, device, model)  
        crystal_names = [class_names[e] for e in classes]
        cocpit.check_classifications.view_classify(path, probs, crystal_names)

# Predict on new data - Test Data

In [None]:
#from raw image folder (i.e., single_imgs)
class TestDataSet(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc)
        #print(image)
        #image =cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        tensor_image = self.transform(image)
        path = self.total_imgs[idx]
        return tensor_image, path

model = torch.load('/data/data/saved_models/no_mask/e50_bs128_k0_8models_vgg19').cuda()
model.eval()
campaign = 'ARM'
data_dir = '/data/data/cpi_data/campaigns/'+campaign+'/single_imgs/'
#save_dir = 'cpi_data/campaigns/'+campaign+'/'

#apply same transforms
test_transforms = transforms.Compose([transforms.Resize(224),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])])

testdata = TestDataSet(data_dir, transform=test_transforms)
test_loader = torch.utils.data.DataLoader(testdata, batch_size=100, shuffle=False, 
                               num_workers=20, drop_last=True)

In [None]:
model.eval()
for batch_idx, (imgs, labels, paths) in enumerate(val_loader):
    #predictions = model_ft(imgs)
    #preds = torch.max(predictions, 1).indices.tolist()    
    for path in paths:
        probs, classes = cocpit.check_classifications.predict(path, device, model)  
        crystal_names = [class_names[e] for e in classes]
        cocpit.check_classifications.view_classify(path, probs, crystal_names)

In [None]:
model.eval()
for batch_idx, (imgs, img_paths) in enumerate(test_loader):
    for im in img_paths:
        path = data_dir+im
        img_og = Image.open(path)
        img = img_og.convert('RGB')
        img = cocpit.check_classifications.process_image(img)

        # Convert 2D image to 1D vector
        img = np.expand_dims(img, 0)

        img = torch.from_numpy(img)
        prediction = model(img)
        cpu_pred = prediction.cpu()
        result = cpu_pred.data.numpy()
        print(class_names[result.argmax()])
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(img_og)
        plt.show()