In [15]:
# Imports of Pytorch and matplotlib and other supporting modules
 
import math

import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader


In [16]:
# Setting default device

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

display(f'{device} is available')

dtype = torch.float
torch.set_default_device(device)

def train_loop(dataloader, model, loss_fn, optimize, epoch):
    """Training loop funciton for non-linear function"""
    model.train()
    training_loss = 0
    num_batches = len(dataloader)
    for batch, (X, y) in enumerate(dataloader):
        y_pred = model(X)
        y_pred = y_pred.unsqueeze(1)
        loss = loss_fn(y_pred, y)
        training_loss+=loss.item()

        loss.backward()
        optimize.step()
        optimize.zero_grad()
  
    training_loss /= num_batches
    return training_loss

def val_loop(dataloader, model, loss_fn, epoch):
    """Eval loop function for non-linear function"""
    model.eval()
    size= len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0,0

    with torch.no_grad():
        for X,y in dataloader:
            y_pred = model(X)
            y_pred = y_pred.unsqueeze(1)
            test_loss += loss_fn(y_pred, y).item()
            correct +=(y_pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /=size

    #if epoch % 100 == 0:
    #    display(f"Accuracy: {(100*correct)}%, Avg loss: {test_loss}")

    return test_loss

def cnt_model_params(model):
    """Count model parameters"""
    count = 0
    with torch.no_grad():
        for param in model.parameters():
            count+=param.numel()
    return count

def display_model_info(model_name, model):
    """ Display model information"""
    count = 0
    for module in model.modules():
        if isinstance(module, nn.Module):
            count+=1
    display(model)
    display(f"{model_name}. parameters: {cnt_model_params(model)}")

'cuda is available'

In [17]:
class CosNetwork(nn.Module):
        """First DNN for Cosine function"""
        def __init__(self):
            super().__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(1, 128),
                nn.ReLU(),
                nn.Linear(128, 12),
                nn.ReLU(),
                nn.Linear(12, 1),
                torch.nn.Flatten(0,1)
            )
        
        def forward(self, x):
              logits = self.linear_relu_stack(x)
              return logits
        
cos_model1 = CosNetwork()
display_model_info("cos_model1", cos_model1)


CosNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=12, bias=True)
    (3): ReLU()
    (4): Linear(in_features=12, out_features=1, bias=True)
    (5): Flatten(start_dim=0, end_dim=1)
  )
)

'cos_model1. parameters: 1817'

In [18]:
x, x2 = torch.utils.data.random_split(torch.linspace(-math.pi, math.pi, 20000, dtype=dtype), [16000, 4000], generator=torch.Generator(device=torch.device('cuda')).manual_seed(42))
x, x2 = torch.Tensor([b for idx, b in enumerate(x.dataset) if idx in x.indices]).to(device), torch.Tensor([c for idx, c in enumerate(x2.dataset) if idx in x2.indices]).to(device)
# torch.linspace(-math.pi, math.pi, 10000, dtype=dtype)
y = torch.cos(x)
y2 = torch.cos(x2)

lossy1, lossy2, lossy3 = list(), list(), list()
epochx1, epochx2, epochx3 = list(), list(), list()

loss_fn = torch.nn.MSELoss()
epochs = 1000
lr = 1e-3
batch_size=8000
optimizer1 = torch.optim.SGD(cos_model1.parameters(), lr=lr)
train_dataloader = DataLoader(TensorDataset(x.unsqueeze(1),y.unsqueeze(1)), batch_size=batch_size)
val_dataloader = DataLoader(TensorDataset(x2.unsqueeze(1),y2.unsqueeze(1)), batch_size=batch_size)

display("Training on original loss function")
for epoch in range(epochs):
    test_loss = train_loop(train_dataloader, cos_model1, loss_fn, optimizer1, epoch)
        
    if epoch % 10 == 0:
        display(f"Epoch {epoch}.")

'Training on original loss function'

'Epoch 0. Train loss: 0.685237318277359. Test loss: 0.687676191329956.'

'Epoch 10. Train loss: 0.615065336227417. Test loss: 0.6159908771514893.'

'Epoch 20. Train loss: 0.5618208646774292. Test loss: 0.5605651140213013.'

'Epoch 30. Train loss: 0.5183338075876236. Test loss: 0.5155771970748901.'

'Epoch 40. Train loss: 0.48615869879722595. Test loss: 0.48309189081192017.'

'Epoch 50. Train loss: 0.4632952809333801. Test loss: 0.4596228003501892.'

'Epoch 60. Train loss: 0.4422297179698944. Test loss: 0.4380805194377899.'

'Epoch 70. Train loss: 0.42119134962558746. Test loss: 0.4167417287826538.'

'Epoch 80. Train loss: 0.3992723673582077. Test loss: 0.39474108815193176.'

'Epoch 90. Train loss: 0.3799649327993393. Test loss: 0.37634003162384033.'

'Epoch 100. Train loss: 0.36694158613681793. Test loss: 0.3633324205875397.'

'Epoch 110. Train loss: 0.35409417748451233. Test loss: 0.3503592312335968.'

'Epoch 120. Train loss: 0.3405783772468567. Test loss: 0.33666980266571045.'

'Epoch 130. Train loss: 0.3259272575378418. Test loss: 0.3219146430492401.'

'Epoch 140. Train loss: 0.3127497136592865. Test loss: 0.3088265657424927.'

'Epoch 150. Train loss: 0.3005062937736511. Test loss: 0.29647329449653625.'

'Epoch 160. Train loss: 0.2882854640483856. Test loss: 0.2841627895832062.'

'Epoch 170. Train loss: 0.27607035636901855. Test loss: 0.27188190817832947.'

'Epoch 180. Train loss: 0.2638775259256363. Test loss: 0.25964587926864624.'

'Epoch 190. Train loss: 0.2517319619655609. Test loss: 0.24747879803180695.'

'Epoch 200. Train loss: 0.23966187238693237. Test loss: 0.23540830612182617.'

'Epoch 210. Train loss: 0.2276974320411682. Test loss: 0.22346292436122894.'

'Epoch 220. Train loss: 0.21586862951517105. Test loss: 0.21167293190956116.'

'Epoch 230. Train loss: 0.20420679450035095. Test loss: 0.20006835460662842.'

'Epoch 240. Train loss: 0.19274357706308365. Test loss: 0.18867944180965424.'

'Epoch 250. Train loss: 0.18151052296161652. Test loss: 0.1775350123643875.'

'Epoch 260. Train loss: 0.17053785920143127. Test loss: 0.1666647046804428.'

'Epoch 270. Train loss: 0.15985483676195145. Test loss: 0.15609599649906158.'

'Epoch 280. Train loss: 0.1494891569018364. Test loss: 0.14585435390472412.'

'Epoch 290. Train loss: 0.13946563005447388. Test loss: 0.1359640508890152.'

'Epoch 300. Train loss: 0.12980739027261734. Test loss: 0.12644624710083008.'

'Epoch 310. Train loss: 0.12053437530994415. Test loss: 0.11731915920972824.'

'Epoch 320. Train loss: 0.11166350543498993. Test loss: 0.1085977852344513.'

'Epoch 330. Train loss: 0.10320815071463585. Test loss: 0.10029403865337372.'

'Epoch 340. Train loss: 0.0951782651245594. Test loss: 0.09241659939289093.'

'Epoch 350. Train loss: 0.08758025243878365. Test loss: 0.08497069031000137.'

'Epoch 360. Train loss: 0.0804169587790966. Test loss: 0.0779581218957901.'

'Epoch 370. Train loss: 0.07368754595518112. Test loss: 0.07137774676084518.'

'Epoch 380. Train loss: 0.06738847494125366. Test loss: 0.06522440910339355.'

'Epoch 390. Train loss: 0.061512697488069534. Test loss: 0.05949031189084053.'

'Epoch 400. Train loss: 0.05605054274201393. Test loss: 0.05416553094983101.'

'Epoch 410. Train loss: 0.05098992399871349. Test loss: 0.04923732206225395.'

'Epoch 420. Train loss: 0.04631645418703556. Test loss: 0.04469085857272148.'

'Epoch 430. Train loss: 0.04201403819024563. Test loss: 0.04050986096262932.'

'Epoch 440. Train loss: 0.03806533105671406. Test loss: 0.03667663782835007.'

'Epoch 450. Train loss: 0.034451890736818314. Test loss: 0.033172354102134705.'

'Epoch 460. Train loss: 0.03115438763052225. Test loss: 0.02997760847210884.'

'Epoch 470. Train loss: 0.02815305069088936. Test loss: 0.027072694152593613.'

'Epoch 480. Train loss: 0.025427956134080887. Test loss: 0.024437768384814262.'

'Epoch 490. Train loss: 0.022959405556321144. Test loss: 0.02205342799425125.'

'Epoch 500. Train loss: 0.020727905444800854. Test loss: 0.019900336861610413.'

'Epoch 510. Train loss: 0.018714629113674164. Test loss: 0.017959974706172943.'

'Epoch 520. Train loss: 0.016901462338864803. Test loss: 0.01621423102915287.'

'Epoch 530. Train loss: 0.015271017793565989. Test loss: 0.014646082185208797.'

'Epoch 540. Train loss: 0.013806858565658331. Test loss: 0.013239379972219467.'

'Epoch 550. Train loss: 0.012493470218032598. Test loss: 0.011978895403444767.'

'Epoch 560. Train loss: 0.011316515505313873. Test loss: 0.010850309394299984.'

'Epoch 570. Train loss: 0.010262649040669203. Test loss: 0.009840489365160465.'

'Epoch 580. Train loss: 0.009319362696260214. Test loss: 0.008937345817685127.'

'Epoch 590. Train loss: 0.008475392125546932. Test loss: 0.008129711262881756.'

'Epoch 600. Train loss: 0.007720276713371277. Test loss: 0.007407612167298794.'

'Epoch 610. Train loss: 0.007044532801955938. Test loss: 0.006761807017028332.'

'Epoch 620. Train loss: 0.006439672550186515. Test loss: 0.006184041500091553.'

'Epoch 630. Train loss: 0.005897914757952094. Test loss: 0.005666851066052914.'

'Epoch 640. Train loss: 0.005412437254562974. Test loss: 0.0052034384571015835.'

'Epoch 650. Train loss: 0.004976924858056009. Test loss: 0.004787955898791552.'

'Epoch 660. Train loss: 0.004585906630381942. Test loss: 0.0044149598106741905.'

'Epoch 670. Train loss: 0.004234463325701654. Test loss: 0.004079681821167469.'

'Epoch 680. Train loss: 0.003918194328434765. Test loss: 0.0037779472768306732.'

'Epoch 690. Train loss: 0.0036331575829535723. Test loss: 0.003506028326228261.'

'Epoch 700. Train loss: 0.0033759172074496746. Test loss: 0.0032605892047286034.'

'Epoch 710. Train loss: 0.003143410081975162. Test loss: 0.0030386741273105145.'

'Epoch 720. Train loss: 0.0029329126700758934. Test loss: 0.002837659092620015.'

'Epoch 730. Train loss: 0.0027420241967774928. Test loss: 0.00265524722635746.'

'Epoch 740. Train loss: 0.002568628580775112. Test loss: 0.0024893959052860737.'

'Epoch 750. Train loss: 0.0024108492652885616. Test loss: 0.002338321413844824.'

'Epoch 760. Train loss: 0.0022670335019938648. Test loss: 0.002200424438342452.'

'Epoch 770. Train loss: 0.0021357174264267087. Test loss: 0.002074338961392641.'

'Epoch 780. Train loss: 0.0020155872334726155. Test loss: 0.001958876848220825.'

'Epoch 790. Train loss: 0.0019055148004554212. Test loss: 0.0018529793014749885.'

'Epoch 800. Train loss: 0.001804534811526537. Test loss: 0.0017557364189997315.'

'Epoch 810. Train loss: 0.0017117765382863581. Test loss: 0.0016663246788084507.'

'Epoch 820. Train loss: 0.001626527402549982. Test loss: 0.0015840240521356463.'

'Epoch 830. Train loss: 0.0015481006703339517. Test loss: 0.0015082627069205046.'

'Epoch 840. Train loss: 0.0014759356854483485. Test loss: 0.0014384790556505322.'

'Epoch 850. Train loss: 0.0014095103542786092. Test loss: 0.0013741833390668035.'

'Epoch 860. Train loss: 0.0013483546499628574. Test loss: 0.0013149421429261565.'

'Epoch 870. Train loss: 0.0012920440931338817. Test loss: 0.0012603556970134377.'

'Epoch 880. Train loss: 0.0012401780404616147. Test loss: 0.0012100657913833857.'

'Epoch 890. Train loss: 0.0011924076243303716. Test loss: 0.0011637257412075996.'

'Epoch 900. Train loss: 0.0011483918060548604. Test loss: 0.0011210308875888586.'

'Epoch 910. Train loss: 0.0011078261013608426. Test loss: 0.0010816920548677444.'

'Epoch 920. Train loss: 0.0010704267479013652. Test loss: 0.0010454390430822968.'

'Epoch 930. Train loss: 0.0010359418229199946. Test loss: 0.0010120092192664742.'

'Epoch 940. Train loss: 0.001004139136057347. Test loss: 0.0009811694035306573.'

'Epoch 950. Train loss: 0.0009747917065396905. Test loss: 0.0009527064394205809.'

'Epoch 960. Train loss: 0.0009476944687776268. Test loss: 0.000926427892409265.'

'Epoch 970. Train loss: 0.0009226581023540348. Test loss: 0.0009021559380926192.'

'Epoch 980. Train loss: 0.0008995089738164097. Test loss: 0.0008797339396551251.'

'Epoch 990. Train loss: 0.000878091057529673. Test loss: 0.0008590014185756445.'

In [75]:
def gradient_norm(model, dataloader, loss_fn, optimize):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        y_pred = model(X)
        y_pred = y_pred.unsqueeze(1)
        loss = loss_fn(y_pred, y)
        loss.backward()
        
    grad_all = 0
    for params in model.parameters():
        grad = 0.0
        if params.grad is not None:
            grad = params.grad.detach().cpu().data.norm(2)
            grad_all += grad.item() ** 2
    grad_norm = grad_all ** 0.5
    return grad_norm

def minimum_ratio(model, dataloader, loss_fn, optimize):
    model.train()
    vals = 0
    length = 0
    for batch, (X, y) in enumerate(dataloader):
        y_pred = model(X)
        y_pred = y_pred.unsqueeze(1)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimize.step()
        weights=torch.Tensor().to(device)
        for name, params in model.named_parameters():
            if 'weight' in name:
                #print(name, len(params), params.shape)
                hess = torch.func.hessian(lambda x: x)(torch.nn.utils.parameters_to_vector(params).to(device)) 
                eigs_vals = torch.linalg.eigvalsh(hess)
                display(eigs_vals)
                vals += torch.sum(eigs_vals > 0).item()
                length+=len(eigs_vals)
    return vals/length

In [76]:
for epoch in range(5):
    train_loop(train_dataloader, cos_model1, loss_fn, optimizer1, epoch)
    norm = gradient_norm(cos_model1, train_dataloader, loss_fn, optimizer1)
    min_ratio = minimum_ratio(cos_model1, train_dataloader, loss_fn, optimizer1)
    val_loss = val_loss = val_loop(val_dataloader, cos_model1, loss_fn, epoch)
    display(f'{epoch}: gradient norm: {norm}. min_ratio: {min_ratio}. val_loss: {val_loss}')
    
    
    

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')

'0: gradient norm: 0.04588260363377215. min_ratio: 0.0. val_loss: 0.0007256894605234265'

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

KeyboardInterrupt: 

In [60]:
torch.Tensor([1,2]).view(2).shape

torch.Size([2])