## import Libraries

In [1]:
import os

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.models import ResNet50_Weights

import models.resnet as resnet
from kan import KAN
from trainer import Trainer

In [2]:
args_dict = {}

transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 16

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')




Files already downloaded and verified
Files already downloaded and verified


In [3]:
resnet_50 = resnet.resnet50_v2(weights="IMAGENET1K_V2", )
resnet_50.fc = nn.Identity()

# Freezing the layers of the resnet model
for param in resnet_50.parameters():
    param.requires_grad = False


class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, activation):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.activation = activation
        

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x
    


KAN_model = KAN(width=[2048, 512, 10], grid=3, k=3, device='cuda')

#model1 = MLP(2048, 512, 10, nn.LeakyReLU())
model1 = KAN_model
feature_extractor = resnet_50



args_dict['kan'] = True
args_dict['opt'] = 'Adam'
loss_function = nn.CrossEntropyLoss()

lr = 1.0
args_dict['epochs'] = 15
args_dict['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'

args_dict['dataset_name'] = 'CIFAR10'
args_dict['trainloader'] = trainloader
args_dict['testloader'] = testloader

args_dict['model_name'] = 'KAN_Model' #'ResNet50_v2'
args_dict['model'] = model1.to(args_dict['device'])

args_dict['feature_extractor'] = feature_extractor.to(args_dict['device'])

#args_dict['optimizer'] = optim.SGD(model1.parameters(), lr = lr, momentum=0.9)
args_dict['optimizer'] = optim.Adam(model1.parameters(), lr = lr)

#args_dict['scheduler'] = optim.lr_scheduler.CyclicLR(args_dict['optimizer'], base_lr=0.00005, max_lr=0.00001, step_size_up=3, mode='exp_range', gamma=0.5)
args_dict['scheduler'] = optim.lr_scheduler.StepLR(args_dict['optimizer'], step_size=5, gamma=0.1)

args_dict['loss_function'] = loss_function


if os.path.exists('saved_models\\'):
    pass
else:
    os.makedirs('saved_models\\')

args_dict['weights_save_path'] = 'saved_models'
args_dict['record_save_path'] = 'saved_models\\initial_trainings.txt'


trainer_obj = Trainer(args_dict)

## calculate number of parameters
print('Model:', args_dict['model_name'], ' No. of Parameters:', trainer_obj.calculate_no_of_parameters(args_dict['model']))

trainer_obj.train_models()

Model: KAN_Model  No. of Parameters: 12644874


description:  70%|███████████████████████████████▍             | 7/10 [3:11:10<1:21:56, 1638.68s/it]
  0%|          | 0/3125 [3:11:10<?, ?it/s]


KeyboardInterrupt: 