In [7]:
PATH_PREFIX = '../../../'
import sys
sys.path.append(PATH_PREFIX)

In [8]:
import torch
import argparse
import os
import numpy as np

from data.imagenette import ImagenetteDataset
from data.utils.imagenet_utils import *
from utils.weight_sharing import *

In [9]:
# net params
BATCH_SIZE = 32
NET_REPO = 'pytorch/vision:v0.10.0'
DEVICE = 'cpu'
LAYER_CLUSTERS = [98, 95, 77, 67, 115, 106, 55, 98, 110, 55, 52, 44, 113, 61, 50, 19, 40, 107, 87, 10, 60, 22, 95, 31, 12, 51, 37, 102, 45, 31, 65, 115, 62, 13, 43, 112, 101, 62, 72, 59, 76, 89, 29, 38, 41, 112, 23, 115, 44, 13, 106, 79, 86]
NET_TYPE = 'mobilenet_v2'
CLUST_ALG = 'minibatch-kmeans'
PREC = 'f4'

# dataset settings
DATA_PATH = os.path.join(PATH_PREFIX, 'data/imagenette')

In [10]:
dataset = ImagenetteDataset(BATCH_SIZE, DATA_PATH, val_split=0.3)
model = torch.hub.load(NET_REPO, NET_TYPE, pretrained=True)

lam_opt = None
lam_train = None
lam_test = lambda : get_accuracy(model, dataset.test_dl, DEVICE, topk=1)

Using cache found in /home/coupekv/.cache/torch/hub/pytorch_vision_v0.10.0


In [11]:
ws_controller = WeightShare(model, lam_test, lam_opt, lam_train)

In [15]:
ws_controller.set_reset()
ws_controller.share(LAYER_CLUSTERS, prec_reduct=[PREC for _ in ws_controller.model_layers], clust_alg=CLUST_ALG)
torch.save(model.state_dict(), os.path.join(PATH_PREFIX, f'models/mobilenet_v2/saves/mobilenet_shared_{PREC}.save'))
ws_controller.reset()

In [16]:
ws_controller.finetuned_mod(
    layer_clusters = LAYER_CLUSTERS,
    mods_focus = np.arange(0, 10, 0.2),
    mods_spread = [2 for _ in ws_controller.model_layers],
    prec_reduct = [PREC for _ in ws_controller.model_layers],
    savefile = os.path.join(PATH_PREFIX, f'results/finetuning/mobilenet_v2_{PREC}.csv'),
    clust_alg = CLUST_ALG,
    verbose= True,
    shared_model_savefile = os.path.join(PATH_PREFIX, f'models/mobilenet_v2/saves/mobilenet_shared_{PREC}.save')
)

Processing layer 0
Processing layer 1
Processing layer 2
Processing layer 3
Processing layer 4
Processing layer 5
Processing layer 6
Processing layer 7
Processing layer 8
Processing layer 9
Processing layer 10
Processing layer 11
Processing layer 12
Processing layer 13
Processing layer 14
Processing layer 15
Processing layer 16
Processing layer 17
Processing layer 18
Processing layer 19
Processing layer 20
Processing layer 21
Processing layer 22
Processing layer 23
Processing layer 24
Processing layer 25
Processing layer 26
Processing layer 27
Processing layer 28
Processing layer 29
Processing layer 30
Processing layer 31
Processing layer 32
Processing layer 33
Processing layer 34
Processing layer 35
Processing layer 36
Processing layer 37
Processing layer 38
Processing layer 39
Processing layer 40
Processing layer 41
Processing layer 42
Processing layer 43
Processing layer 44
Processing layer 45
Processing layer 46
Processing layer 47
Processing layer 48
Processing layer 49
Processing

[1.4000000000000001,
 0.0,
 0.8,
 6.2,
 1.6,
 0.0,
 1.0,
 0.0,
 6.2,
 8.0,
 2.0,
 6.6000000000000005,
 1.2000000000000002,
 0.0,
 9.8,
 0.0,
 0.0,
 9.8,
 7.800000000000001,
 0.0,
 5.6000000000000005,
 0.4,
 6.0,
 2.2,
 0.8,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 3.0,
 1.0,
 7.4,
 0.0,
 4.800000000000001,
 2.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 8.8,
 0.0,
 0.0,
 2.0,
 0.0,
 1.2000000000000002,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [17]:
torch.save(model.state_dict(), os.path.join(PATH_PREFIX, f'models/mobilenet_v2/saves/mobilenet_finetuned_{PREC}.save'))

In [12]:
model.load_state_dict(torch.load(os.path.join(PATH_PREFIX, f'models/mobilenet_v2/saves/mobilenet_finetuned_{PREC}.save')))
ws_controller.test()
# 0.8205968141555786

0.8391557335853577