In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [25]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, pred, true):
        loss = (pred - true)**2
        return loss.mean()
    
        

In [26]:
class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(2, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )
        
    def forward(self, x):
        return self.network(x)

In [27]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        super(CustomDataset, self).__init__()
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return X[idx, :], y[idx]

In [28]:
X = torch.randn(1000, 2)
y = 3 * torch.mul(X[:, 0], X[:, 1]).unsqueeze(1) + torch.randn(1000, 1)*3
print(X.shape, y.shape)
dataset = CustomDataset(X, y)
print(dataset)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)
print(len(next(iter(dataloader))))

torch.Size([1000, 2]) torch.Size([1000, 1])
<__main__.CustomDataset object at 0x131cbefe0>
2


In [29]:
model = CustomModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = CustomLoss()

epochs = 1000
for epoch in range(epochs):
    for X_train, y_train in dataloader:
        pred = model(X_train)
        loss = criterion(pred, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if epoch % 100 == 0:
        print(f"Loss: {loss.item()}")

Loss: 20.795549392700195
Loss: 11.477770805358887
Loss: 10.886364936828613
Loss: 8.334315299987793
Loss: 8.210415840148926
Loss: 7.873299598693848
Loss: 9.052928924560547
Loss: 8.413309097290039
Loss: 8.281631469726562
Loss: 9.855108261108398


In [30]:
#Explore state_dict
for param in model.state_dict():
    print(f"name: {param} and size: {model.state_dict()[param].size()}")

name: network.0.weight and size: torch.Size([8, 2])
name: network.0.bias and size: torch.Size([8])
name: network.2.weight and size: torch.Size([1, 8])
name: network.2.bias and size: torch.Size([1])


In [31]:
torch.save(model.state_dict(), "model.pth")

In [32]:
model1 = CustomModel()
loaded_model = model1.load_state_dict(torch.load("model.pth"))

  loaded_model = model1.load_state_dict(torch.load("model.pth"))


In [35]:
# Verify the model works after loading
X_test = torch.tensor([[0.5, 1.0]])
with torch.no_grad():
    pred = model1(X_test)
    print(f"Predictions after loading: {pred}")

Predictions after loading: tensor([[2.4699]])
