Adding following weighted MSE to my code:
def weighted_mse(y, pred, x, mask, weight=10):
x = torch.reshape(x[:,0,:,:],(-1,1,600,600))
mask = torch.reshape(mask,(-1,1,600,600))
num_und_pix = torch.sum(x!=0)
num_dam_pix = torch.sum(mask==1)
out_und = ((((y-pred)**2)*x).sum()/num_und_pix )
out_dam = ((((y-pred)**2)*mask).sum()/num_dam_pix)*weight
loss = out_und + out_dam
return loss
generates a warning at the console:
148: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
value = torch.tensor(value, device=device, dtype=torch.float)
Before adding the above loss function the model was working without any warning with torch.nn.MSELoss.
Since My loss is a flat curve and does not deceasing I believe this backs to above warning.