# Install, Paths and Parameters

In [1]:
import os
from pathlib import Path
import getpass
import numpy as np
import time
import torch
from torch import nn
from tqdm import tqdm
import random
import sys

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino
from src.model.data import create_loader
from src.model.eval import validate_network

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', username, 'dl_data')

DN_PATH = Path(DATA_PATH, 'damageNet')
DN_LABEL_PATH = Path(DN_PATH, 'val_damagenet.txt')
DN_IMAGES_PATH = Path(DN_PATH, 'images')

ORI_PATH = Path(DATA_PATH, 'ori_data')
ORI_LABEL_PATH = Path(ORI_PATH,'correct_labels.txt')
ORI_IMAGES_PATH = Path(ORI_PATH,'images')

In [2]:
INDEX_SUBSET = get_random_indexes()

BATCH_SIZE = 16
N_LAST_BLOCKS = 4

DEVICE = 'cuda'

In [3]:
#!python $HOME/deeplearning/setup/collect_env.py

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [4]:
model, linear_classifier = get_dino()

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
We load the reference pretrained linear weights.


# Load data

In [5]:
ori_loader = create_loader(ORI_IMAGES_PATH, ORI_LABEL_PATH, INDEX_SUBSET, BATCH_SIZE)

In [6]:
dn_loader = create_loader(DN_IMAGES_PATH, DN_LABEL_PATH, None, BATCH_SIZE)

# Inference

In [7]:
pred = validate_network(dn_loader, model, linear_classifier)

Test:  [   0/3125]  eta: 0:09:20  loss: 4.988029 (4.988029)  acc1: 25.000000 (25.000000)  acc5: 56.250000 (56.250000)  time: 0.179343  data: 0.114727  max mem: 182
Test:  [  20/3125]  eta: 0:10:44  loss: 4.675150 (4.746682)  acc1: 25.000000 (23.214286)  acc5: 43.750000 (47.619048)  time: 0.208980  data: 0.171253  max mem: 201
Test:  [  40/3125]  eta: 0:10:32  loss: 4.596530 (4.649653)  acc1: 25.000000 (23.475610)  acc5: 50.000000 (47.560976)  time: 0.202114  data: 0.164337  max mem: 201
Test:  [  60/3125]  eta: 0:10:25  loss: 4.631100 (4.677447)  acc1: 18.750000 (22.540984)  acc5: 43.750000 (47.131148)  time: 0.201948  data: 0.164218  max mem: 201
Test:  [  80/3125]  eta: 0:10:20  loss: 4.518159 (4.630482)  acc1: 18.750000 (21.990741)  acc5: 50.000000 (48.225309)  time: 0.202663  data: 0.164859  max mem: 201
Test:  [ 100/3125]  eta: 0:10:15  loss: 4.399150 (4.618075)  acc1: 18.750000 (21.782178)  acc5: 43.750000 (48.143564)  time: 0.203042  data: 0.165346  max mem: 201
Test:  [ 120/312

In [8]:
pred

{'loss': 4.718800410919189, 'acc1': 21.284, 'acc5': 48.546}