diff --git a/train.py b/train.py index b6a8798..ce34fa9 100644 --- a/train.py +++ b/train.py @@ -88,6 +88,8 @@ def save_model(name): optimizer.zero_grad() loss.backward() optimizer.step() + + proto = None; logits = None; loss = None tl = tl.item() ta = ta.item() @@ -114,6 +116,8 @@ def save_model(name): vl.add(loss.item()) va.add(acc) + + proto = None; logits = None; loss = None vl = vl.item() va = va.item()