In [4]:
import torch
import os
import argparse

import mixed_precision
from stats import AverageMeterSet
from datasets import Dataset, build_dataset, get_dataset, get_encoder_size
from model import Model
from checkpoint import Checkpointer
from utils import test_model
import matplotlib.pyplot as plt

In [12]:
parser = argparse.ArgumentParser(description='Infomax Representations - Testing Script')
# parameters for general training stuff
parser.add_argument('checkpoint_path', type=str,
                    help='path from which to load checkpoint')
parser.add_argument('--dataset', type=str, default='STL10')
parser.add_argument('--batch_size', type=int, default=200,
                    help='input batch size (default: 200)')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
parser.add_argument('--amp', action='store_true', default=False,
                    help='Enables automatic mixed precision')
parser.add_argument('--input_dir', type=str, default='/mnt/imagenet',
                    help="Input directory for the dataset. Not needed For C10,"
                    " C100 or STL10 as the data will be automatically downloaded.")
parser.add_argument('--run_name', type=str, default='default_run',
                    help='name to use for the tensorbaord summary for this run')
args = parser.parse_args(args=["--amp", "--dataset", "C10", "/root/amdim-public/runs/amdim_cpt.pth", ])


In [13]:
if args.amp:
    mixed_precision.enable_mixed_precision()

torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# get the dataset
dataset = get_dataset(args.dataset)

_, test_loader, _ = build_dataset(dataset=dataset,
                        batch_size=args.batch_size,
                        input_dir=args.input_dir)

torch_device = torch.device('cuda')
checkpointer = Checkpointer()

Files already downloaded and verified
Files already downloaded and verified


In [14]:
model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
model = model.to(torch_device)
model, _ = mixed_precision.initialize(model, None)

test_stats = AverageMeterSet()

Using a 32x32 encoder
***** CHECKPOINTING *****
Model restored from checkpoint.
Self-supervised training epoch 115
Classifier training epoch 0
*************************
Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.

Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic


In [15]:
def test(model, test_loader, device, stats):
    test_model(model, test_loader, device, stats)

In [16]:
test_stats = AverageMeterSet()
test(model, test_loader, torch_device, test_stats)
stat_str = test_stats.pretty_string(ignore=model.tasks)
print(stat_str)

test_accuracy_mlp_classifier: 0.896, test_accuracy_linear_classifier: 0.867


# in test_model

In [18]:
batches=50
model.train()
for i, (images, _) in enumerate(test_loader):
    if i == batches:
        break
    images = images.to(torch_device)
    _ = model(x1=images, x2=images, class_only=True)

In [29]:
model.eval()
for i, (images, labels) in enumerate(test_loader):
    print(images.shape)
    images = images.to(torch_device)
    labels = labels.cpu()
    with torch.no_grad():
        res_dict = model(x1=images, x2=images, class_only=True)
        lgt_glb_mlp, lgt_glb_lin = res_dict['class']
    if i == 0:
        break

torch.Size([200, 3, 32, 32])


In [34]:
model.encode(images, no_grad=True)[0].shape

torch.Size([200, 1024, 1, 1])