In [None]:
import torchvision.models as models
import os
import wandb
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np


In [None]:
#a function that calculates the accuracy of the model on the test set
def test_accuracy(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        return 100 * correct / total

In [None]:
resnet18 = models.resnet18(pretrained=False, num_classes=10)
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()
    
        self.adapter= nn.Sequential( nn.Linear(512, 128), nn.ReLU() )
   
        self.conv1 = resnet18.conv1
        self.bn1 = resnet18.bn1
        self.relu = resnet18.relu
        self.maxpool = resnet18.maxpool
        self.layer1 = resnet18.layer1
        self.layer2 = resnet18.layer2
        self.layer3 = resnet18.layer3
        self.layer4 = resnet18.layer4
        self.avgpool = resnet18.avgpool
        self.fc = nn.Sequential( nn.Linear(128, num_classes))
        # print(self)

    def forward(self, x, no_fc=False):
        # See note [TorchScript super()]
        # print(self)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.adapter(x)
        if no_fc:
            return x
        x = self.fc(x)

        return x



In [None]:
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 32, 32)
    return x


num_epochs = 10
batch_size = 512
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# dataset_train = MNIST('./data', transform=img_transform, download=True,train = True)
# dataset_test = MNIST('./data', transform=img_transform, download=True,train = False)

train_transform = transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

dataset_train = CIFAR10('./data', transform=train_transform, download=True,train = True)
dataset_test = CIFAR10('./data', transform=test_transform, download=True,train = False)


# dataloader
train_loader = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

test_loader = DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)
# dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [None]:
from time import sleep

# in this cell, I want to train the model and save the weights

# device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device=torch.device('cuda'if torch.cuda.is_available() else 'cpu')
print(device)
modelResnet=ResNet18(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(modelResnet.parameters(), lr=learning_rate)
test_accuracy_list_before_freeze = []
for epoch in range(num_epochs):
    print("Epoch: {}".format(epoch))
    for i, (images, labels) in enumerate(train_loader):
        print(i)
        images = images.to(device)
        labels = labels.to(device)

        outputs = modelResnet(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Epoch: {}/{}".format(epoch, num_epochs - 1))
    #calculate the accuracy of the model on the test set
    test_accuracy_list_before_freeze.append(test_accuracy(modelResnet, test_loader))
    print("Test accuracy: {}".format(test_accuracy_list_before_freeze[-1]))
#plot the accuracy of the model over the epochs
#epochs in axis x
epochs = np.arange(1, num_epochs+1)
#accuracy in axis y
accuracy = test_accuracy_list_before_freeze
plt.plot(epochs, accuracy, label='accuracy')
plt.legend()
plt.show()

In [None]:
import torch
from torchvision import models
from torchsummary import summary
import torchvision.models as models
#show what the model is made of we can compare it to the original resnet 18
# resnet18 = models.resnet18()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(modelResnet, (3, 32, 32))

In [None]:

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = modelResnet(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
    # save the model
    torch.save(modelResnet.state_dict(), './resnet18.pth')
    # save the image
    img = images[0].cpu()
    img = img.view(1, 3, 32, 32)
    save_image(img, './mlp_img/image_{}.png'.format(epoch))

In [None]:

from linear_classifier import LinearClassifier
#joined model
class JoinedModel(nn.Module):
    def __init__(self,num_classes=10):
        
        super(JoinedModel, self).__init__()
        #uses the resnets weights already trained
        self.resnet = modelResnet
        #classifier parts
        self.classifier = LinearClassifier()
    def forward(self, x):
        x = self.resnet.forward(x,no_fc=True)
        #classifier part
        x = self.classifier(x)
        return x
joined_model = JoinedModel().to(device)

In [None]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(joined_model, (3, 32, 32))

In [None]:
# print(joined_model)

In [None]:
#freeze all weights 
for param in joined_model.parameters():
    param.requires_grad = False
#unfreeze the classifier
joined_model.classifier.fc1.weight.requires_grad = True
joined_model.classifier.fc1.bias.requires_grad = True
joined_model.classifier.fc2.weight.requires_grad = True
joined_model.classifier.fc2.bias.requires_grad = True


In [None]:
#verify weight are frozen

for name, param in joined_model.named_parameters():
    print(name, param.requires_grad)


In [None]:
#train the joined model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, joined_model.parameters()), lr=learning_rate)


In [None]:


#a function that plots the accuracy of the model over the epochs
# def plot_accuracy(model, train_loader, test_loader):
#     # Plot the accuracy
#     train_acc = []
#     test_acc = []
#     for epoch in range(num_epochs):
#         print("Epoch: {}".format(epoch))
#         for i, (images, labels) in enumerate(train_loader):
#             print(i)
#             images = images.to(device)
#             labels = labels.to(device)

#             outputs = model(images)
#             loss = criterion(outputs, labels)

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
#             print("Epoch: {}/{}".format(epoch, num_epochs - 1))
#         train_acc.append(test_accuracy(model, train_loader))
#         test_acc.append(test_accuracy(model, test_loader))

#     plt.plot(train_acc, label='train')
#     plt.plot(test_acc, label='test')
#     plt.legend()
#     plt.show()


In [None]:

test_accuracy_list=[]
#train the classification layer

for epoch in range(num_epochs):
    print("Epoch: {}".format(epoch))
    for i, (images, labels) in enumerate(train_loader):
        print(i)
        images = images.to(device)
        labels = labels.to(device)

        outputs = joined_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print("Epoch: {}/{}".format(epoch, num_epochs - 1))
    #calculate the accuracy of the model on the test set
    test_accuracy_list.append(test_accuracy(joined_model, test_loader))
    print("Test accuracy: {}".format(test_accuracy_list[-1]))
#plot the accuracy of the model over the epochs
#epochs in axis x
epochs = np.arange(1, num_epochs+1)
#accuracy in axis y
accuracy = test_accuracy_list
plt.plot(epochs, accuracy, label='accuracy')
plt.legend()
plt.show()









    
    

In [None]:
# evaluate the joined model 
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = joined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)

        correct += (predicted == labels).sum().item()


    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
    # save the model
    torch.save(joined_model.state_dict(), './joined_model.pth')
    # save the image
    img = images[0].cpu()
    img = img.view(1, 3, 32, 32)
    save_image(img, './mlp_img_resnet/image_{}.png'.format(epoch))


# make a graph of the accuracy vs epoch




In [None]:
#save the joined_model for resnet supervised encoder
torch.save(joined_model.state_dict(), './saved_models/joined_model_resnet_supervised.pth')
