In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from utils import  train_model, visualize_loss_landscape, generate_random_direction
from models import TorusClassifier, resnet18, resnet50
import plotly.graph_objects as go
import torchvision
import torchvision.transforms as transforms
import plotly.io as pio

setting = 'resnet18_adam2'
pio.renderers.default = "vscode"

if torch.cuda.is_available():
    device = torch.device("cuda")  # CUDA device
else:
    device = torch.device("cpu")  # CPU device

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

# Loss landscape parameters

In [None]:
grid_size = 31
grid_range = (-1,1)
maximum_loss = 5
percentage = 0.5

# Set model 

In [None]:
model_path = './saved_weights/' + setting + '.pth'

# Define the model, criterion and optimizer
model = resnet50(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

# Load model

In [None]:
# Load the state dictionary from the file
loaded_state_dict = torch.load(model_path)
# Assume 'model' is a newly created nn.Module object with the same architecture
model.load_state_dict(loaded_state_dict)

# Train/save model

In [None]:
# # Train the model for a few epochs
# model, total_loss, accuracy = train_model(model, criterion, optimizer, trainloader, 30, print_every=10)
# # Save the state dictionary to a file
# torch.save(model.state_dict(), model_path)

# Visualize/save loss landscape

In [None]:
picture_path = './pics/' + setting + '.html'

direction1= generate_random_direction(model)
direction2 = generate_random_direction(model)
fig = visualize_loss_landscape(model, criterion, testloader, direction1, direction2, grid_size, grid_range, maximum_loss, percentage)
fig.update_layout(width=2000, height=1000)  # Adjust these values as needed
fig.write_html(picture_path)