In [15]:
import torch
import os
import random
from models import *
from torch.optim import Adam
from dataformatter import *
import torch.optim as optim
from vanilla_kd import VanillaKD
import copy

### Utilities

In [16]:
def set_random_seed(seed):
	# Esp important for ensuring deterministic behavior with CNNs
	torch.backends.cudnn.deterministic = True
	np.random.seed(seed)
	random.seed(seed)
	torch.manual_seed(seed)
	cuda_available = torch.cuda.is_available()
	if cuda_available:
		torch.cuda.manual_seed_all(seed)
	return cuda_available

#### Hyper Params

In [17]:
N_EPOCHS = 5
DATA_PATH = '../hawkeye_trace_belady_graph.csv' # This is the CSV FILE WE ARE TRYING TO ANALYZE
TR_DESC = 'GRAPH'
EVAL_DATA_PATH = '../hawkeye_trace_belady_xalancbmk.csv' # This is the CSV FILE WE ARE TRYING TO ANALYZE
EVAL_DESC = 'XALANBMK'

SAVE_FLDR = 'results'
N_EPOCHS = 5
MAX_GRAD_NORM = 0.1
SET_WISE = True
RANDOM_SEED = 140982301

#### Set location of the teacher model

In [18]:
teacher_model_file="GRAPH-TRANSFORMER_BSZ.64_LR.0.0001_saved_model.pth"

In [19]:
MODEL_TYPE=teacher_model_file.split("_")[0].split('-')[1]
MODEL_DESC='{} MODEL'.format(MODEL_TYPE)
BATCH_SZ=int(teacher_model_file.split("_")[1].split('.')[1])
LR=float(teacher_model_file.split("_")[2][3:])
print(MODEL_TYPE,BATCH_SZ,LR)

TRANSFORMER 64 0.0001


#### Setting up the teacher

In [20]:
set_random_seed(RANDOM_SEED)
teacher_model = get_model(MODEL_TYPE)

In [21]:
chosen_columns = teacher_model.get_data_columns()
train_dataset = csv_to_data(DATA_PATH, chosen_columns)
eval_dataset = csv_to_data(EVAL_DATA_PATH, chosen_columns)

average_pred = np.mean(train_dataset[:, -1])
if torch.cuda.is_available():
    teacher_model.cuda()

# teacher_model.load_state_dict(torch.load('../teacher_FC.model'))

In [22]:
if SET_WISE:
    train_setwise_dataset = group_by_set(train_dataset)
    eval_setwise_dataset = group_by_set(eval_dataset)
    all_tr_keys = list(train_setwise_dataset.keys())
    val_keys = np.random.choice(all_tr_keys, size=int(0.1 * len(all_tr_keys)))
    tr_keys = set(all_tr_keys) - set(val_keys)

    vals = [len(x) for x in list(train_setwise_dataset.values())]
    max_key = all_tr_keys[np.argmax(vals)]
    teacher_model.use_cuda=torch.cuda.is_available()
    teacher_model.prep_for_data(train_setwise_dataset[max_key], temp_order=True)
    
    for set_id, this_dataset in train_setwise_dataset.items():
        teacher_model.remap_embedders(this_dataset, set_id)

#     tr_val_setwise_dataset = {k: train_setwise_dataset[k] for k in val_keys}
#     train_setwise_dataset = {k: train_setwise_dataset[k] for k in tr_keys}
else:
    teacher_model.prep_for_data(train_dataset, temp_order=True)

In [23]:
teacher_model.load_state_dict(torch.load('./pytorch_c++/'+teacher_model_file))

<All keys matched successfully>

In [24]:
student_model_initialization='SFC'

In [25]:
if not os.path.exists(SAVE_FLDR):
    os.makedirs(SAVE_FLDR)

student_model = get_model(student_model_initialization)

# student_model=copy.deepcopy(teacher_model)
# chosen_columns = student_model.get_data_columns()


vals = [len(x) for x in list(train_setwise_dataset.values())]
max_key = all_tr_keys[np.argmax(vals)]
# student_model.use_cuda=True
print("TC",torch.cuda.is_available())
student_model.use_cuda = torch.cuda.is_available()

# student_model.prep_for_data(train_setwise_dataset[max_key], temp_order=True)

student_model.pc_emb_map=copy.deepcopy(teacher_model.pc_emb_map) # clone
student_model.set_occ_emb_map= copy.deepcopy(teacher_model.set_occ_emb_map) # clone
student_model.setid_to_map_map=copy.deepcopy(teacher_model.setid_to_map_map) #clone

student_model.pc_embedding=nn.Embedding.from_pretrained(teacher_model.pc_embedding.weight.clone().detach(),freeze=True) 
student_model.set_occ_embedding=nn.Embedding.from_pretrained(teacher_model.set_occ_embedding.weight.clone().detach(),freeze=True)# clone and detach

# if(student_model_initialization=="ST"):
    
for set_id, this_dataset in train_setwise_dataset.items():
    student_model.remap_embedders(this_dataset, set_id)

if torch.cuda.is_available():
    student_model.cuda()

print(student_model)

TC True
MLP(
  (loss_fn): CrossEntropyLoss()
  (model): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=128, out_features=256, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=256, out_features=2, bias=True)
  )
  (pc_embedding): Embedding(500, 64)
  (set_occ_embedding): Embedding(500, 64)
)


In [26]:
student_optimizer = Adam(student_model.parameters(), lr=1e-4)
teacher_optimizer = Adam(teacher_model.parameters(), lr=LR)

In [27]:
distiller = VanillaKD(teacher_model, student_model,teacher_optimizer, student_optimizer,
                      train_setwise_dataset,BATCH_SZ,False,student_type=student_model_initialization,
                      distil_weight=0.5,temp=20.0)

In [14]:
distiller.train_student(10)

Training Student...
Epoch: 0 | Student Average Accuracy:0.9024984679031053 Student Median Accuracy:1.0 |Distillation Loss: 0.45778968930244446
Epoch: 0 | Teacher Average Accuracy:0.9430225392364279 Teacher Median Accuracy:1.0 
Epoch: 1 | Student Average Accuracy:0.9430225392364279 Student Median Accuracy:1.0 |Distillation Loss: 0.033493876457214355
Epoch: 1 | Teacher Average Accuracy:0.9430225392364279 Teacher Median Accuracy:1.0 


KeyboardInterrupt: 

In [None]:
# SAVE_FLDR="results"
# torch.save(distiller.student_model.state_dict(), '{}/{}_saved_model.pth'.format(SAVE_FLDR,'STUDENT_'+student_model_initialization))