In [1]:
import torch
import os
import random
from models import *
from torch.optim import Adam
from dataformatter import * 
import torch.optim as optim
from vanilla_kd import VanillaKD

### Utilities

In [2]:
def set_random_seed(seed):
	# Esp important for ensuring deterministic behavior with CNNs
	torch.backends.cudnn.deterministic = True
	np.random.seed(seed)
	random.seed(seed)
	torch.manual_seed(seed)
	cuda_available = torch.cuda.is_available()
	if cuda_available:
		torch.cuda.manual_seed_all(seed)
	return cuda_available

#### Hyper Params

In [3]:
N_EPOCHS = 5
DATA_PATH = 'hawkeye_trace_belady_graph.csv' # This is the CSV FILE WE ARE TRYING TO ANALYZE
TR_DESC = 'GRAPH'
EVAL_DATA_PATH = 'hawkeye_trace_belady_xalancbmk.csv' # This is the CSV FILE WE ARE TRYING TO ANALYZE
EVAL_DESC = 'XALANBMK'

SAVE_FLDR = 'results'
N_EPOCHS = 5
MAX_GRAD_NORM = 0.1
SET_WISE = True
RANDOM_SEED = 140982301

#### Set location of the teacher model

In [4]:
teacher_model_file="TRANSFORMER_1_BSZ.32_LR.0.0001_saved_model.pth"

In [5]:
MODEL_TYPE='_'.join(teacher_model_file.split("_")[:2])
MODEL_DESC='{} MODEL'.format(MODEL_TYPE)
BATCH_SZ=int(teacher_model_file.split("_")[2][4:])
LR=float(teacher_model_file.split("_")[3][3:])
print(MODEL_TYPE,BATCH_SZ,LR)

TRANSFORMER_1 32 0.0001


#### Setting up the teacher

In [6]:
set_random_seed(RANDOM_SEED)
teacher_model = get_model(MODEL_TYPE)

In [7]:
chosen_columns = teacher_model.get_data_columns()
train_dataset = csv_to_data(DATA_PATH, chosen_columns)
eval_dataset = csv_to_data(EVAL_DATA_PATH, chosen_columns)

average_pred = np.mean(train_dataset[:, -1])
if torch.cuda.is_available():
    teacher_model.cuda()

# teacher_model.load_state_dict(torch.load('../teacher_FC.model'))

In [8]:
if SET_WISE:
    train_setwise_dataset = group_by_set(train_dataset)
    eval_setwise_dataset = group_by_set(eval_dataset)
    all_tr_keys = list(train_setwise_dataset.keys())
    val_keys = np.random.choice(all_tr_keys, size=int(0.1 * len(all_tr_keys)))
    tr_keys = set(all_tr_keys) - set(val_keys)

    vals = [len(x) for x in list(train_setwise_dataset.values())]
    max_key = all_tr_keys[np.argmax(vals)]
    teacher_model.prep_for_data(train_setwise_dataset[max_key], temp_order=True)
    
    for set_id, this_dataset in train_setwise_dataset.items():
        teacher_model.remap_embedders(this_dataset, set_id)

    tr_val_setwise_dataset = {k: train_setwise_dataset[k] for k in val_keys}
    train_setwise_dataset = {k: train_setwise_dataset[k] for k in tr_keys}
else:
    teacher_model.prep_for_data(train_dataset, temp_order=True)

In [9]:
teacher_model.load_state_dict(torch.load('./results/'+teacher_model_file))

<All keys matched successfully>

In [10]:
# def evaluate(model, dataset, epoch_=-1, print_res=True, shuffle=False):
#     # get a data iterator for this epoch
#     model.eval()
#     data_iter = get_batch_iterator(dataset, BATCH_SZ, shuffle=shuffle)
#     epoch_stats = run_epoch(model, optimizer, data_iter, mode='test')
#     if print_res:
#         print('Epoch {} : Avg Loss = {} Avg Acc = {}'.format(epoch_, epoch_stats[0], epoch_stats[1]))
#     return epoch_stats

In [11]:
if not os.path.exists(SAVE_FLDR):
    os.makedirs(SAVE_FLDR)

student_model = get_model('SFC')
# chosen_columns = student_model.get_data_columns()


vals = [len(x) for x in list(train_setwise_dataset.values())]
max_key = all_tr_keys[np.argmax(vals)]
student_model.prep_for_data(train_setwise_dataset[max_key], temp_order=True)
    
for set_id, this_dataset in train_setwise_dataset.items():
    student_model.remap_embedders(this_dataset, set_id)

if torch.cuda.is_available():
    student_model.cuda()
student_model.use_cuda = torch.cuda.is_available()
print(student_model)

MLP(
  (loss_fn): CrossEntropyLoss()
  (model): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=192, out_features=300, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=300, out_features=2, bias=True)
  )
  (pc_embedding): Embedding(500, 96, padding_idx=0)
  (set_occ_embedding): Embedding(500, 96, padding_idx=0)
)


In [12]:
student_optimizer = Adam(student_model.parameters(), 0.01)
teacher_optimizer = Adam(student_model.parameters(), lr=LR)

In [13]:
distiller = VanillaKD(teacher_model, student_model,teacher_optimizer, student_optimizer,train_setwise_dataset,BATCH_SZ,False)

In [14]:
distiller.train_student(10)

Training Student...


RuntimeError: Input, output and indices must be on the current device