In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from CovTypeDataset import CovTypeDataset, PermutedCovTypeDataset
from torch.utils.data import DataLoader
from utils import test

In [2]:
dataset = CovTypeDataset('./covtype/covtype.data')
permuted_dataset = PermutedCovTypeDataset('./covtype/covtype.data', 6)

In [3]:
class MLP(nn.Module):
    def __init__(self, num_features, num_classes, hidden_size=400):
        super(MLP, self).__init__()
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(num_features, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, num_classes)

    def forward(self, input):
        x = F.leaky_relu(self.fc1(input))
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x))
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x))
        x = self.dropout(x)
        x = F.leaky_relu(self.fc4(x))
        return x

In [4]:
model = torch.load('model')
model.eval()

model_ewc = torch.load('model_ewc')
model.eval()

MLP(
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=54, out_features=400, bias=True)
  (fc2): Linear(in_features=400, out_features=400, bias=True)
  (fc3): Linear(in_features=400, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=7, bias=True)
)

In [5]:
combined_dataset = torch.utils.data.ConcatDataset([dataset, permuted_dataset])
combined_dataloader = DataLoader(dataset=combined_dataset, batch_size=2000)
acc = test(model, combined_dataloader) * 100
acc_ewc = test(model_ewc, combined_dataloader) * 100

In [6]:
print(f'No EWC Accuracy: {round(acc.item(), 2)}%')
print(f'EWC Accuracy: {round(acc_ewc.item(), 2)}%')

No EWC Accuracy: 67.95%
EWC Accuracy: 82.7%
