In [110]:
from torch import manual_seed, tensor, nonzero, logical_not, load, save
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split, IterableDataset, Dataset, sampler
from torchvision.transforms import Compose, ToTensor, Normalize, RandomRotation
from torchvision.datasets.mnist import MNIST 
from torchshow import show
from CNN_setup.model.CIFAR_CNN import CIFAR_CNN_Classifier
from CNN_setup.vars.CIFARvars import CIFAR10_classes
from CNN_setup.datasets.datasets import CustomCIFAR10, CustomMNIST

from CNN_setup.model.MNIST_CNN import Mnist_CNN_Classifier
from CNN_setup.vars.MNISTvars import MNIST_classes

from CNN_setup.utils.cnn_models_utils import load_model, evaluate

from PIL import Image
from torchshow import show

from CNN_setup.datasets.dataset_tools import save_dataset

from torchvision.datasets import ImageFolder

# Incremental, abrupt, and transformations

In [2]:
transform = Compose([ToTensor()])

In [3]:
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = CIFAR10(root='./data', train=False, download=False, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32)

Files already downloaded and verified


### Withhold a class

In [95]:
raw_data = CIFAR10(root='./data', train=False, download=True)

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ./data
    Split: Test

In [None]:
save_dataset(path='data/transformed/cifar-rotated90', data=raw_data, transform=Image.Image.rotate, args=(90,))

In [None]:
save_dataset(path='data/transformed/cifar-rotated90', data=raw_data, transform=Image.Image.rotate, args=(90,))

In [115]:
rotated = ImageFolder(root='data/transformed/cifar-rotated90', transform=ToTensor())
dataloader_rotated = DataLoader(dataset=rotated, batch_size = 32)

In [116]:
model = load_model('trained_models\CNN_cifar_downloaded.torch', CIFAR_CNN_Classifier())

In [118]:
evaluate(model=model, test_dataloader=dataloader_rotated, classes=CIFAR10_classes)

Accuracy for class: plane is 40.0 %
Accuracy for class: car is 1.1 %
Accuracy for class: bird is 10.5 %
Accuracy for class: cat is 9.3 %
Accuracy for class: deer is 68.6 %
Accuracy for class: dog is 3.0 %
Accuracy for class: frog is 12.0 %
Accuracy for class: horse is 0.9 %
Accuracy for class: ship is 19.2 %
Accuracy for class: truck is 11.0 %
Total Accuracy: 17.6 %


### Check MNIST rotate

In [None]:
transform_rotate = Compose([ToTensor(), RandomRotation(degrees=90)])

In [None]:
CustomerMnist_obj = CustomMNIST(root='./data', train=True, download=True, transform=transform_rotate)
CustomerMnist_obj_dataloader = DataLoader(CustomerMnist_obj, batch_size=32)

MNIST_model = load_model("trained_models/CNN_mnist_downloaded.torch",Mnist_CNN_Classifier())

In [None]:
evaluate(CustomerMnist_obj_dataloader,MNIST_model,classes = MNIST_classes)

Accuracy for class: 0 is 94.5 %
Accuracy for class: 1 is 96.9 %
Accuracy for class: 2 is 89.7 %
Accuracy for class: 3 is 85.2 %
Accuracy for class: 4 is 91.1 %
Accuracy for class: 5 is 90.6 %
Accuracy for class: 6 is 93.0 %
Accuracy for class: 7 is 90.5 %
Accuracy for class: 8 is 89.6 %
Accuracy for class: 9 is 86.9 %
Total Accuracy: 90.9 %


In [None]:
evaluate(CustomerMnist_obj_dataloader,MNIST_model,classes = MNIST_classes)

Accuracy for class: 0 is 94.5 %
Accuracy for class: 1 is 96.9 %
Accuracy for class: 2 is 89.7 %
Accuracy for class: 3 is 85.2 %
Accuracy for class: 4 is 91.1 %
Accuracy for class: 5 is 90.6 %
Accuracy for class: 6 is 93.0 %
Accuracy for class: 7 is 90.5 %
Accuracy for class: 8 is 89.6 %
Accuracy for class: 9 is 86.9 %
Total Accuracy: 90.9 %


## Incremental

In [1]:
from torch.utils.data import ConcatDataset

def combine_datasets(dataset1, dataset2, threshold1, threshold2):
    # Take the first 'threshold' samples from each dataset
    data1 = [dataset1[i] for i in range(threshold1)]
    data2 = [dataset2[i] for i in range(threshold2)]
    
    # Combine the two datasets
    combined_data = data1 + data2
    
    return combined_data