<a href="https://colab.research.google.com/github/tasn19/scan-repro/blob/main/SCAN_repro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

Import setup files from github

In [None]:
# Import setup files from github
%%shell

git clone https://github.com/tasn19/scan-repro.git
cd setup
git checkout v0.3.0

cp utils.py ../
cp dataset.py ../
cp models.py ../
cp criterion.py ../
cp memorybank.py ../
cp train.py ../

# Pretext Task

In [None]:
# Import all pretext stuff
import torch
import torchvision
from dataset import CustomDataset
from utils import get_transform, contrastive_evaluate
from models import get_model
from criterion import SimCLR_loss
from memorybank import MemoryBank, fill_memory_bank
from train import SimCLR_train


In [None]:
# Get transformations & datasets 
cifar_path = "/content/drive/MyDrive/Colab Notebooks/SCANmaterials/Unsupervised-Classification/datasets/cifar10"
pretext_model_path = "/content/drive/MyDrive/Colab Notebooks/SCANmaterials/Unsupervised-Classification/mymodels"
step = "simclr"
featuresDim = 128
numClasses = 10

# Get transformations
transform = get_transform("simclr")
base_transform = get_transform("base") # FIX !!!!!
val_transform = get_transform("validate")

# author code epochs = 500, batchsize = 512, num_workers = 8
hyperparams = {"epochs": 20, "batchsize": 512, "weight decay": 0.0001, "momentum": 0.9, "lr": 0.4, 
                   "lr decay rate": 0.1, "num_workers": 8}

# Load training set
train1_set = torchvision.datasets.CIFAR10(root = cifar_path, train = True, transform = transform, download = False) # change to True
train_set = CustomDataset(train1_set, step, base_transform = base_transform)
# enable pin_memory to speed up host to device transfer
# Probably highest possible batch_size=128 & num_workers=2 with memory limits CHECK
train_loader = torch.utils.data.DataLoader(train_set, batch_size = hyperparams["batchsize"], shuffle = True, 
                                           num_workers = hyperparams["num_workers"], pin_memory = True, drop_last = True)

# Load testing set
test_set = torchvision.datasets.CIFAR10(root = cifar_path, transform = transform, download = False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = hyperparams["batchsize"], shuffle = False, 
                                          num_workers = hyperparams["num_workers"], pin_memory = True, drop_last = False)


In [None]:
# For initial testing, take a small subset of dataset
indices = torch.randperm(len(train_set)).tolist()
expset = torch.utils.data.Subset(train_set, indices[:1000])
expload = torch.utils.data.DataLoader(expset, batch_size = 512, shuffle = True, num_workers = 8,  pin_memory = True, drop_last = True)

testexpset = torch.utils.data.Subset(test_set, indices[:1000])
testexpload = torch.utils.data.DataLoader(testexpset, batch_size = 512, shuffle = False, num_workers = 8, pin_memory = True, drop_last = False)

In [None]:
# Instantiate model
model = get_model(step)
model.cuda()

# Get criterion
criterion = SimCLR_loss(hyperparams["batchsize"])
criterion.cuda()

SimCLR_loss()

In [None]:
# Build memory bank
# use dataset without augmentations for knn evaluation CHECK
base_dataset = torchvision.datasets.CIFAR10(root= cifar_path, train=True, transform=val_transform, download=True)
# subset
base_dataset = torch.utils.data.Subset(base_dataset, indices[:1000]) # remove
base_loader = torch.utils.data.DataLoader(base_dataset, batch_size = hyperparams["batchsize"], shuffle = False, 
                                          num_workers = hyperparams["num_workers"], pin_memory = True)
base_memorybank = MemoryBank(len(base_dataset), featuresDim, numClasses, temperature = 0.1)
base_memorybank.cuda()
test_memorybank = MemoryBank(len(testexpset), featuresDim, numClasses, temperature = 0.1) # change testexpset to test_set
test_memorybank.cuda()


In [None]:
lr = 0.4
decay_rate = 0.1
epochs = hyperparams["epochs"]
# Instantiate SGD?? optimizer # original paper used LARS...
#params = [p for p in model.parameters() if p.requires_grad] # CHECK THIS
params = model.parameters()
optimizer = torch.optim.SGD(params, lr, momentum=0.9, weight_decay=0.0001, nesterov=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=(lr*(decay_rate**3)))

# Checkpoint - use paper code

# Train model
# add warm-up? (to reduce effect of early training)
for epoch in range(epochs):
  print('Epoch ', epoch)
  # Update scheduler (it resets every epoch)
  scheduler.step()
  lr = scheduler.get_last_lr()[0]
  #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=(lr*(decay_rate**3)))
  
  print('Learning Rate ', lr)

  # Train
  SimCLR_train(expload, model, epoch, criterion, optimizer) # change expload to train_loader
  print("SimCLR train complete")

  # Fill memory bank for knn step  
  fill_memory_bank(base_loader, model, base_memorybank)

  # Check progress
  print("Evaluating")
  top1 = contrastive_evaluate(testexpload, model, base_memorybank) # change testexpload to test_loader
  print('Result of kNN evaluation is %.2f' %(top1)) 

  # checkpoint 

# Save model
#torch.save(model.state_dict(), my_model_path)


Mine the top 20 nearest neighbors to pass on to SCAN step

In [None]:
k = 20
fill_memory_bank(base_loader, model, base_memorybank)
train_indices, accuracy = base_memorybank.mine_nearest_neighbors(k)
print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(k, 100*accuracy))
    #np.save(p['topk_neighbors_train_path'], indices)   

Mine the top 5 neighbors for validation

In [None]:
fill_memory_bank(test_loader, model, test_memorybank)
k = 5
test_indices, accuracy = test_memorybank.mine_nearest_neighbors(k)
print('Accuracy of top-%d nearest neighbors on test set is %.2f' %(k, 100*acc))
    #np.save(p['topk_neighbors_val_path'], indices)   

# SCAN