<a href="https://colab.research.google.com/github/tasn19/scan-repro/blob/main/SCAN_repro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

Import setup files from github

In [None]:
# Import setup files from github
%%shell

git clone https://github.com/tasn19/scan-repro.git
cd setup
git checkout v0.3.0

cp utils.py ../
cp dataset.py ../
cp models.py ../
cp criterion.py ../
cp memorybank.py ../
cp train.py ../

In [None]:
# import libraries common to all tasks
import torch
import torchvision
from utils import get_transform
from models import get_model

In [None]:
# Set path to store dataset, models, checkpoints, etc.
base_path = "/content/drive/MyDrive/Colab Notebooks/SCANmaterials/Unsupervised-Classification/"
cifar_path = base_path + "datasets/cifar10"

# Pretext Task

In [None]:
# Import all pretext stuff
from dataset import CustomDataset
from utils import contrastive_evaluate
from criterion import SimCLR_loss
from memorybank import MemoryBank, fill_memory_bank
from train import SimCLR_train


In [None]:
!pip install faiss-gpu # needed for mine_nearest_neighbors -> Memory Bank

In [None]:
# Set paths to store files
pretext_model_path = base_path + "mymodels/simclrmodel.pth.tar"
checkpoint_path = base_path + "mymodels/pretext_checkpoint.pth.tar"
knn_train_path = base_path + "mymodels/knn_train.npy"
knn_test_path = base_path + "mymodels/knn_test.npy"

In [None]:
# Get transformations & datasets 
step = "simclr"
featuresDim = 128
numClasses = 10

# Get transformations
transform = get_transform("simclr")
base_transform = get_transform("base") 
val_transform = get_transform("validate")

hyperparams = {"epochs": 500, "batchsize": 512, "weight decay": 0.0001, "momentum": 0.9, "lr": 0.4, 
                   "lr decay rate": 0.1, "num_workers": 8}

# Load training set
train1_set = torchvision.datasets.CIFAR10(root = cifar_path, train = True, transform = transform, download = False) # change to True
train_set = CustomDataset(train1_set, step, base_transform = base_transform)

# For initial testing, take a small subset of dataset
#indices = torch.randperm(len(train_set)).tolist() # INITIAL TESTING
#train_set = torch.utils.data.Subset(train_set, indices[:1000]) # INITIAL TESTING

# enable pin_memory to speed up host to device transfer
train_loader = torch.utils.data.DataLoader(train_set, batch_size = hyperparams["batchsize"], shuffle = True, 
                                           num_workers = hyperparams["num_workers"], pin_memory = True, drop_last = True)

# Load testing set
test_set = torchvision.datasets.CIFAR10(root = cifar_path, transform = transform, download = False)
#test_set = torch.utils.data.Subset(test_set, indices[:1000]) # INITIAL TESTING
test_loader = torch.utils.data.DataLoader(test_set, batch_size = hyperparams["batchsize"], shuffle = False, 
                                          num_workers = hyperparams["num_workers"], pin_memory = True, drop_last = False)


In [None]:
# Instantiate model
model = get_model(step)
model.cuda()
torch.backends.cudnn.benchmark = True # without this, memory error during knn mining

# Get criterion
criterion = SimCLR_loss(hyperparams["batchsize"])
criterion.cuda()

In [None]:
# Build memory bank
# use dataset without augmentations for knn evaluation CHECK
base_dataset = torchvision.datasets.CIFAR10(root= cifar_path, train=True, transform=val_transform, download=True)
# subset
#base_dataset = torch.utils.data.Subset(base_dataset, indices[:1000]) # INITIAL TEST
base_loader = torch.utils.data.DataLoader(base_dataset, batch_size = hyperparams["batchsize"], shuffle = False, 
                                          num_workers = hyperparams["num_workers"], pin_memory = True)
base_memorybank = MemoryBank(len(base_dataset), featuresDim, numClasses, temperature = 0.1)
base_memorybank.cuda()
test_memorybank = MemoryBank(len(test_set), featuresDim, numClasses, temperature = 0.1) 
test_memorybank.cuda()


In [None]:
epochs = hyperparams["epochs"]
lr = hyperparams["lr"]
lr_decay_rate = hyperparams["lr decay rate"]

# Instantiate SGD?? optimizer # original paper used LARS...
params = model.parameters()
optimizer = torch.optim.SGD(params, lr, momentum=0.9, weight_decay=0.0001, nesterov=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=(lr*(lr_decay_rate**3)))

# Load checkpoint
if os.path.exists(checkpoint_path):
  print("Loading checkpoint")
  checkpoint = torch.load(checkpoint_path)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.cuda()
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  startE = checkpoint['epoch']
else:
  startE = 0

# Train model
# add warm-up? (to reduce effect of early training)
for epoch in range(epochs):
  print('Epoch ', epoch)
  # Update scheduler (it resets every epoch)
  scheduler.step()
  lr = scheduler.get_last_lr()[0]
  #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=(lr*(decay_rate**3)))
  
  print('Learning Rate ', lr)

  # Train
  SimCLR_train(train_loader, model, epoch, criterion, optimizer)
  print("SimCLR train complete")

  # Fill memory bank for knn step  
  fill_memory_bank(base_loader, model, base_memorybank)

  # Check progress
  print("Evaluating")
  top1 = contrastive_evaluate(test_loader, model, base_memorybank)
  print('Result of kNN evaluation is %.2f' %(top1)) 

  # Save checkpoint
  torch.save({'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)

# Save model
torch.save(model.state_dict(), my_model_path)


Mine the top 20 nearest neighbors to pass on to SCAN step

In [None]:
k = 20
fill_memory_bank(base_loader, model, base_memorybank)
train_indices, accuracy = base_memorybank.mine_nearest_neighbors(k)
print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(k, 100*accuracy))
# save positions of nearest neighbors
np.save(knn_train_path, train_indices)   

Mine the top 5 neighbors for validation

In [None]:
fill_memory_bank(test_loader, model, test_memorybank)
k = 5
test_indices, accuracy = test_memorybank.mine_nearest_neighbors(k)
print('Accuracy of top-%d nearest neighbors on test set is %.2f' %(k, 100*acc))
# save positions of nearest neighbors
np.save(knn_test_path, test_indices)  

# SCAN

Classify images with semantically meaningful nearest neighbors (found in pretext task) together

In [None]:
# Import all SCAN stuff
from dataset import NNDataset
from utils import get_predictions, SCAN_evaluate, hungarian_evaluate
from criterion import SCAN_loss
from train import SCAN_train

In [None]:
# Download from github - additional augmentation functions
!wget https://raw.githubusercontent.com/wvangansbeke/Unsupervised-Classification/master/data/augment.py

In [None]:
# Set paths to store files
SCAN_model_path = base_path + "mymodels/scanmodel.pth.tar"
checkpoint_path_scan = base_path + "mymodels/checkpoint_scan.pth.tar"
cifar_path = base_path + "datasets/cifar10"
knn_train_path = base_path + "mymodels/knn_train.npy"

In [None]:
step2 = "scan"
scan_hyperparams = {"epochs": 20, "batchsize": 128, "lr": 0.0001, "weight decay": 0.0001, "num_workers": 8}
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Get transformations
scan_transforms = get_transform(step2)
base_transform = get_transform("base") 
# val_transforms is the same as simclr 
val_transform = get_transform("validate")

# Load training set
train1_set_scan = torchvision.datasets.CIFAR10(root = cifar_path, train = True, transform = scan_transforms, download = False) # change to True
knn_indices = np.load(knn_train_path)
train_set_scan = NNDataset(train1_set_scan, knn_indices, numNeighbors=20, step=step2, base_transform = base_transform)
indices1 = torch.randperm(len(train_set_scan)).tolist()  # INITIAL TESTING
train_set_scan = torch.utils.data.Subset(train_set_scan, indices1[:1000]) # INITIAL TESTING
train_loader_scan = torch.utils.data.DataLoader(train_set_scan, batch_size = scan_hyperparams["batchsize"], shuffle = True, 
                                           num_workers = scan_hyperparams["num_workers"], pin_memory = True, drop_last = True)

# Load testing set
test1_set_scan = torchvision.datasets.CIFAR10(root = cifar_path, train = False, transform = val_transform, download = False)
knn_test_indices = np.load(knn_test_path) 
knn_test_indices = np.reshape(knn_test_indices, (-1,20))
test_set_scan = NNDataset(test1_set_scan, knn_test_indices, numNeighbors=5, step=step2, base_transform = base_transform)
test_set_scan = torch.utils.data.Subset(test_set_scan, indices1[:1000]) # INITIAL TESTING
test_loader_scan = torch.utils.data.DataLoader(test_set_scan, batch_size = scan_hyperparams["batchsize"], shuffle = False, 
                                          num_workers = scan_hyperparams["num_workers"], pin_memory = True, drop_last = False)

In [None]:
# Set device to cuda if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model
model = get_model(step2, pretrained_weights=None, numClasses=10) 
model.to(device)

# Get criterion
criterion = SCAN_loss(entropy_weight = 5)
criterion.to(device)

In [None]:
# Optimizer
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=scan_hyperparams["lr"], weight_decay=scan_hyperparams["weight decay"])
# use constant learning rate

# Load checkpoint
if os.path.exists(checkpoint_path_scan):
  print("Loading checkpoint")
  checkpoint = torch.load(checkpoint_path_scan)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.cuda()
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  startE = checkpoint['epoch']
  best_loss = checkpoint['best_loss']
else:
  startE = 0
  best_loss = 1e4

# Train
epochs = scan_hyperparams["epochs"]
for epoch in range(startE, epochs):
  print('Epoch ', epoch)

  # Train
  SCAN_train(train_loader_scan, model, epoch, criterion, optimizer) 
  print("SCAN training complete")

  # Evaluate
  print("Evaluating")
  lowest_loss = SCAN_evaluate(test_loader_scan, model)
  print('loss', lowest_loss)
  
  if lowest_loss < best_loss:
    print("New lowest loss: {}, previous lowest: {}".format(lowest_loss, best_loss))
    best_loss = lowest_loss
    #torch.save(model.state_dict(), SCAN_model_path)

  print('Evaluate with hungarian matching algorithm ...')
  # Check progress
  predictions = get_predictions(test_loader_scan, model)
  stats = hungarian_evaluate(predictions, compute_confusion_matrix=False)
  print(stats)
  
  # Save checkpoint
  torch.save({'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'best_loss': best_loss}, checkpoint_path_scan)

# Evaluate best model with hungarian matching algorithm
#model.load_state_dict(torch.load(SCAN_model_path)
predictions = get_predictions(test_loader_scan, model)
clustering_stats = hungarian_evaluate(predictions, class_names=classes, compute_confusion_matrix=True)
print(clustering_stats) 


# Self-label

Fine-tune to correct labelling mistakes due to noisy nearest neighbors

In [None]:
# Import self-label stuff
from dataset import CustomDataset
from utils import get_predictions, hungarian_evaluate
from criterion import CE_loss
from train import selflabel_train

In [None]:
# Set paths to store files
cifar_path = base_path + "datasets/cifar10"
selflabel_model_path = base_path + "mymodels/selflabelmodel"
checkpoint_path_slbl = base_path + "mymodels/checkpoint_slbl"

In [None]:
step3 = 'selflabel'

#slbl_hyperparams = {"epochs": 200, "batchsize": 1000, "lr": 0.0001, "weight decay": 0.0001, "num_workers": 8} # ORIGINAL
slbl_hyperparams = {"epochs": 10, "batchsize": 512, "lr": 0.0001, "weight decay": 0.0001, "num_workers": 8}

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Get transformations - same as SCAN step
slbl_transforms = get_transform('scan')
base_transform = get_transform("base") # FIX !!!!!
# val_transforms is the same as simclr 
val_transform = get_transform("validate")

# Get datasets
# Load training set
# CHECK: paper code using 'standard':val_transforms
train1_set_slbl = torchvision.datasets.CIFAR10(root = cifar_path, train = True, transform = slbl_transforms, download = False) # change to True
train_set_slbl = CustomDataset(train1_set_slbl, step=step3, base_transform = base_transform)
indices1 = torch.randperm(1000).tolist()  # INITIAL TESTING 
train_set_slbl = torch.utils.data.Subset(train_set_slbl, indices1[:1000]) # INITIAL TESTING
train_loader_slbl = torch.utils.data.DataLoader(train_set_slbl, batch_size = slbl_hyperparams["batchsize"], shuffle = True, 
                                           num_workers = slbl_hyperparams["num_workers"], pin_memory = True, drop_last = True)

# Load testing set
test1_set_slbl = torchvision.datasets.CIFAR10(root = cifar_path, train = False, transform = val_transform, download = False)
test_set_slbl = torch.utils.data.Subset(test1_set_slbl, indices1[:1000]) # INITIAL TESTING
test_loader_slbl = torch.utils.data.DataLoader(test_set_slbl, batch_size = slbl_hyperparams["batchsize"], shuffle = False, 
                                          num_workers = slbl_hyperparams["num_workers"], pin_memory = True, drop_last = False)


In [None]:
# Set device to cuda if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model
model = get_model(step3, pretrained_weights=None, numClasses=10) # change to pretrained_weight=SCAN_model_path
model.to(device) # does it work with CPU? CHECK
torch.backends.cudnn.benchmark = True # need to avoid memory errors

# cudnn
# Get criterion
# weighted cross entropy loss used to update weights for obtaining pseudolabels, to compensate for imbalance btwn confident samples across clusters
criterion = CE_loss(threshold=0.99) 
criterion.to(device)

In [None]:
# Optimizer
params = model.parameters()
optimizer = torch.optim.Adam(params, lr=slbl_hyperparams["lr"], weight_decay=slbl_hyperparams["weight decay"])
# use constant learning rate

# Load checkpoint
if os.path.exists(checkpoint_path_slbl):
  print("Loading checkpoint")
  checkpoint = torch.load(checkpoint_path_slbl)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.cuda()
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  startE = checkpoint['epoch']
else:
  startE = 0

# Train
epochs = slbl_hyperparams["epochs"]
for epoch in range(epochs):
  print('Epoch ', epoch)

  # Train
  selflabel_train(train_loader_slbl, model, epoch, criterion, optimizer)  
  print("Self-label training complete")

  # Check progress
  print("Evaluating with hungarian matching algorithm")
  predictions = get_predictions_slbl(test_loader_slbl, model) 

  stats = hungarian_evaluate(predictions, compute_confusion_matrix=False)
  print(stats)
  
  # Save checkpoint
  torch.save({'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path_slbl)

# Final Evaluation
print("Evaluating final model")
predictions = get_predictions_slbl(test_loader_slbl, model)
clustering_stats = hungarian_evaluate(subhead_index, predictions, class_names=classes, 
                            compute_confusion_matrix=True)
print(clustering_stats) 

# Save model
torch.save(model.state_dict(), selflabel_model_path)