In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [4]:
# what is the state_dict

In [6]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])


Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


In [8]:
'''
torch.save(model.state_dict(), PATH)

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, weights_only=True))
model.eval()

'''


'\ntorch.save(model.state_dict(), PATH)\n\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, weights_only=True))\nmodel.eval()\n\n'

In [9]:
'''
-export/load model in torchscript format

#export 
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_scripted.pt') # Save

#load
model = torch.jit.load('model_scripted.pt')
model.eval()


'''


"\n-export/load model in torchscript format\n\n#export \nmodel_scripted = torch.jit.script(model) # Export to TorchScript\nmodel_scripted.save('model_scripted.pt') # Save\n\n#load\nmodel = torch.jit.load('model_scripted.pt')\nmodel.eval()\n\n\n"

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        return self.fc(x)

# Initialize model and optimizer
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()  # Example loss function

# Simulate training for 100 epochs
num_epochs = 100
for epoch in range(1, num_epochs + 1):  # Start from 1 for better readability
    # Simulated training step
    optimizer.zero_grad()
    output = model(torch.randn(1, 2))  # Random input
    loss = loss_fn(output, torch.randn(1, 1))  # Random target
    loss.backward()
    optimizer.step()

    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item()
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')
        print(f"Checkpoint saved at epoch {epoch}")

print("Training completed!")


Checkpoint saved at epoch 10
Checkpoint saved at epoch 20
Checkpoint saved at epoch 30
Checkpoint saved at epoch 40
Checkpoint saved at epoch 50
Checkpoint saved at epoch 60
Checkpoint saved at epoch 70
Checkpoint saved at epoch 80
Checkpoint saved at epoch 90
Checkpoint saved at epoch 100
Training completed!
