In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
!pwd
!ls
%cd drive
%cd MyDrive

/content
drive  sample_data
/content/drive
/content/drive/MyDrive


In [3]:
!git clone https://github.com/mtoneva/example_forgetting.git

Cloning into 'example_forgetting'...
remote: Enumerating objects: 194, done.[K
remote: Total 194 (delta 0), reused 0 (delta 0), pack-reused 194[K
Receiving objects: 100% (194/194), 566.18 KiB | 4.46 MiB/s, done.
Resolving deltas: 100% (102/102), done.


In [4]:
%cd example_forgetting/

/content/drive/My Drive/example_forgetting


In [5]:
!pip install -r requirements.txt


Collecting torch==0.4.1.post2
  Downloading torch-0.4.1.post2-cp37-cp37m-manylinux1_x86_64.whl (519.5 MB)
[K     |████████████████████████████████| 519.5 MB 21 kB/s 
[?25hCollecting torchvision==0.1.8
  Downloading torchvision-0.1.8-py2.py3-none-any.whl (37 kB)
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.9.0+cu102
    Uninstalling torch-1.9.0+cu102:
      Successfully uninstalled torch-1.9.0+cu102
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.10.0+cu102
    Uninstalling torchvision-0.10.0+cu102:
      Successfully uninstalled torchvision-0.10.0+cu102
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchtext 0.10.0 requires torch==1.9.0, but you have torch 0.4.1.post2 which is incompatible.
fastai 1.0.61 requires torch>=1.0.0, but you have tor

In [6]:
from __future__ import print_function
import argparse
import numpy as np
import numpy.random as npr
import time
import os
import sys
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms

In [7]:
# Format time for printing purposes
def get_hms(seconds):
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)

    return h, m, s


In [8]:
# Setup basic CNN model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))

        if args['no_dropout']:
            x = F.relu(F.max_pool2d(self.conv2(x), 2))
        else:
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))

        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))

        if not args['no_dropout']:
            x = F.dropout(x, training=self.training)

        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


In [9]:
# Train model for one epoch
#
# example_stats: dictionary containing statistics accumulated over every presentation of example
#
def train(args, model, device, trainset, optimizer, epoch, example_stats):
    train_loss = 0
    correct = 0
    total = 0
    batch_size = args['batch_size']

    model.train()

    # Get permutation to shuffle trainset
    trainset_permutation_inds = npr.permutation(
        np.arange(len(trainset.train_labels)))

    for batch_idx, batch_start_ind in enumerate(
            range(0, len(trainset.train_labels), batch_size)):

        # Get trainset indices for batch
        batch_inds = trainset_permutation_inds[batch_start_ind:
                                               batch_start_ind + batch_size]

        # Get batch inputs and targets, transform them appropriately
        transformed_trainset = []
        for ind in batch_inds:
            transformed_trainset.append(trainset.__getitem__(ind)[0])
        inputs = torch.stack(transformed_trainset)
        targets = torch.LongTensor(
            np.array(trainset.train_labels)[batch_inds].tolist())

        # Map to available device
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward propagation, compute loss, get predictions
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        _, predicted = torch.max(outputs.data, 1)

        # Update statistics and loss
        acc = predicted == targets
        for j, index in enumerate(batch_inds):

            # Get index in original dataset (not sorted by forgetting)
            index_in_original_dataset = train_indx[index]

            # Compute missclassification margin
            output_correct_class = outputs.data[
                j, targets[j].item()]  # output for correct class
            sorted_output, _ = torch.sort(outputs.data[j, :])
            if acc[j]:
                # Example classified correctly, highest incorrect class is 2nd largest output
                output_highest_incorrect_class = sorted_output[-2]
            else:
                # Example misclassified, highest incorrect class is max output
                output_highest_incorrect_class = sorted_output[-1]
            margin = output_correct_class.item(
            ) - output_highest_incorrect_class.item()

            # Add the statistics of the current training example to dictionary
            index_stats = example_stats.get(index_in_original_dataset,
                                            [[], [], []])
            index_stats[0].append(loss[j].item())
            index_stats[1].append(acc[j].sum().item())
            index_stats[2].append(margin)
            example_stats[index_in_original_dataset] = index_stats

        # Update loss, backward propagate, update optimizer
        loss = loss.mean()
        train_loss += loss.item()
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        loss.backward()
        optimizer.step()

        sys.stdout.write('\r')
        sys.stdout.write(
            '| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' %
            (epoch, args['epochs'], batch_idx + 1,
             (len(trainset) // batch_size) + 1, loss.item(),
             100. * correct.item() / total))
        sys.stdout.flush()

        # Add training accuracy to dict
        index_stats = example_stats.get('train', [[], []])
        index_stats[1].append(100. * correct.item() / float(total))
        example_stats['train'] = index_stats



In [10]:
# Evaluate model predictions on heldout test data
#
# example_stats: dictionary containing statistics accumulated over every presentation of example
#
def test(args, model, device, testset, example_stats):
    test_loss = 0
    correct = 0
    total = 0
    test_batch_size = 32

    model.eval()

    for batch_idx, batch_start_ind in enumerate(
            range(0, len(testset.test_labels), test_batch_size)):

        # Get batch inputs and targets
        transformed_testset = []
        for ind in range(
                batch_start_ind,
                min(
                    len(testset.test_labels),
                    batch_start_ind + test_batch_size)):
            transformed_testset.append(testset.__getitem__(ind)[0])
        inputs = torch.stack(transformed_testset)
        targets = torch.LongTensor(
            np.array(testset.test_labels)[batch_start_ind:batch_start_ind +
                                          test_batch_size].tolist())

        # Map to available device
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward propagation, compute loss, get predictions
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss = loss.mean()
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    # Add test accuracy to dict
    acc = 100. * correct.item() / total
    index_stats = example_stats.get('test', [[], []])
    index_stats[1].append(100. * correct.item() / float(total))
    example_stats['test'] = index_stats
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %
          (epoch, loss.item(), acc))



In [11]:
args = {'dataset': 'permuted_mnist',
        'batch_size': 64,
        'epochs':200,
        'lr':0.01,
        'momentum':0.5,
        'no_cuda':False,
        'seed':2,
        'sorting_file':"none",
        'remove_n':0,
        'keep_lowest_n':0,
        'no_dropout':False,
        'input_dir':'permuted_mnist_results/',
        'output_dir':'permuted_mnist_results/'

        }


# Set appropriate devices
use_cuda = not args['no_cuda'] and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# Set random seed for initialization
torch.manual_seed(args['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(args['seed'])
npr.seed(args['seed'])

# Setup transforms
all_transforms = [
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
]
if args['dataset'] == 'permuted_mnist':
    pixel_permutation = torch.randperm(28 * 28)
    all_transforms.append(
        transforms.Lambda(
            lambda x: x.view(-1, 1)[pixel_permutation].view(1, 28, 28)))
transform = transforms.Compose(all_transforms)

os.makedirs(args['output_dir'], exist_ok=True)

# Load the appropriate train and test datasets
trainset = datasets.MNIST(
    root='/tmp/data', train=True, download=True, transform=transform)
testset = datasets.MNIST(
    root='/tmp/data', train=False, download=True, transform=transform)

# Get indices of examples that should be used for training
if args['sorting_file'] == 'none':
    train_indx = np.array(range(len(trainset.train_labels)))
else:
    try:
        with open(
                os.path.join(args['input_dir'], args['sorting_file']) + '.pkl',
                'rb') as fin:
            ordered_indx = pickle.load(fin)['indices']
    except IOError:
        with open(os.path.join(args['input_dir'], args['sorting_file']),
                  'rb') as fin:
            ordered_indx = pickle.load(fin)['indices']

    # Get the indices to remove from training
    elements_to_remove = np.array(
        ordered_indx)[args['keep_lowest_n']:args['keep_lowest_n'] + args['remove_n']]

    # Remove the corresponding elements
    train_indx = np.setdiff1d(
        range(len(trainset.train_labels)), elements_to_remove)

# Remove remove_n number of examples from the train set at random
if args['keep_lowest_n'] < 0:
    train_indx = npr.permutation(np.arange(len(
        trainset.train_labels)))[:len(trainset.train_labels) - args['remove_n']]

# Reassign train data and labels
trainset.train_data = trainset.train_data[train_indx, :, :]
trainset.train_labels = np.array(trainset.train_labels)[train_indx].tolist()

print('Training on ' + str(len(trainset.train_labels)) + ' examples')

# Setup model and optimizer
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'])

# Setup loss
criterion = nn.CrossEntropyLoss()
criterion.__init__(reduce=False)

# Initialize dictionary to save statistics for every example presentation
example_stats = {}

elapsed_time = 0

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
Training on 60000 examples




In [12]:
for epoch in range(args['epochs']):
    start_time = time.time()

    train(args, model, device, trainset, optimizer, epoch, example_stats)
    test(args, model, device, testset, example_stats)

    epoch_time = time.time() - start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d' % (get_hms(elapsed_time)))

    # Save the stats dictionary
    fname = os.path.join(args['output_dir'], str(epoch))
    with open(fname + "__stats_dict.pkl", "wb") as f:
        pickle.dump(example_stats, f)

    # Log the best train and test accuracy so far
    with open(fname + "__best_acc.txt", "w") as f:
        f.write('train test \n')
        f.write(str(max(example_stats['train'][1])))
        f.write(' ')
        f.write(str(max(example_stats['test'][1])))

| Epoch [  0/200] Iter[938/938]		Loss: 0.7132 Acc@1: 43.438%
| Validation Epoch #0			Loss: 0.3346 Acc@1: 83.98%
| Elapsed time : 0:00:42
| Epoch [  1/200] Iter[938/938]		Loss: 0.4494 Acc@1: 77.578%
| Validation Epoch #1			Loss: 0.0949 Acc@1: 90.24%
| Elapsed time : 0:01:24
| Epoch [  2/200] Iter[938/938]		Loss: 0.5114 Acc@1: 82.973%
| Validation Epoch #2			Loss: 0.0658 Acc@1: 91.82%
| Elapsed time : 0:02:07
| Epoch [  3/200] Iter[938/938]		Loss: 0.3167 Acc@1: 85.455%
| Validation Epoch #3			Loss: 0.0581 Acc@1: 92.72%
| Elapsed time : 0:02:50
| Epoch [  4/200] Iter[938/938]		Loss: 0.5026 Acc@1: 86.623%
| Validation Epoch #4			Loss: 0.0406 Acc@1: 93.10%
| Elapsed time : 0:03:32
| Epoch [  5/200] Iter[938/938]		Loss: 0.3511 Acc@1: 87.455%
| Validation Epoch #5			Loss: 0.0369 Acc@1: 93.60%
| Elapsed time : 0:04:14
| Epoch [  6/200] Iter[938/938]		Loss: 0.2670 Acc@1: 87.952%
| Validation Epoch #6			Loss: 0.0231 Acc@1: 93.42%
| Elapsed time : 0:04:57
| Epoch [  7/200] Iter[938/938]		Loss: 0.

In [13]:
import argparse
import numpy as np
import os
import pickle


In [14]:
#args

def compute_forgetting_statistics(diag_stats, npresentations):

    presentations_needed_to_learn = {}
    unlearned_per_presentation = {}
    margins_per_presentation = {}
    first_learned = {}

    for example_id, example_stats in diag_stats.items():

        # Skip 'train' and 'test' keys of diag_stats
        if not isinstance(example_id, str):

            # Forgetting event is a transition in accuracy from 1 to 0
            presentation_acc = np.array(example_stats[1][:npresentations])
            transitions = presentation_acc[1:] - presentation_acc[:-1]

            # Find all presentations when forgetting occurs
            if len(np.where(transitions == -1)[0]) > 0:
                unlearned_per_presentation[example_id] = np.where(
                    transitions == -1)[0] + 2
            else:
                unlearned_per_presentation[example_id] = []

            # Find number of presentations needed to learn example, 
            # e.g. last presentation when acc is 0
            if len(np.where(presentation_acc == 0)[0]) > 0:
                presentations_needed_to_learn[example_id] = np.where(
                    presentation_acc == 0)[0][-1] + 1
            else:
                presentations_needed_to_learn[example_id] = 0

            # Find the misclassication margin for each presentation of the example
            margins_per_presentation = np.array(
                example_stats[2][:npresentations])

            # Find the presentation at which the example was first learned, 
            # e.g. first presentation when acc is 1
            if len(np.where(presentation_acc == 1)[0]) > 0:
                first_learned[example_id] = np.where(
                    presentation_acc == 1)[0][0]
            else:
                first_learned[example_id] = np.nan

    return presentations_needed_to_learn, unlearned_per_presentation, margins_per_presentation, first_learned


# Sorts examples by number of forgetting counts during training, in ascending order
# If an example was never learned, it is assigned the maximum number of forgetting counts
# If multiple training runs used, sort examples by the sum of their forgetting counts over all runs
#
# unlearned_per_presentation_all: list of dictionaries, one per training run
# first_learned_all: list of dictionaries, one per training run
# npresentations: number of training epochs
#
# Returns 2 numpy arrays containing the sorted example ids and corresponding forgetting counts
#
def sort_examples_by_forgetting(unlearned_per_presentation_all,
                                first_learned_all, npresentations):

    # Initialize lists
    example_original_order = []
    example_stats = []

    for example_id in unlearned_per_presentation_all[0].keys():

        # Add current example to lists
        example_original_order.append(example_id)
        example_stats.append(0)

        # Iterate over all training runs to calculate the total forgetting count for current example
        for i in range(len(unlearned_per_presentation_all)):

            # Get all presentations when current example was forgotten during current training run
            stats = unlearned_per_presentation_all[i][example_id]

            # If example was never learned during current training run, add max forgetting counts
            if np.isnan(first_learned_all[i][example_id]):
                example_stats[-1] += npresentations
            else:
                example_stats[-1] += len(stats)

    print('Number of unforgettable examples: {}'.format(
        len(np.where(np.array(example_stats) == 0)[0])))
    return np.array(example_original_order)[np.argsort(
        example_stats)], np.sort(example_stats)


# Checks whether a given file name matches a list of specified arguments
#
# fname: string containing file name
# args_list: list of strings containing argument names and values, i.e. [arg1, val1, arg2, val2,..]
#
# Returns 1 if filename matches the filter specified by the argument list, 0 otherwise
#
def check_filename(fname, args_list):

    # # If no arguments are specified to filter by, pass filename
    # if args_list is None:
    #     return 1

    # for arg_ind in list(args_list):#np.arange(0, len(args_list), 2):
    #     arg = str(arg_ind)
    #     arg_value = str(args_list[arg_ind])

    #     # Check if filename matches the current arg and arg value
    #     if arg + '_' + arg_value + '__' not in fname:
    #         print('skipping file: ' + fname)
    #         return 0

    return 1



In [15]:


args = {'output_dir': 'permuted_mnist_results',
        'output_name': 'permuted_mnist_sorted',
        'input_dir':'permuted_mnist_results',
        'epochs': 20,
        'input_fname_args':
        {
            'dataset': 'permuted_mnist',
            'no_droput': False,
            'sorting_file': 'none',
            'remove_n': 0,
            'keep_lowest_n': 0

        }

        }

# Initialize lists to collect forgetting stastics per example across multiple training runs
unlearned_per_presentation_all, first_learned_all = [], []

for d, _, fs in os.walk(args['input_dir']):
    for f in fs:

        # Find the files that match input_fname_args and compute forgetting statistics
        if f.endswith('stats_dict.pkl') and check_filename(
                f, args['input_fname_args']):
            print('including file: ' + f)

            # Load the dictionary compiled during training run
            with open(os.path.join(d, f), 'rb') as fin:
                loaded = pickle.load(fin)

            # Compute the forgetting statistics per example for training run
            _, unlearned_per_presentation, _, first_learned = compute_forgetting_statistics(
                loaded, args['epochs'])

            unlearned_per_presentation_all.append(
                unlearned_per_presentation)
            first_learned_all.append(first_learned)

if len(unlearned_per_presentation_all) == 0:
    print('No input files found in {} that match {}'.format(
        args['input_dir'], args['input_fname_args']))
else:

    # Sort examples by forgetting counts in ascending order, over one or more training runs
    ordered_examples, ordered_values = sort_examples_by_forgetting(
        unlearned_per_presentation_all, first_learned_all, args['epochs'])

    # Save sorted output
    if args['output_name'].endswith('.pkl'):
        with open(os.path.join(args['output_dir'], args['output_name']),
                  'wb') as fout:
            pickle.dump({
                'indices': ordered_examples,
                'forgetting counts': ordered_values
            }, fout)
    else:
        with open(
                os.path.join(args['output_dir'], args['output_name'] + '.pkl'),
                'wb') as fout:
            pickle.dump({
                'indices': ordered_examples,
                'forgetting counts': ordered_values
            }, fout)

including file: 0__stats_dict.pkl
including file: 1__stats_dict.pkl
including file: 2__stats_dict.pkl
including file: 3__stats_dict.pkl
including file: 4__stats_dict.pkl
including file: 5__stats_dict.pkl
including file: 6__stats_dict.pkl
including file: 7__stats_dict.pkl
including file: 8__stats_dict.pkl
including file: 9__stats_dict.pkl
including file: 10__stats_dict.pkl
including file: 11__stats_dict.pkl
including file: 12__stats_dict.pkl
including file: 13__stats_dict.pkl
including file: 14__stats_dict.pkl
including file: 15__stats_dict.pkl
including file: 16__stats_dict.pkl
including file: 17__stats_dict.pkl
including file: 18__stats_dict.pkl
including file: 19__stats_dict.pkl
including file: 20__stats_dict.pkl
including file: 21__stats_dict.pkl
including file: 22__stats_dict.pkl
including file: 23__stats_dict.pkl
including file: 24__stats_dict.pkl
including file: 25__stats_dict.pkl
including file: 26__stats_dict.pkl
including file: 27__stats_dict.pkl
including file: 28__stats_dict