In [25]:
cfg_path = 'configs/cfg_foccsd.yaml'
split = 'test'
output_folder = '/datadrive/animals_training_dataset/predictions/clean_bs128_w8_cl_wts'

In [26]:
import yaml
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch

from util import init_seed
from train import create_dataloader, load_model 

In [27]:
cfg = yaml.safe_load(open(cfg_path, 'r'))
cfg['model_dir'] = os.path.join('..', '..', cfg['model_dir'])
init_seed(cfg.get('seed', None))
device = cfg['device']
if device != 'cpu' and not torch.cuda.is_available():
    print(f'WARNING: device set to "{device}" but CUDA not available; falling back to CPU...')
    cfg['device'] = 'cpu'

dataLoader = create_dataloader(cfg, split=split)
classnames = dict([v,k] for k,v in dataLoader.dataset.species_to_index_mapping.items())
model, epoch = load_model(cfg)
model.to(device)
model.eval()

err_name = []
err_pred=[]
err_orig=[]
with torch.no_grad():
    for idx, (data, label, image_path) in enumerate(dataLoader):
        data = data.to(device)
        prediction = model(data)
        predict_label = torch.argmax(prediction.cpu(), dim=1)

        
        error = torch.nonzero(predict_label != label)
        
        for err in error:
            _, fname = os.path.split(image_path[err])
            plt.figure()
            plt.imshow(Image.open(image_path[err]))
            plt.title(f'pred: {classnames[predict_label[err].item()]}; actual: {classnames[label[err].item()]}; name: {fname}')
            plt.show()
            print(f'{classnames[label[err].item()]}/{fname}')
            dest = os.path.join(f'figs/errors_foccsd/{classnames[predict_label[err].item()]}')
            os.makedirs(dest, exist_ok=True)
            #plt.savefig(os.path.join(dest, fname))
            #print(image_path[err])
            pred = classnames[predict_label[err].item()]
            orig = classnames[label[err].item()]
            err_name.append(fname)
            err_pred.append(pred)
            err_orig.append(orig)


            break
        

    



Resuming from epoch 200


In [None]:
import pandas as pd
tuple_list = list(zip(err_pred, err_orig, err_name))
df = pd.DataFrame(tuple_list, columns = ['predicted', 'original', 'name'])
#print(df)
df.to_csv('erros_foccsd')