In [None]:
!pip install gdown==3.6.0

import torch
import gdown
from torchvision import models
from torch import nn
from collections import OrderedDict

my_file_id = "1aBjrK4kIYS_swJ8PF59bnYxoA7aoZ09k"
!gdown https://drive.google.com/uc?id={my_file_id} 
checkpoint_path = ('/home/workspace/checkpoint.pt')

# Load your model to this variable
def load_checkpoint(checkpoint_path):
    
    checkpoint = torch.load(checkpoint_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = models.vgg19(pretrained=True)
    
    for param in model.parameters():
        param.requires_grad_(False)
        
    model.classifier = nn.Sequential(OrderedDict([
                                    ('fc1', nn.Linear(25088, 4096)),
                                    ('relu1', nn.ReLU()),
                                    ('dropout1', nn.Dropout(p=0.6)),
                                    ('fc2', nn.Linear(4096, 2048)),
                                    ('relu2', nn.ReLU()),
                                    ('dropout2', nn.Dropout(p=0.6)),
                                    ('output', nn.LogSoftmax(dim=1))
                                    ]))
    
    model.class_to_idx = checkpoint['class_to_dict']
    model.load_state_dict(checkpoint['state_dict'])
    
    model = load_checkpoint(checkpoint_path) 
    
    image_size = 224
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
        
    model.to(device)
    
    return model 