In [26]:
import os

import torch
from torch import nn
import torch.distributed as dist
from torchvision import datasets
from torchvision import transforms as pth_transforms
from torchvision import models as torchvision_models

import utils
import vision_transformer as vits

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, dim, num_labels=1000):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = nn.Linear(dim, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)

        # linear layer
        return self.linear(x)

In [42]:
@torch.no_grad()
def validate_network(val_loader, model, linear_classifier, n, avgpool):
    linear_classifier.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    for inp, target in metric_logger.log_every(val_loader, 20, header):
        # move to gpu
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward
        with torch.no_grad():
            if "vit" in "vit_base":
                intermediate_output = model.get_intermediate_layers(inp, n)
                output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
                if avgpool:
                    output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
                    output = output.reshape(output.shape[0], -1)
            else:
                output = model(inp)
        output = linear_classifier(output)
        loss = nn.CrossEntropyLoss()(output, target)

        if linear_classifier.num_labels >= 5:
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        else:
            acc1, = utils.accuracy(output, target, topk=(1,))

        batch_size = inp.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        if linear_classifier.num_labels >= 5:
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    if linear_classifier.num_labels >= 5:
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
    else:
        print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, losses=metric_logger.loss))
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

In [29]:
# Set Pretext and Linear Classifier weights here

PRETRAINED_WEIGHTS = "models/checkpoint0200.pth"
LINEAR_WEIGHTS = "models/checkpoint.pth.tar"

In [31]:
model = vits.__dict__['vit_base'](patch_size=16, num_classes=4)
embed_dim = model.embed_dim * (4 + int(False))

utils.load_pretrained_weights(model, PRETRAINED_WEIGHTS, 'teacher', 'vit_base', 16)
model.to(device)

Take key teacher in provided checkpoint dict
Pretrained weights found at models/checkpoint0200.pth and loaded with msg: _IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate=none)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Atte

In [32]:
linear_classifier = LinearClassifier(embed_dim, num_labels=4)
linear_classifier = linear_classifier.cuda()

lr = 0.001

# set optimizer
optimizer = torch.optim.SGD(
    linear_classifier.parameters(),
    lr * (256 * utils.get_world_size()) / 256., # linear scaling rule
    momentum=0.9,
    weight_decay=0, # we do not apply weight decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, eta_min=0)

to_restore = {"epoch": 0, "best_acc": 0.}
utils.restart_from_checkpoint(
    os.path.join("models", "checkpoint.pth.tar"),
    run_variables=to_restore,
    state_dict=linear_classifier,
    optimizer=optimizer,
    scheduler=scheduler,
)

start_epoch = to_restore["epoch"]
best_acc = to_restore["best_acc"]

Found checkpoint at models/checkpoint.pth.tar
Value is dict_keys(['linear.weight', 'linear.bias'])
=> loaded 'state_dict' from checkpoint 'models/checkpoint.pth.tar' with msg <All keys matched successfully>
=> loaded 'optimizer' from checkpoint: 'models/checkpoint.pth.tar'
=> loaded 'scheduler' from checkpoint: 'models/checkpoint.pth.tar'


In [38]:
# ============ preparing data ... ============
test_transform = pth_transforms.Compose([
    pth_transforms.Resize(256, interpolation=3),
    pth_transforms.CenterCrop(224),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

dataset_test = datasets.ImageFolder(os.path.join("../data/ship_100train_500val_200test", "test"), transform=test_transform)
test_loader = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=256,
    num_workers=16,
    pin_memory=True,
)


In [43]:
test_stats = validate_network(test_loader, model, linear_classifier, 4, False)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

Test:  [0/4]  eta: 0:00:22  loss: 0.235036 (0.235036)  acc1: 88.281250 (88.281250)  time: 5.640758  data: 4.197881  max mem: 5518
Test:  [3/4]  eta: 0:00:02  loss: 0.153946 (0.185878)  acc1: 93.750000 (92.625000)  time: 2.172544  data: 1.049648  max mem: 5518
Test: Total time: 0:00:08 (2.227809 s / it)
* Acc@1 92.625 loss 0.186
Accuracy of the network on the 800 test images: 92.6%
