In [57]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.models import resnet50, ResNet50_Weights
import torch.optim as optim


In [58]:
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
dataset = datasets.ImageFolder('stata_dataset', transform=transform)

train, val = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=32, shuffle=False)

In [81]:
class StataModel(torch.nn.Module):

    def __init__(self):
        super(StataModel, self).__init__()
        
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        resnet_layers = list(resnet.children())[:-1]
        self.resnet_layers = nn.Sequential(*resnet_layers)
        for param in self.resnet_layers.parameters():
            param.requires_grad = False
        
        self.linear = nn.Linear(2048, len(dataset.classes))

    def forward(self, x):
        x = self.resnet_layers(x)
        x = torch.squeeze(x)
        x = self.linear(x)
        return x
    

In [82]:
stata_model = StataModel()

In [83]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(stata_model.parameters())

In [86]:
def test_model(net):
    correct_pred = 0
    total_pred = 0

    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            outputs = net(images)
            _, predictions = torch.max(outputs, 1)
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred += 1
                total_pred += 1

    accuracy = 100 * float(correct_pred) / total_pred
    print(f'Test accuracy is {accuracy:.1f} %')
        
def train(net):
    for epoch in range(10):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = stata_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        print(f'{epoch}: loss: {running_loss}')
        test_model(net)

    print('Finished Training')

In [87]:
train(stata_model)

0: loss: 5.763929724693298
Test accuracy is 29.4 %
1: loss: 5.479592680931091
Test accuracy is 29.4 %
2: loss: 4.971322536468506
Test accuracy is 29.4 %
3: loss: 4.7355101108551025
Test accuracy is 29.4 %
4: loss: 4.441720604896545
Test accuracy is 29.4 %
5: loss: 4.331196427345276
Test accuracy is 29.4 %
6: loss: 4.044917821884155
Test accuracy is 29.4 %
7: loss: 3.644752860069275
Test accuracy is 41.2 %
8: loss: 3.487971544265747
Test accuracy is 41.2 %
9: loss: 3.054932177066803
Test accuracy is 41.2 %
Finished Training
