Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An important correction if you want to Fine Tune an already pre-train… #91108

Closed

Conversation

kraidiky
Copy link

@kraidiky kraidiky commented Dec 19, 2022

…ed model with a big moment. If you want to continue training with this optimizer, for example, with a moment of 0.99, the gradient of the first batch will be x100 more efficient than all the others. This may even knock the model out of the optimum it already founded.

No issue

cc @vincentqb @jbschlosser @albanD @janeyx99

…ed model with a big moment. If you want to continue training with this optimizer, for example, with a moment of 0.99, the gradient of the first batch will be x100 more efficient than all the others. This may even knock the model out of the optimum it already founded.
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 19, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91108

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 073f677:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch!

could a unit test be added in test_optim.py to ensure this doesn’t regress?

@kraidiky
Copy link
Author

could a unit test be added in test_optim.py to ensure this doesn’t regress?
I am sorry, it will be difficult to me, I'm a novice and have no environment to test torch. Can you help with it?

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is BC-breaking.
You're now biasing towards a zero momentum at the beginning. I could see this significantly slowing down training for some users.
We should definitely look into more details at the impact of this change before merging it!

@kraidiky
Copy link
Author

kraidiky commented Dec 20, 2022

Not from zero, from 1-momentum, else was removed. This is the normal behavior of sgd with momentum

Let's see, for example momentum will be 3/4, old version:
step 1:
grad = a;
momentum buffer = a;
value changes/lr = a;
step 2:
grad = b;
momentum buffer = a3/4 + b1/4;
value changes/lr = a(1 + 3/4) + b(1/4);
step 3:
grad = c;
momentum buffer = a9/16 + b3/16 + c1/4;
value changes/lr = a(1 + 3/4 + 9/16) + b(1/4 + 3/16) + c(1/4)
step 4:
grad = d;
momentum buffer = a
27/64 + b9/64 + c3/16 + d(1/4);
value changes/lr = a(1 + 3/4 + 9/16) + b(1/4 + 3/16) + c(1/4)
Limit of a sequence (geometric progression):
b,c,d and all other will be: (1/4)/(1-3/4) = 1
but for a it will be: 1/(1-3/4) = 4

It means that the training is not faster, just the contribution of the first batch is a multiple higher, and the rest of the training goes the same way.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 21, 2022
@kraidiky
Copy link
Author

This is example based on my project:

image0df = FileToImageDF(filenames[0])
image_size = (image0df.x.max()+1, image0df.y.max()+1)
ImageDFToInputOutput(image0df)
(train, check, test) = FrameSquareSplitDF(image0df, .4, 40, .20, .10, shift = (5, 25))
#demo for github
print(f'original fine tuning, original dataset size: {image_size[0]*image_size[1]}, batch size 64, 100 epoches')
torch.manual_seed(999)
net = SimpleNet(2, 3, 16, 3)
log1 = Train(net, train, 100, 64, optimizer = torch.optim.SGD(net.parameters(), lr=0.5, momentum=1-1/4, dampening=1-1/4)) #Start learning
log2 = Train(net, train, 100, 64, optimizer = torch.optim.SGD(net.parameters(), lr=0.5, momentum=1-1/200, dampening=1-1/200)) #Fine Tuning
log = ConcatLogs([log1, log2])
print('fixed fine tuning')
torch.manual_seed(999)
net = SimpleNet(2, 3, 16, 3)
log1 = Train(net, train, 100, 64, optimizer = fixed.SGD(net.parameters(), lr=0.5, momentum=1-1/4, dampening=1-1/4)) #Start learning
log2 = Train(net, train, 100, 64, optimizer = fixed.SGD(net.parameters(), lr=0.5, momentum=1-1/200, dampening=1-1/200)) #Fine Tuning
log_fixed = ConcatLogs([log1, log2])

output log:

original fine tuning, original dataset size: 468480, batch size 64, 100 epoches
epoch=0, epoch_loss:0.12477422104009632, grad:0.028602054342627525, time:4.102067232131958
epoch=1, epoch_loss:0.12466210855710846, grad:0.006887256167829037, time:8.249023675918579
epoch=2, epoch_loss:0.12464358331370394, grad:0.009717237204313278, time:11.789730787277222
...
epoch=98, epoch_loss:0.03149822497088306, grad:0.03132188320159912, time:370.8234317302704
epoch=99, epoch_loss:0.03142608949287453, grad:0.0316515788435936, time:374.6575493812561
Train time: 374.6575493812561
epoch=0, epoch_loss:0.03256784819278685, grad:0.07165557146072388, time:3.6130664348602295
epoch=1, epoch_loss:0.0312966939392601, grad:0.02535572648048401, time:7.2863898277282715
epoch=2, epoch_loss:0.031234623639028476, grad:0.03266138955950737, time:10.81709599494934
epoch=3, epoch_loss:0.03106951178418132, grad:0.03074655309319496, time:14.325260162353516
...
epoch=99, epoch_loss:0.02801777953079198, grad:0.04015304893255234, time:361.7370517253876
Train time: 361.7370517253876
fixed fine tuning
epoch=0, epoch_loss:0.1248286312748639, grad:0.04875539243221283, time:4.0784361362457275
epoch=1, epoch_loss:0.1246622925827052, grad:0.006915505044162273, time:8.866403102874756
epoch=2, epoch_loss:0.12464372644472362, grad:0.00974933709949255, time:14.37080192565918
...
epoch=98, epoch_loss:0.03124530990319436, grad:0.03003973513841629, time:393.6995451450348
epoch=99, epoch_loss:0.031167515638086264, grad:0.029645301401615143, time:398.4200973510742
Train time: 398.4200973510742
epoch=0, epoch_loss:0.031061670249031618, grad:0.03348888084292412, time:4.562276601791382
epoch=1, epoch_loss:0.031036571202365998, grad:0.03117925301194191, time:9.416330814361572
epoch=2, epoch_loss:0.030937636996833124, grad:0.028899099677801132, time:13.785247087478638
...
epoch=98, epoch_loss:0.027763123887667503, grad:0.03270133584737778, time:405.28802037239075
epoch=99, epoch_loss:0.02775227148928235, grad:0.03276437148451805, time:409.34902334213257
Train time: 409.34902334213257

You can see that first 2 epoch of fine tuning series loss was increased. It's effect of this error. With my fixed version there is no difference btw first and second optimizator. Function is smooth.

But I do not know how to implement such a test in your test_optim.py

@kraidiky kraidiky requested review from janeyx99 and albanD and removed request for janeyx99 and albanD December 22, 2022 21:17
@albanD
Copy link
Collaborator

albanD commented Dec 26, 2022

Not from zero, from 1-momentum, else was removed.

With the current version, the else had no effect because you interpolate between twice the same value. So it is the same as if you were initializing with the grad and then going from there.
While the new version is initializing at 0 and going from there.

@kraidiky
Copy link
Author

The difference is way to use dampening. For example, when you use momentum = 1-1/4, it means that the effective lr will be quadrupled. To prevent this effect you must set dampening = momentum = 1-1/4 too. With this property integral effect of each batch will be the same as momentum=0 but more smoothed.
But in original erroneous implementation first value of momentum will be grade and each other damped. this string is:

.add_(d_p, alpha=1 - dampening)

In my fix buffer will be zeros and after that on the same call will be executed old else branch and bush will receive dumped grad, like an all other calls.

This is the the same effect if you write:

if buf is None:
buf = momentum_buffer_list[i] = d_p * (1 - dampening)
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

but my fix has lesser copy-paste

@IvanYashchuk IvanYashchuk added the module: optimizer Related to torch.optim label Dec 27, 2022
@kraidiky
Copy link
Author

kraidiky commented Jan 5, 2023

@albanD To illustrate the changes I made, I made an example of a network on the mnist dataset so that you could experiment with it yourself. fixed file sgd.py must be at the same folder:

import time
import torch
import torch.nn.functional as F
import torchvision
from typing import cast, List, Optional, Dict, Tuple
import sgd as fixed # file sgd.py

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#Load Dataset
mnist = torchvision.datasets.MNIST("", train = True, download = True)
x_train = torch.Tensor(mnist.data).unsqueeze(1).to(device)/255 # shape=(60000, 1, 28, 28)
y_train = torch.Tensor(mnist.targets).to(device)
mnist = torchvision.datasets.MNIST("", train = False, download = True)
x_test = torch.Tensor(mnist.data).unsqueeze(1).to(device)/255 # shape=(10000, 1, 28, 28)
y_test = torch.Tensor(mnist.targets).to(device)
print(f'MNIST: train x:{tuple(x_train.shape)}, y:{tuple(y_train.shape)}; test x:{tuple(x_test.shape)}, y:{tuple(y_test.shape)}')

#Shuffle and split dataset
def Batches(X, Y, shuffle = True, batch_size = 256):
    if shuffle:
        shuffle_indexes = torch.randperm(len(Y))
        X = X[shuffle_indexes]
        Y = Y[shuffle_indexes]
    return zip(torch.split(X, batch_size), torch.split(Y, batch_size))

#Sample net
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size = 3)
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 3)
        self.max_pool2d1 = nn.MaxPool2d(2)
        self.max_pool2d2 = nn.MaxPool2d(2)
        self.conv2_drop = nn.Dropout2d(p = 0.3)
        self.dropout = nn.Dropout(p = 0.3)
        self.fc1 = nn.Linear(800, 32)
        self.fc2 = nn.Linear(32, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool2d1(x)
        x = self.conv2_drop(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max_pool2d2(x) # torch.Size([256, 32, 5, 5])
        # x = F.relu(F.max_pool2d(self.conv1(x), 2))
        # x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) # 204800
        # x = F.relu(self.conv2_drop(self.conv2(x))) # 991232
        x = x.view(-1, 800) # torch.Size([256, 800]) (32 * 5 * 5 = 800)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        # x = F.dropout(x, training = self.training) # self.training = True при обучении
        x = self.fc2(x)
        return F.log_softmax(x, dim = -1)

def load(model:Net, device, file):
    model.load_state_dict(torch.load(file, map_location = torch.device(device)))

def train(model:Net, epochs:int, optimizer:torch.optim.Optimizer, save:str = None) -> List[float]:
    log = []
    for e in range(epochs):
        start_time = time.time()
        learn_loss = trn_loss = tst_loss = 0 # learn_loss and trn_loss is different, cause of dropout
        model.train() # Режим обучения
        for batch_no, (data, target) in enumerate(Batches(x_train, y_train)):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            learn_loss += loss.item() * data.size(0)
        model.eval() # Режим оценки
        for data, target in Batches(x_train, y_train, shuffle=False):
            output = model(data)
            loss = criterion(output, target)
            trn_loss += loss.item() * data.size(0)
        for data, target in Batches(x_test, y_test, shuffle=False):
            output = model(data)
            loss = criterion(output, target)
            tst_loss += loss.item() * data.size(0)
        learn_loss /= x_train.size(0)
        trn_loss /= x_train.size(0)
        tst_loss /= x_test.size(0)
        print('Epoch: {} learn loss: {:.6f} \ttrain loss: {:.6f} \ttest loss: {:.6f} \ttime:{:.6f}'.format(
            e + 1, learn_loss, trn_loss, tst_loss, time.time()-start_time))
        if save is not None:
            torch.save(model.state_dict(), save)
        log.append([trn_loss, tst_loss, learn_loss])
    return log

def test(model:Net):
    print('Detalied test')
    num_classes = len(y_train.unique())
    tst_loss = 0
    cls_correct = [0]*num_classes
    cls_total = [0]*num_classes
    model.eval() #Режим оценки
    for data, target in Batches(x_test, y_test):
        output = model(data)
        loss = criterion(output, target)
        tst_loss += loss.item() * data.size(0)
        _, pred = torch.max(output, 1)
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        for i in range(len(target)):
            label = target.data[i]
            cls_correct[label] += correct[i].item()
            cls_total[label] += 1
    tst_loss /= y_test.size(0)
    print('Test Loss: {:.6f}'.format(tst_loss))
    print('Accuracy: {:.4f}'.format(sum(cls_correct) / sum(cls_total)))
    print('Accuracy by classes:')
    cls = 0
    for cc, ct in zip(cls_correct, cls_total):
        print('{} - {:.4f}'.format(cls, cc / ct))
        cls += 1

criterion = nn.CrossEntropyLoss()
torch.manual_seed(999)
model = Net()
model.to(device)
print('torch.optim.SGD:')
log1 = train(model, 30,  torch.optim.SGD(model.parameters(), lr = 1, momentum=1-1/100, dampening=1-1/100))
log2 = train(model, 10,  torch.optim.SGD(model.parameters(), lr = 0.1, momentum=1-1/100, dampening=1-1/100))
#test(model)

if log1[-1][0] < log2[0][0]:
    print(f'ERROR! Loss was dramatically increased. log1[-1][0] < log2[0][0]: {log1[-1][0]} < {log2[0][0]}')

torch.manual_seed(999)
model = Net()
model.to(device)
print('fixed.SGD')
log4 = train(model, 30,  fixed.SGD(model.parameters(), lr = 1, momentum=1-1/100, dampening=1-1/100))
log5 = train(model, 10,  fixed.SGD(model.parameters(), lr = 0.1, momentum=1-1/100, dampening=1-1/100))
#test(model)

if log4[-1][0] < log5[0][0]:
    print(f'ERROR! Loss was dramatically increased. log4[-1][0] < log5[0][0]: {log4[-1][0]} < {log5[0][0]}')

RESULTS:

torch.optim.SGD:
Epoch: 1 learn loss: 0.959126 train loss: 0.180986 test loss: 0.166251 time:2.176997
...
Epoch: 30 learn loss: 0.064928 train loss: 0.026299 test loss: 0.034716 time:0.821000
Epoch: 1 learn loss: 0.108850 train loss: 0.032349 test loss: 0.041665 time:0.817999
Epoch: 2 learn loss: 0.074319 train loss: 0.024775 test loss: 0.034391 time:0.819999
...
Epoch: 10 learn loss: 0.044947 train loss: 0.016667 test loss: 0.026992 time:0.829000

ERROR! Loss was dramatically increased. log1[-1][0] < log2[0][0]: 0.026298949959874154 < 0.032348830696940424
fixed.SGD
Epoch: 1 learn loss: 1.025694 train loss: 0.177346 test loss: 0.162757 time:0.814998
...
Epoch: 30 learn loss: 0.046051 train loss: 0.013763 test loss: 0.028930 time:0.807999
Epoch: 1 learn loss: 0.038613 train loss: 0.011081 test loss: 0.026304 time:0.800999
...
Epoch: 10 learn loss: 0.031889 train loss: 0.009083 test loss: 0.024232 time:0.814999

P.S. By the way, in Adam, where exponential moving average are also used, buffers are initialized with zeros:
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)

@albanD
Copy link
Collaborator

albanD commented Jan 11, 2023

Hi,
I am not arguing that this change doesn't improve behavior in some cases.
My main hesitation is around the trade-off:

  • this is a BC-breaking change and it will change the behavior of user code (for better or worse, we can't predict for 100% of the use cases). So this is not something we want to do lightly.
  • initializing at 0 is the "usual" thing to do from the literature. So this is most likely what users expect to happen (as shown by the existence of this PR). So there is an incentive toward changing the current behavior.

Keeping this second point in mind, I'm wondering why we ended up with this different initialization. I think it would be worth digging into the git history to see if that has always been the case or if it was changed at some point from 0 to what it is today.
If it has always been like that, then I'm happy to consider this an oversight during the initial implementation and change the default. If it was intentionally set to this non-zero value, I would be curious to know why and if the reason for doing so is still valid.

@soumith
Copy link
Member

soumith commented Jan 12, 2023

i havent followed this discussion closely, but the original nesterov momentum is a bit simplified / re-derived and is from: torch/optim#27 (comment) (see the linked PDF in the comment).

hope this helps.

@kraidiky
Copy link
Author

kraidiky commented Jan 12, 2023

i havent followed this discussion closely, but the original nesterov momentum is a bit simplified / re-derived and is from: torch/optim#27 (comment) (see the linked PDF in the comment).

The error was already into the code when the Nesterov momentum was added there in 2017, and when using it, it just did not show itself, because for Nesterov must necessarily dampening == 0. But thanks for the link, I looked into nag.lua from which file clementfarabet the code peeked and initialization with zeros is also implemented there. In the code it looks like this:

-- (4) apply momentum
   if mom ~= 0 then
      if not state.dfdx then
         state.dfdx = torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(0)
      else
         state.dfdx:mul(mom)
      end
   end

@albanD However , during sgd.py the error is present from the very beginning, from commite
Revision: 554a1d8
Author: Adam Paszke adam.paszke@gmail.com
Date: 19.07.2016 6:21:47
Message: Add optim

if 'dfdx' not in state:
            state['dfdx'] = torch.Tensor().typeAs(dfdx).resizeAs(dfdx).copy(dfdx)
        else:
            state['dfdx'].mul(mom).add(1-damp, dfdx)

It was in the fourth month of development and it is unlikely that we will now be able to find out from Adam where he was peeping when he wrote this code. At least google claims that this is not an exact copy-paste from some other place before 19.07.2016. Perhaps it is just a rare and very old error.

@albanD
Copy link
Collaborator

albanD commented Jan 12, 2023

Hey!

From looking at these in more details, I don't think we want to change the default here. The main reasons are:

  • There is no theoretical difference between the two
  • There is no original paper that enforces the original value that we can use as a justification
  • It is easy to find special cases where one will perform better or the other (in particular if you look at a case where the first step is optimal in one of the two settings, it is simple to build problems where one will be optimal and not the other)
  • Changing this default is BC-breaking and so we don't want to do it unless we have a good reason to do so.

I understand that this might not be optimal in your case though. There are a couple things you could do here:

  • Do a first step with gradients full of zeros.
  • Copy/paste the PT optimizer from

    pytorch/torch/optim/sgd.py

    Lines 220 to 252 in b7cad02

    def _single_tensor_sgd(params: List[Tensor],
    d_p_list: List[Tensor],
    momentum_buffer_list: List[Optional[Tensor]],
    *,
    weight_decay: float,
    momentum: float,
    lr: float,
    dampening: float,
    nesterov: bool,
    maximize: bool,
    has_sparse_grad: bool):
    for i, param in enumerate(params):
    d_p = d_p_list[i] if not maximize else -d_p_list[i]
    if weight_decay != 0:
    d_p = d_p.add(param, alpha=weight_decay)
    if momentum != 0:
    buf = momentum_buffer_list[i]
    if buf is None:
    buf = torch.clone(d_p).detach()
    momentum_buffer_list[i] = buf
    else:
    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
    if nesterov:
    d_p = d_p.add(buf, alpha=momentum)
    else:
    d_p = buf
    param.add_(d_p, alpha=-lr)
    and update the initial momentum buffer to your prefered value.

While the doc already explicitly states what the initial value is, we can add one line to the note about the details of our SGD vs other frameworks to highlight the initialization. https://pytorch.org/docs/stable/generated/torch.optim.SGD.html?highlight=sgd#torch.optim.SGD

@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Mar 13, 2023
@github-actions github-actions bot closed this Apr 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants