In [1]:
import torch
import os
import random
from models import *
from torch.optim import Adam
from dataformatter import * 
import torch.optim as optim
from vanilla_kd import VanillaKD
import copy

### Utilities

In [2]:
def set_random_seed(seed):
	# Esp important for ensuring deterministic behavior with CNNs
	torch.backends.cudnn.deterministic = True
	np.random.seed(seed)
	random.seed(seed)
	torch.manual_seed(seed)
	cuda_available = torch.cuda.is_available()
	if cuda_available:
		torch.cuda.manual_seed_all(seed)
	return cuda_available

#### Hyper Params

In [3]:
N_EPOCHS = 5
DATA_PATH = '../lru_trace_belady_bzip.csv' # This is the CSV FILE WE ARE TRYING TO ANALYZE
TR_DESC = 'GRAPH'
EVAL_DATA_PATH = '../lru_trace_belady_bzip.csv' # This is the CSV FILE WE ARE TRYING TO ANALYZE
EVAL_DESC = 'GRAPH'

SAVE_FLDR = 'results'
N_EPOCHS = 5
MAX_GRAD_NORM = 0.1
SET_WISE = True
RANDOM_SEED = 140982301
join=False
student_model_initialization="SFC"

#### Set location of the teacher model

In [4]:
teacher_model_file="HWKY_ONLY_BZIP-TRANSFORMER_BSZ.32_LR.0.0001_saved_model.pth"
student_model_file="STUDENT_SFC_HWKY_ONLY_BZIP-TRANSFORMER_BSZ.32_LR.0.0001_saved_model.pth"
if 'JOINT' in teacher_model_file:
    join=True
else:
    join=False

In [5]:
join

False

In [6]:
data_type,remaining=teacher_model_file.split('-')
MODEL_TYPE=''.join(remaining.split("_")[0])
MODEL_DESC='{} MODEL'.format(MODEL_TYPE)
BATCH_SZ=int(remaining.split("_")[1][4:])
LR=float(teacher_model_file.split("_")[3][3:])
print(MODEL_TYPE,MODEL_DESC,BATCH_SZ,LR)

TRANSFORMER TRANSFORMER MODEL 32 0.32


#### Setting up the teacher

In [7]:
set_random_seed(RANDOM_SEED)
teacher_model = get_model(MODEL_TYPE)

In [8]:
chosen_columns = teacher_model.get_data_columns()
chosen_columns.append('Set')

train_dataset = csv_to_data(DATA_PATH, chosen_columns)
eval_dataset = csv_to_data(EVAL_DATA_PATH, chosen_columns)

average_pred = np.mean(train_dataset[:, -1])
if torch.cuda.is_available():
    teacher_model.cuda()
teacher_model.use_cuda=torch.cuda.is_available()


# teacher_model.load_state_dict(torch.load('../teacher_FC.model'))

In [9]:
# Always SetWise
train_setwise_dataset = group_by_set(train_dataset, set_idx=3)
eval_setwise_dataset = group_by_set(eval_dataset, set_idx=3)

if join:
    MAX_TR_KEY = max([int(x) for x in train_setwise_dataset.keys()]) + 1
    new_eval_set = {}
    for k, v in eval_setwise_dataset.items():
        new_id = MAX_TR_KEY + int(k)
        train_setwise_dataset[new_id] = v
        new_eval_set[new_id] = v
#     print('Total keys After : ', len(train_setwise_dataset.keys()))
    eval_setwise_dataset = new_eval_set
    for k, v in train_setwise_dataset.items():
        if v.shape[0] < 2:
            print(k)
res = [x.shape[0] for _, x in train_setwise_dataset.items()]
# print(min(res), np.mean(res), np.median(res), max(res))
all_tr_keys = list(train_setwise_dataset.keys())
val_keys = np.random.choice(all_tr_keys, size=int(0.2 * len(all_tr_keys)))
tr_keys = set(all_tr_keys) - set(val_keys)

vals = [len(x) for x in list(train_setwise_dataset.values())]
max_key = all_tr_keys[np.argmax(vals)]

teacher_model.prep_for_data(train_setwise_dataset[max_key], temp_order=True)

for set_id, this_dataset in train_setwise_dataset.items():
    teacher_model.remap_embedders(this_dataset, set_id)

tr_val_setwise_dataset = {k: train_setwise_dataset[k] for k in val_keys}
train_setwise_dataset = {k: train_setwise_dataset[k] for k in tr_keys}


In [10]:
teacher_model.load_state_dict(torch.load('../'+teacher_model_file))

<All keys matched successfully>

In [11]:
if not os.path.exists(SAVE_FLDR):
    os.makedirs(SAVE_FLDR)

student_model = get_model(student_model_initialization)
if torch.cuda.is_available():
    student_model.cuda()
student_model.use_cuda = torch.cuda.is_available()

# chosen_columns = student_model.get_data_columns()

all_tr_keys = list(train_setwise_dataset.keys())
vals = [len(x) for x in list(train_setwise_dataset.values())]
max_key = all_tr_keys[np.argmax(vals)]

student_model.prep_for_data(train_setwise_dataset[max_key], temp_order=True)

student_model.pc_emb_map=copy.deepcopy(teacher_model.pc_emb_map) # clone
student_model.set_occ_emb_map= copy.deepcopy(teacher_model.set_occ_emb_map) # clone
student_model.setid_to_map_map=copy.deepcopy(teacher_model.setid_to_map_map) #clone

student_model.pc_embedding=nn.Embedding.from_pretrained(teacher_model.pc_embedding.weight.clone().detach(),freeze=True) 
student_model.set_occ_embedding=nn.Embedding.from_pretrained(teacher_model.set_occ_embedding.weight.clone().detach(),freeze=True)# clone and detach

for set_id, this_dataset in train_setwise_dataset.items():
    student_model.remap_embedders(this_dataset, set_id)



# print(student_model)

In [12]:
student_model.load_state_dict(torch.load('../student_models/'+student_model_file))

<All keys matched successfully>

In [13]:
student_optimizer = Adam(student_model.parameters(), lr=1e-4)
teacher_optimizer = Adam(teacher_model.parameters(), lr=LR)

In [14]:
distiller = VanillaKD(teacher_model, student_model,teacher_optimizer, student_optimizer,
                      train_setwise_dataset,tr_val_setwise_dataset,batch_size=BATCH_SZ,shuffle=False,student_type=student_model_initialization,
                      distil_weight=0.5,temp=20.0)

In [15]:
distiller.evaluate_student()

Evaluating Student...
Student Average Accuracy:0.8500331338762578 Student Median Accuracy:0.8999999761581421 |Student Loss:2.392160177230835


In [16]:
distiller.evaluate_teacher()

Evaluating Teacher...
Teacher Average Accuracy:0.9861546872399789 Teacher Median Accuracy:1.0 |Teacher Loss:5.116590182296932e-05


In [181]:
SAVE_FLDR="../student_models"
torch.save(distiller.student_model.state_dict(), 
           os.path.join(SAVE_FLDR,'STUDENT_'+student_model_initialization+'_'+teacher_model_file))