In [17]:
import torchray
import torch
import attr
from pathlib import Path
from torchvision import datasets, transforms
from torchray.benchmark import get_example_data, plot_example
from torchvision import models
import torchvision
import matplotlib.pyplot as plt
from sklearn import metrics as sk_metrics
import numpy as np

In [18]:
@attr.s(auto_attribs=True)
class DataPreparation:
    data_dir: Path
    device: str = attr.ib(default=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), init=False)

    @staticmethod
    def data_transformations():
        data_transforms = {
            'test': transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        }
        return data_transforms
    
    def create_dataloaders(self, batch_size, shuffle, num_workers):
        data_transforms = self.data_transformations()
        
        image_datasets = {
            'test': datasets.ImageFolder(self.data_dir, data_transforms['test'])
        }
        dataloaders = {
            'test': torch.utils.data.DataLoader(image_datasets['test'], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
        }
        dataset_sizes = {
            'test': len(image_datasets['test'])
        }
        return dataloaders, dataset_sizes

In [19]:
data_prep = DataPreparation('../data/figures/test')
data, size = data_prep.create_dataloaders(16, False, 4)

In [20]:
model_path = Path().cwd().parents[0] / "models/resnet50_d_28_t_15_49.pth"
model = torch.load(model_path)
model = model.eval()

In [23]:
y_pred = []
y_true = []

for inputs, labels in data['test']:
    inputs = inputs.to(0)
    labels = labels.to(0)
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    y_pred.append(preds)
    y_true.append(labels)

y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)
sk_metrics.accuracy_score(y_pred, y_true)
sk_metrics.f1_score(y_pred, y_true)

KeyboardInterrupt: 