# MoCo Preptraining implementation

## Init

In [1]:
import torch
import os.path
from Datasets import get_loaders
from Augmentations import augment
from MoCoTrainer import get_MoCo_feature_extractor
from tqdm.notebook import tqdm

In [2]:
# Hyperparams taken from paper
TEMPERATURE = 0.07
MOMENTUM = 0.999
KEY_DICTIONARY_SIZE = 4096
NUM_EPOCHS = 500
MOCO_DIM=128

## Define classifier

In [3]:
def accuracy(predicted_labels, true_labels):
    return (predicted_labels == true_labels).sum() / len(true_labels)

In [4]:
class Classifier(torch.nn.Module):
    def __init__(self, feature_extractor: torch.nn.Module, loss_func, optimizer_type, num_classes, weights_path=None, weights_save_file="classifier_weights.pth"):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.weights_save_file = weights_save_file
        self.feature_extractor = feature_extractor.to(device=self.device)
        self.feature_extractor.fc = torch.nn.Identity()  # disable final fc layer of feature extractor
        ftr_exctr_output_dim = 2048  # flattening of final resnet convolution layer's output
        intermediate_dim = ftr_exctr_output_dim
        fc2 = torch.nn.Linear(in_features=intermediate_dim, out_features=num_classes, device=self.device)
        self.classifier_head = torch.nn.Sequential(fc2, torch.nn.Softmax(dim=1))
        self.loss_func = loss_func
        self.optimizer = optimizer_type(self.classifier_head.parameters())
        if weights_path and os.path.exists(weights_path):
            self.load_state_dict(torch.load(weights_path))
        self.weights_path = weights_path

    def forward(self, x):
        with torch.no_grad():
            x = self.feature_extractor(x)  # Only fine-tune classifier head
        x = self.classifier_head(x)
        return x

    def train(self, train_loader, val_loader=None, num_epochs=NUM_EPOCHS):
        validation_accuracies = []
        loss = float("inf")
        losses = []
        for epoch in tqdm(range(num_epochs)):
            print(f"epoch={epoch}, loss={loss}")
            with tqdm(total=46) as pbar:
                for inputs, true_labels in train_loader:
                    self.optimizer.zero_grad()
                    inputs, true_labels = inputs.to(self.device), true_labels.to(self.device)

                    label_confidences = self.forward(inputs)
                    loss = self.loss_func(label_confidences, true_labels)
                    loss.backward()
                    self.optimizer.step()

                    if val_loader:
                        inputs, true_labels = next(val_loader)
                        with torch.no_grad():
                            predicted_labels = self.forward(inputs)
                        validation_accuracies.append(accuracy(predicted_labels=predicted_labels, true_labels=true_labels))
                    
                    predicted_labels = torch.multinomial(input=label_confidences, num_samples=1).squeeze(dim=1)
                    batch_accuracy = 100.0 * accuracy(predicted_labels=predicted_labels, true_labels=true_labels)
                    pbar.set_description(
                        f"(Loss {loss.item():.3f}, "
                        f"Accuracy {batch_accuracy:.1f}%)"
                    )
                    losses.append(loss)
                    torch.save(self.state_dict(), self.weights_path)
                    pbar.update()
        return losses

    def test(self, test_loader):
        total_correct = 0
        test_size = 0
        with tqdm(total=19) as pbar:
            for inputs, true_labels in test_loader:
                inputs, true_labels = inputs.to(self.device), true_labels.to(self.device)

                with torch.no_grad():
                    predicted_labels = self.forward(inputs).topk(k=1)[1].squeeze(dim=1)
                    total_correct += (predicted_labels == true_labels).sum().item()
                    test_size += len(true_labels)
                
                pbar.update()

        print(f"Accuracy: {total_correct / test_size}")
        return total_correct, test_size


## Obtain trained MoCo feature extractor

In [5]:
dl_train_moco, dl_val_moco, dl_train_clf, dl_val_clf = get_loaders(data_path="../imagenette2", batch_size=256//6)  
extractor, moco_training_losses = get_MoCo_feature_extractor(temperature=TEMPERATURE, loader=dl_train_moco, augment=augment, momentum=MOMENTUM, key_dictionary_size=KEY_DICTIONARY_SIZE, num_epochs=NUM_EPOCHS, moco_dim=MOCO_DIM, early_stopping_count=1)
#device = "cuda" if torch.cuda.is_available() else "cpu"
#extractor = get_encoder().to(device)
#extractor.load_state_dict(torch.load('f_q_weights.pth', map_location=device))

Training moco dataset has 7575 images with a total of 181 batches
Testing moco dataset has 3140 images with a total of 75 batches
Training clf dataset has 1894 images with a total of 46 batches
Testing clf dataset has 785 images with a total of 19 batches
[34mInitializing feature extractor training[0m


Using cache found in /home/a.block/.cache/torch/hub/pytorch_vision_v0.10.0


[34mLoading pretrained weights from file[0m
[34mGenerating initial keys queue[0m


100%|██████████| 97/97 [00:18<00:00,  5.35it/s]


[34mBeginning training loop[0m
epoch = 0.  Experimentally 4 epochs ought to do the trick.


  1%|          | 1/119 [00:01<02:16,  1.16s/it]

batch = 0
[32mCompleted minibatch, loss=0.5017343759536743, with accuracy=0.8333333730697632[0m


  1%|          | 1/119 [00:01<03:05,  1.57s/it]

[34mGot a perfect match![0m
[32mEarly stopping!  Loss=0.17173293232917786, with accuracy=1.0[0m
[34mCompleted training MoCo feature extractor early![0m





## Obtain classifier

In [6]:
model = Classifier(feature_extractor=extractor, loss_func=torch.nn.CrossEntropyLoss(), optimizer_type=torch.optim.Adam, num_classes=10, weights_path="classifier_weights.pth")
classifier_training_losses = model.train(train_loader=dl_train_clf, num_epochs=3)

  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=1.7099717855453491


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=1.479665994644165


  0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
model.test(dl_val_clf)

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.978343949044586


(768, 785)

## Linear probing

In [8]:
class LinearProbe(Classifier):
    def __init__(self, feature_extractor: torch.nn.Module, loss_func, optimizer_type, num_classes, num_layers_applied, ftr_exctr_output_dim, weights_path=None):
        super().__init__(feature_extractor, loss_func, optimizer_type, num_classes, weights_path)
        
        extractor_layers = [
            self.feature_extractor.conv1,
            self.feature_extractor.bn1,
            self.feature_extractor.relu,
            self.feature_extractor.maxpool,

            self.feature_extractor.layer1,
            self.feature_extractor.layer2,
            self.feature_extractor.layer3,
            self.feature_extractor.layer4
        ]
        self.shortened_extractor = torch.nn.Sequential(*extractor_layers[:num_layers_applied]).to(device=self.device)
        fc = torch.nn.Linear(in_features=ftr_exctr_output_dim, out_features=num_classes, device=self.device)
        self.classifier_head = torch.nn.Sequential(torch.nn.Flatten(), fc, torch.nn.Softmax(dim=1))
        self.optimizer = optimizer_type(self.classifier_head.parameters())
        if weights_path and os.path.exists(weights_path):
            self.load_state_dict(torch.load(weights_path))
            
    
    def forward(self, x):
        with torch.no_grad():
            x = self.shortened_extractor(x)
        x = self.classifier_head(x)
        return x

In [10]:
intermediate_dims = [150528, 802816, 802816, 802816, 200704, 802816, 401408, 200704, 100352]
for i in range(9):  # Number of ResNet layers
    print(f"\nProbe number {i}/9:")
    probe = LinearProbe(feature_extractor=extractor, loss_func=torch.nn.CrossEntropyLoss(), optimizer_type=torch.optim.Adam, num_classes=10, num_layers_applied=i, ftr_exctr_output_dim=intermediate_dims[i], weights_path=f"probe{i}_weights.pth")
    probe.train(train_loader=dl_train_clf, num_epochs=3)
    probe.test(dl_val_clf)
    


Probe number 0/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.14904458598726114

Probe number 1/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.4611473083496094


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.1643312101910828

Probe number 2/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.10318471337579618

Probe number 3/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.2111501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.09681528662420383

Probe number 4/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.2111501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.10445859872611465

Probe number 5/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=1.9611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.4611501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.09426751592356689

Probe number 6/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.2111501693725586


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.45890474319458


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.1681528662420382

Probe number 7/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=2.43827486038208


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=2.008925437927246


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.5261146496815287

Probe number 8/9:


  0%|          | 0/3 [00:00<?, ?it/s]

epoch=0, loss=inf


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=1, loss=1.9566727876663208


  0%|          | 0/30 [00:00<?, ?it/s]

epoch=2, loss=1.9611494541168213


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

Accuracy: 0.5490445859872611
