Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved flop budget and add double skip resnet18 #10

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
31 changes: 24 additions & 7 deletions finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
ap.add_argument('--test_only', '-t', type=bool, default=False, help='test the best model')
ap.add_argument('--workers', default=0, type=int, help='number of workers')
ap.add_argument('--cuda_id', '-id', type=str, default='0', help='gpu number')
ap.add_argument('--label_smoothing', '-ls', type=float, default=0, help='set label smoothing')

args = ap.parse_args()

valid_size=args.valid_size
Expand Down Expand Up @@ -58,10 +60,22 @@
state = torch.load(model_path)['state_dict']
model.load_state_dict(state, strict=False)
CE = nn.CrossEntropyLoss()
def criterion(model, y_pred, y_true):
def criterion_test(model, y_pred, y_true):
ce_loss = CE(y_pred, y_true)
return ce_loss

if args.label_smoothing>0:
CE_smooth = CrossEntropyLabelSmooth(data_object.num_classes , args.label_smoothing)
def criterion_train(model, y_pred, y_true):
ce_loss = CE_smooth(y_pred, y_true)
return ce_loss
else:
def criterion_train(model, y_pred, y_true):
ce_loss = CE(y_pred, y_true)
return ce_loss



optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay)
device = torch.device(f"cuda:{str(args.cuda_id)}")
model.to(device)
Expand Down Expand Up @@ -119,12 +133,15 @@ def test(model, loss_fn, optimizer, phase):
train_losses = []
valid_losses = []
valid_accuracy = []
name = f'{args.name}_{args.dataset}_finetuned'
if args.label_smoothing>0:
name += '_label_smoothing'
if args.test_only == False:
for epoch in range(num_epochs):
adjust_learning_rate(optimizer, epoch, args)
print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
train_loss = train(model, criterion, optimizer)
accuracy, valid_loss = test(model, criterion, optimizer, "val")
train_loss = train(model, criterion_train, optimizer)
accuracy, valid_loss = test(model, criterion_test, optimizer, "val")
remaining = model.get_remaining(20.,args.budget_type).item()

if accuracy>best_accuracy:
Expand All @@ -135,16 +152,16 @@ def test(model, loss_fn, optimizer, phase):
"state_dict" : model.state_dict(),
"acc" : best_accuracy,
"rem" : remaining,
}, f"checkpoints/{args.name}_{args.dataset}_finetuned.pth")
}, f"checkpoints/{name}.pth")

train_losses.append(train_loss)
valid_losses.append(valid_loss)
valid_accuracy.append(accuracy)
df_data=np.array([train_losses, valid_losses, valid_accuracy]).T
df = pd.DataFrame(df_data,columns = ['train_losses','valid_losses','valid_accuracy'])
df.to_csv(f"logs/{args.name}_{args.dataset}_finetuned.csv")
df.to_csv(f"logs/{name}.csv")

state = torch.load(f"checkpoints/{args.name}_{args.dataset}_finetuned.pth")
state = torch.load(f"checkpoints/{name}.pth")
model.load_state_dict(state['state_dict'],strict=True)
acc, v_loss = test(model, criterion, optimizer, "test")
acc, v_loss = test(model, criterion_test, optimizer, "test")
print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}")
73 changes: 58 additions & 15 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def __init__(self):
super(BaseModel, self).__init__()
self.prunable_modules = []
self.prev_module = defaultdict()
# self.next_module = defaultdict()
pass

def set_threshold(self, threshold):
Expand Down Expand Up @@ -48,9 +47,10 @@ def calculate_prune_threshold(self, Vc, budget_type = 'channel_ratio'):
def smoothRound(self, x, steepness=20.):
return 1./(1.+torch.exp(-1*steepness*(x-0.5)))

def n_remaining(self, m, steepness=20.):
return (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness)).sum()

def n_remaining(self, m, steepness=20., do_sum=True):
rem = (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness))
return rem.sum() if do_sum else rem

def is_all_pruned(self, m):
return self.n_remaining(m) == 0

Expand All @@ -72,13 +72,57 @@ def get_remaining(self, steepness=20., budget_type = 'channel_ratio'):
n_rem += self.n_remaining(l_block, steepness)*prev_remaining*k*k
n_total += l_block.num_gates*prev_total*k*k
elif budget_type == 'flops_ratio':
k = l_block._conv_module.kernel_size[0]
output_area = l_block._conv_module.output_area
prev_total = 3 if self.prev_module[l_block] is None else self.prev_module[l_block].num_gates
prev_remaining = 3 if self.prev_module[l_block] is None else self.n_remaining(self.prev_module[l_block], steepness)
k1 = l_block._conv_module.kernel_size[0]
k2 = l_block._conv_module.kernel_size[1]
active_elements_count = l_block._conv_module.output_area
if self.prev_module[l_block] is None:
prev_total = 3
prev_remaining = 3
elif isinstance(self.prev_module[l_block], nn.BatchNorm2d):
prev_total = self.prev_module[l_block].num_gates
prev_remaining = self.n_remaining(self.prev_module[l_block], steepness)
else:
prev_total = self.prev_module[l_block][-1].num_gates
def cal_max(prev):
if isinstance(prev[0], nn.BatchNorm2d):
prev1 = self.n_remaining(prev[0], steepness, do_sum=False)
prev2 = self.n_remaining(prev[1], steepness, do_sum=False)
return (torch.maximum(prev1, prev2) + torch.maximum(prev2, prev1))/2
prev2 = self.n_remaining(prev[-1], steepness, do_sum=False)
list_ = cal_max(prev[0])
return (torch.maximum(list_, prev2) + torch.maximum(prev2, list_))/2

prev_remaining = cal_max(self.prev_module[l_block]).sum()

curr_remaining = self.n_remaining(l_block, steepness)
n_rem += curr_remaining*prev_remaining*k*k*output_area + curr_remaining*output_area
n_total += l_block.num_gates*prev_total*k*k*output_area + l_block.num_gates*output_area

## Prunned
# conv
conv_per_position_flops = k1 * k2 * prev_remaining * curr_remaining
n_rem += conv_per_position_flops * active_elements_count
if l_block._conv_module.bias is not None:
n_rem += curr_remaining * active_elements_count

# bn
batch_flops = curr_remaining * active_elements_count
n_rem += batch_flops ## ReLU flops
if l_block.affine:
batch_flops *= 2
n_rem += batch_flops

## normal
# conv
conv_per_position_flops = k1 * k2 * prev_total * l_block.num_gates
n_total += conv_per_position_flops * active_elements_count
if l_block._conv_module.bias is not None:
n_total += l_block.num_gates * active_elements_count

# bn
batch_flops = l_block.num_gates * active_elements_count
n_total += batch_flops ## ReLU flops
if l_block.affine:
batch_flops *= 2
n_total += batch_flops
return n_rem/n_total

def give_zetas(self):
Expand Down Expand Up @@ -128,7 +172,7 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N
high = mid-1
else:
low = mid+1
elif budget_type == 'flops_ratio':
elif budget_type == 'flops_ratio' and threshold==None:
zetas = sorted(self.give_zetas())
high = len(zetas)-1
low = 0
Expand All @@ -138,12 +182,11 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N
for l_block in self.prunable_modules:
l_block.prune(threshold)
self.remove_orphans()
if self.flops()<Vc:
if self.get_remaining(steepness=20., budget_type='flops_ratio')<Vc:
high = mid-1
else:
low = mid+1
else:
if threshold==None:
elif threshold==None:
self.prune_threshold = self.calculate_prune_threshold(Vc, budget_type)
threshold = min(self.prune_threshold, 0.9)

Expand All @@ -166,7 +209,7 @@ def prepare_for_finetuning(self, device, budget, budget_type = 'channel_ratio'):
self.device = device
self(torch.rand(2,3,32,32).to(device))
threshold = self.prune(budget, budget_type=budget_type, finetuning=True)
if budget_type not in ['parameter_ratio', 'flops_ratio']:
if budget_type not in ['parameter_ratio']:
while self.get_remaining(steepness=20., budget_type=budget_type)<budget:
threshold-=0.0001
self.prune(budget, finetuning=True, budget_type=budget_type, threshold=threshold)
Expand Down
Loading